
package dev.langchain4j.store.embedding.redis;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.IndexDefinition;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.SearchResult;
import redis.clients.jedis.search.IndexDefinition.Type;

public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final JedisPooled client;
    private final RedisSchema schema;

    public RedisEmbeddingStore(String host, Integer port, String user, String password, String indexName, Integer dimension, Collection<String> metadataKeys) {
        ValidationUtils.ensureNotBlank(host, "host");
        ValidationUtils.ensureNotNull(port, "port");
        ValidationUtils.ensureNotNull(dimension, "dimension");
        this.client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password);
        this.schema = RedisSchema.builder().indexName((String)Utils.getOrDefault(indexName, "embedding-index")).prefix(indexName+":").dimension(dimension).metadataKeys(metadataKeys).build();
        if (!this.isIndexExist(this.schema.indexName())) {
            this.createIndex(this.schema.indexName());
        }

    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, (TextSegment)null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public String add(String id, Embedding embedding, TextSegment textSegment) {
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = (List)embeddings.stream().map((ignored) -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, (List)null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = (List)embeddings.stream().map((ignored) -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    @Experimental
    public void remove(String id) {
        ValidationUtils.ensureNotBlank(id, "id");
        this.removeAll((Collection)Collections.singletonList(id));
    }

    @Experimental
    public void removeAll(Collection<String> ids) {
        String[] ds = ids.stream().toArray(String[]::new);
        ids.stream().forEach(id -> {
            if(!id.startsWith(schema.prefix())){
                id = schema.prefix()+id;
            }
        });
        this.client.del(ds);
    }
    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
        List<String> returnFields = new ArrayList(this.schema.metadataKeys());
        returnFields.addAll(Arrays.asList(this.schema.vectorFieldName(), this.schema.scalarFieldName(), "vector_score"));
        Query query = (new Query(String.format(queryTemplate, maxResults, this.schema.vectorFieldName(), "vector_score"))).addParam("BLOB", RediSearchUtil.ToByteArray(referenceEmbedding.vector())).returnFields((String[])returnFields.toArray(new String[0])).setSortBy("vector_score", true).dialect(2);
        SearchResult result = this.client.ftSearch(this.schema.indexName(), query);
        List<Document> documents = result.getDocuments();
        return this.toEmbeddingMatch(documents, minScore);
    }

    private void createIndex(String indexName) {
        IndexDefinition indexDefinition = new IndexDefinition(Type.JSON);
        indexDefinition.setPrefixes(new String[]{this.schema.prefix()});
        String res = this.client.ftCreate(indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.schema.prefix()), this.schema.toSchemaFields());
        if (!"OK".equals(res)) {
            if (log.isErrorEnabled()) {
                log.error("create index error, msg={}", res);
            }

            throw new RedisRequestFailedException("create index error, msg=" + res);
        }
    }

    private boolean isIndexExist(String indexName) {
        Set<String> indexes = this.client.ftList();
        return indexes.contains(indexName);
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (!Utils.isNullOrEmpty(ids) && !Utils.isNullOrEmpty(embeddings)) {
            ValidationUtils.ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
            ValidationUtils.ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");
            Pipeline pipeline = this.client.pipelined();

            List responses;
            try {
                int size = ids.size();
                int i = 0;

                while(true) {
                    if (i >= size) {
                        responses = pipeline.syncAndReturnAll();
                        break;
                    }

                    String id = (String)ids.get(i);
                    Embedding embedding = (Embedding)embeddings.get(i);
                    TextSegment textSegment = embedded == null ? null : (TextSegment)embedded.get(i);
                    Map<String, Object> fields = new HashMap();
                    fields.put(this.schema.vectorFieldName(), embedding.vector());
                    if (textSegment != null) {
                        fields.put(this.schema.scalarFieldName(), textSegment.text());
                        fields.putAll(textSegment.metadata().asMap());
                    }

                    String key = this.schema.prefix() + id;
                    pipeline.jsonSetWithEscape(key, Path2.of("$"), fields);
                    ++i;
                }
            } catch (Throwable var14) {
                if (pipeline != null) {
                    try {
                        pipeline.close();
                    } catch (Throwable var13) {
                        var14.addSuppressed(var13);
                    }
                }

                throw var14;
            }

            if (pipeline != null) {
                pipeline.close();
            }

            Optional<Object> errResponse = responses.stream().filter((response) -> {
                return !"OK".equals(response);
            }).findAny();
            if (errResponse.isPresent()) {
                if (log.isErrorEnabled()) {
                    log.error("add embedding failed, msg={}", errResponse.get());
                }

                throw new RedisRequestFailedException("add embedding failed, msg=" + errResponse.get());
            }
        } else {
            log.info("do not add empty embeddings to redis");
        }
    }

    private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(List<Document> documents, double minScore) {
        return (List)(documents != null && !documents.isEmpty() ? (List)documents.stream().map((document) -> {
            double score = (2.0 - Double.parseDouble(document.getString("vector_score"))) / 2.0;
            String id = document.getId().substring(this.schema.prefix().length());
            String text = document.hasProperty(this.schema.scalarFieldName()) ? document.getString(this.schema.scalarFieldName()) : null;
            TextSegment embedded = null;
            if (text != null) {
                Stream var10000 = this.schema.metadataKeys().stream();
                Objects.requireNonNull(document);
                var10000 = var10000.filter(e -> document.hasProperty(e.toString()));
//                var10000 = var10000.filter(document::hasProperty);
                Function var10001 = (metadataKey) -> {
                    return metadataKey;
                };
                Objects.requireNonNull(document);
                Map<String, String> metadata = (Map) var10000.collect(Collectors.toMap(var10001, document::getString));
                embedded = new TextSegment(text, new Metadata(metadata));
            }

            Embedding embedding;
            try {
                float[] vectors = (float[])OBJECT_MAPPER.readValue(document.getString(this.schema.vectorFieldName()), float[].class);
                embedding = new Embedding(vectors);
            } catch (JsonProcessingException var9) {
                throw new RedisRequestFailedException("failed to parse embedding", var9);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }

            return new EmbeddingMatch(score, id, embedding, embedded);
        }).filter((embeddingMatch) -> {
            return embeddingMatch.score() >= minScore;
        }).collect(Collectors.toList()) : new ArrayList());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String indexName;
        private Integer dimension;
        private Collection<String> metadataKeys = new ArrayList();

        public Builder() {
        }

        public Builder host(String host) {
            this.host = host;
            return this;
        }

        public Builder port(Integer port) {
            this.port = port;
            return this;
        }

        public Builder user(String user) {
            this.user = user;
            return this;
        }

        public Builder password(String password) {
            this.password = password;
            return this;
        }

        public Builder indexName(String indexName) {
            this.indexName = indexName;
            return this;
        }

        public Builder dimension(Integer dimension) {
            this.dimension = dimension;
            return this;
        }

        /** @deprecated */
        @Deprecated
        public Builder metadataFieldsName(Collection<String> metadataFieldsName) {
            this.metadataKeys = metadataFieldsName;
            return this;
        }

        public Builder metadataKeys(Collection<String> metadataKeys) {
            this.metadataKeys = metadataKeys;
            return this;
        }

        public RedisEmbeddingStore build() {
            return new RedisEmbeddingStore(this.host, this.port, this.user, this.password, this.indexName, this.dimension, this.metadataKeys);
        }
    }
}
