package dev.langchain4j.store.embedding.redis;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.Experimental;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.IndexDefinition;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;

/* loaded from: input_file:dev/langchain4j/store/embedding/redis/RedisEmbeddingStore.class */
public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final JedisPooled client;
    private final RedisSchema schema;

    /* loaded from: input_file:dev/langchain4j/store/embedding/redis/RedisEmbeddingStore$Builder.class */
    public static class Builder {
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String indexName;
        private Integer dimension;
        private Collection<String> metadataKeys = new ArrayList();

        public Builder host(String str) {
            this.host = str;
            return this;
        }

        public Builder port(Integer num) {
            this.port = num;
            return this;
        }

        public Builder user(String str) {
            this.user = str;
            return this;
        }

        public Builder password(String str) {
            this.password = str;
            return this;
        }

        public Builder indexName(String str) {
            this.indexName = str;
            return this;
        }

        public Builder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        @Deprecated
        public Builder metadataFieldsName(Collection<String> collection) {
            this.metadataKeys = collection;
            return this;
        }

        public Builder metadataKeys(Collection<String> collection) {
            this.metadataKeys = collection;
            return this;
        }

        public RedisEmbeddingStore build() {
            return new RedisEmbeddingStore(this.host, this.port, this.user, this.password, this.indexName, this.dimension, this.metadataKeys);
        }
    }

    public RedisEmbeddingStore(String str, Integer num, String str2, String str3, String str4, Integer num2, Collection<String> collection) {
        ValidationUtils.ensureNotBlank(str, "host");
        ValidationUtils.ensureNotNull(num, "port");
        ValidationUtils.ensureNotNull(num2, "dimension");
        this.client = str2 == null ? new JedisPooled(str, num.intValue()) : new JedisPooled(str, num.intValue(), str2, str3);
        this.schema = RedisSchema.builder().indexName((String) Utils.getOrDefault(str4, "embedding-index")).prefix(str4 + ":").dimension(num2.intValue()).metadataKeys(collection).build();
        if (isIndexExist(this.schema.indexName())) {
            return;
        }
        createIndex(this.schema.indexName());
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, (TextSegment) null);
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public String add(String str, Embedding embedding, TextSegment textSegment) {
        addInternal(str, embedding, textSegment);
        return str;
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, (List) null);
        return list2;
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    @Experimental
    public void remove(String str) {
        ValidationUtils.ensureNotBlank(str, "id");
        removeAll(Collections.singletonList(str));
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    @Experimental
    public void removeAll(Collection<String> collection) {
        String[] strArr = (String[]) collection.stream().toArray(i -> {
            return new String[i];
        });
        collection.stream().forEach(str -> {
            if (str.startsWith(this.schema.prefix())) {
                return;
            }
            String str = this.schema.prefix() + str;
        });
        this.client.del(strArr);
    }

    @Override // dev.langchain4j.store.embedding.EmbeddingStore
    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        ArrayList arrayList = new ArrayList(this.schema.metadataKeys());
        arrayList.addAll(Arrays.asList(this.schema.vectorFieldName(), this.schema.scalarFieldName(), "vector_score"));
        return toEmbeddingMatch(this.client.ftSearch(this.schema.indexName(), new Query(String.format("*=>[ KNN %d @%s $BLOB AS %s ]", Integer.valueOf(i), this.schema.vectorFieldName(), "vector_score")).addParam("BLOB", RediSearchUtil.ToByteArray(embedding.vector())).returnFields((String[]) arrayList.toArray(new String[0])).setSortBy("vector_score", true).dialect(2)).getDocuments(), d);
    }

    private void createIndex(String str) {
        new IndexDefinition(IndexDefinition.Type.JSON).setPrefixes(new String[]{this.schema.prefix()});
        String ftCreate = this.client.ftCreate(str, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.schema.prefix()), this.schema.toSchemaFields());
        if ("OK".equals(ftCreate)) {
            return;
        }
        if (log.isErrorEnabled()) {
            log.error("create index error, msg={}", ftCreate);
        }
        throw new RedisRequestFailedException("create index error, msg=" + ftCreate);
    }

    private boolean isIndexExist(String str) {
        return this.client.ftList().contains(str);
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("do not add empty embeddings to redis");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        Pipeline pipelined = this.client.pipelined();
        try {
            int size = list.size();
            for (int i = 0; i < size; i++) {
                String str = list.get(i);
                Embedding embedding = list2.get(i);
                TextSegment textSegment = list3 == null ? null : list3.get(i);
                HashMap hashMap = new HashMap();
                hashMap.put(this.schema.vectorFieldName(), embedding.vector());
                if (textSegment != null) {
                    hashMap.put(this.schema.scalarFieldName(), textSegment.text());
                    hashMap.putAll(textSegment.metadata().asMap());
                }
                pipelined.jsonSetWithEscape(this.schema.prefix() + str, Path2.of("$"), hashMap);
            }
            List syncAndReturnAll = pipelined.syncAndReturnAll();
            if (pipelined != null) {
                pipelined.close();
            }
            Optional findAny = syncAndReturnAll.stream().filter(obj -> {
                return !"OK".equals(obj);
            }).findAny();
            if (findAny.isPresent()) {
                if (log.isErrorEnabled()) {
                    log.error("add embedding failed, msg={}", findAny.get());
                }
                throw new RedisRequestFailedException("add embedding failed, msg=" + findAny.get());
            }
        } catch (Throwable th) {
            if (pipelined != null) {
                try {
                    pipelined.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(List<Document> list, double d) {
        return (list == null || list.isEmpty()) ? new ArrayList() : (List) list.stream().map(document -> {
            double parseDouble = (2.0d - Double.parseDouble(document.getString("vector_score"))) / 2.0d;
            String substring = document.getId().substring(this.schema.prefix().length());
            String string = document.hasProperty(this.schema.scalarFieldName()) ? document.getString(this.schema.scalarFieldName()) : null;
            TextSegment textSegment = null;
            if (string != null) {
                Stream stream = this.schema.metadataKeys().stream();
                Objects.requireNonNull(document);
                Stream filter = stream.filter(obj -> {
                    return document.hasProperty(obj.toString());
                });
                Function function = obj2 -> {
                    return obj2;
                };
                Objects.requireNonNull(document);
                document.getClass();
                textSegment = new TextSegment(string, new Metadata((Map) filter.collect(Collectors.toMap(function, document::getString))));
            }
            try {
                return new EmbeddingMatch(Double.valueOf(parseDouble), substring, new Embedding((float[]) OBJECT_MAPPER.readValue(document.getString(this.schema.vectorFieldName()), float[].class)), textSegment);
            } catch (IOException e) {
                throw new RuntimeException(e);
            } catch (JsonProcessingException e2) {
                throw new RedisRequestFailedException("failed to parse embedding", e2);
            }
        }).filter(embeddingMatch -> {
            return embeddingMatch.score().doubleValue() >= d;
        }).collect(Collectors.toList());
    }

    public static Builder builder() {
        return new Builder();
    }
}
