/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.redis;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.redis.RedisRequestFailedException;
import dev.langchain4j.store.embedding.redis.RedisSchema;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.IndexDefinition;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.SearchResult;

public class RedisEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final JedisPooled client;
    private final RedisSchema schema;

    public RedisEmbeddingStore(String host, Integer port, String user, String password, String indexName, Integer dimension, Collection<String> metadataKeys) {
        ValidationUtils.ensureNotBlank((String)host, (String)"host");
        ValidationUtils.ensureNotNull((Object)port, (String)"port");
        ValidationUtils.ensureNotNull((Object)dimension, (String)"dimension");
        this.client = user == null ? new JedisPooled(host, port.intValue()) : new JedisPooled(host, port.intValue(), user, password);
        this.schema = RedisSchema.builder().indexName((String)Utils.getOrDefault((Object)indexName, (Object)"embedding-index")).prefix(indexName + ":").dimension(dimension.intValue()).metadataKeys(metadataKeys).build();
        if (!this.isIndexExist(this.schema.indexName())) {
            this.createIndex(this.schema.indexName());
        }
    }

    @Override
    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    @Override
    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    @Override
    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    @Override
    public String add(String id, Embedding embedding, TextSegment textSegment) {
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    @Override
    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    @Override
    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    @Override
    @Experimental
    public void remove(String id) {
        ValidationUtils.ensureNotBlank((String)id, (String)"id");
        this.removeAll(Collections.singletonList(id));
    }

    @Override
    @Experimental
    public void removeAll(Collection<String> ids) {
        String[] ds = (String[])ids.stream().toArray(String[]::new);
        ids.stream().forEach(id -> {
            if (!id.startsWith(this.schema.prefix())) {
                id = this.schema.prefix() + id;
            }
        });
        this.client.del(ds);
    }

    @Override
    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
        ArrayList<String> returnFields = new ArrayList<String>(this.schema.metadataKeys());
        returnFields.addAll(Arrays.asList(this.schema.vectorFieldName(), this.schema.scalarFieldName(), "vector_score"));
        Query query = new Query(String.format(queryTemplate, maxResults, this.schema.vectorFieldName(), "vector_score")).addParam("BLOB", (Object)RediSearchUtil.ToByteArray((float[])referenceEmbedding.vector())).returnFields(returnFields.toArray(new String[0])).setSortBy("vector_score", true).dialect(2);
        SearchResult result = this.client.ftSearch(this.schema.indexName(), query);
        List documents = result.getDocuments();
        return this.toEmbeddingMatch(documents, minScore);
    }

    private void createIndex(String indexName) {
        IndexDefinition indexDefinition = new IndexDefinition(IndexDefinition.Type.JSON);
        indexDefinition.setPrefixes(new String[]{this.schema.prefix()});
        String res = this.client.ftCreate(indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.schema.prefix()), this.schema.toSchemaFields());
        if (!"OK".equals(res)) {
            if (log.isErrorEnabled()) {
                log.error("create index error, msg={}", (Object)res);
            }
            throw new RedisRequestFailedException("create index error, msg=" + res);
        }
    }

    private boolean isIndexExist(String indexName) {
        Set indexes = this.client.ftList();
        return indexes.contains(indexName);
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (!Utils.isNullOrEmpty(ids) && !Utils.isNullOrEmpty(embeddings)) {
            List responses;
            ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
            ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
            try (Pipeline pipeline = this.client.pipelined();){
                int size = ids.size();
                int i = 0;
                while (true) {
                    if (i >= size) {
                        responses = pipeline.syncAndReturnAll();
                        break;
                    }
                    String id = ids.get(i);
                    Embedding embedding = embeddings.get(i);
                    TextSegment textSegment = embedded == null ? null : embedded.get(i);
                    HashMap<String, Object> fields = new HashMap<String, Object>();
                    fields.put(this.schema.vectorFieldName(), embedding.vector());
                    if (textSegment != null) {
                        fields.put(this.schema.scalarFieldName(), textSegment.text());
                        fields.putAll(textSegment.metadata().asMap());
                    }
                    String key = this.schema.prefix() + id;
                    pipeline.jsonSetWithEscape(key, Path2.of((String)"$"), fields);
                    ++i;
                }
            }
            Optional<Object> errResponse = responses.stream().filter(response -> !"OK".equals(response)).findAny();
            if (errResponse.isPresent()) {
                if (log.isErrorEnabled()) {
                    log.error("add embedding failed, msg={}", errResponse.get());
                }
                throw new RedisRequestFailedException("add embedding failed, msg=" + errResponse.get());
            }
        } else {
            log.info("do not add empty embeddings to redis");
        }
    }

    private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(List<Document> documents, double minScore) {
        return documents != null && !documents.isEmpty() ? documents.stream().map(document -> {
            Embedding embedding;
            double score = (2.0 - Double.parseDouble(document.getString("vector_score"))) / 2.0;
            String id = document.getId().substring(this.schema.prefix().length());
            String text = document.hasProperty(this.schema.scalarFieldName()) ? document.getString(this.schema.scalarFieldName()) : null;
            TextSegment embedded = null;
            if (text != null) {
                Stream<Object> var10000 = this.schema.metadataKeys().stream();
                Objects.requireNonNull(document);
                var10000 = var10000.filter(e -> document.hasProperty(e.toString()));
                Function<Object, Object> var10001 = metadataKey -> metadataKey;
                Objects.requireNonNull(document);
                Map<Object, String> metadata = var10000.collect(Collectors.toMap(var10001, arg_0 -> ((Document)document).getString(arg_0)));
                embedded = new TextSegment(text, new Metadata(metadata));
            }
            try {
                float[] vectors = (float[])OBJECT_MAPPER.readValue(document.getString(this.schema.vectorFieldName()), float[].class);
                embedding = new Embedding(vectors);
            }
            catch (JsonProcessingException var9) {
                throw new RedisRequestFailedException("failed to parse embedding", (Throwable)var9);
            }
            catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            return new EmbeddingMatch(Double.valueOf(score), id, embedding, embedded);
        }).filter(embeddingMatch -> embeddingMatch.score() >= minScore).collect(Collectors.toList()) : new ArrayList<EmbeddingMatch<TextSegment>>();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String indexName;
        private Integer dimension;
        private Collection<String> metadataKeys = new ArrayList<String>();

        public Builder host(String host) {
            this.host = host;
            return this;
        }

        public Builder port(Integer port) {
            this.port = port;
            return this;
        }

        public Builder user(String user) {
            this.user = user;
            return this;
        }

        public Builder password(String password) {
            this.password = password;
            return this;
        }

        public Builder indexName(String indexName) {
            this.indexName = indexName;
            return this;
        }

        public Builder dimension(Integer dimension) {
            this.dimension = dimension;
            return this;
        }

        @Deprecated
        public Builder metadataFieldsName(Collection<String> metadataFieldsName) {
            this.metadataKeys = metadataFieldsName;
            return this;
        }

        public Builder metadataKeys(Collection<String> metadataKeys) {
            this.metadataKeys = metadataKeys;
            return this;
        }

        public RedisEmbeddingStore build() {
            return new RedisEmbeddingStore(this.host, this.port, this.user, this.password, this.indexName, this.dimension, this.metadataKeys);
        }
    }
}

