package com.hw.langchain.embeddings.openai;

import com.google.common.primitives.Doubles;
import com.google.common.primitives.Floats;
import com.hw.langchain.embeddings.base.Embeddings;
import com.hw.langchain.exception.LangChainException;
import com.hw.langchain.utils.Utils;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.embeddings.Embedding;
import com.hw.openai.entity.embeddings.EmbeddingData;
import com.hw.openai.entity.embeddings.EmbeddingResp;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:com/hw/langchain/embeddings/openai/OpenAIEmbeddings.class */
public class OpenAIEmbeddings implements Embeddings {
    private OpenAiClient client;
    private String model;
    private String openaiApiBase;
    private String openaiProxy;
    private int embeddingCtxLength;
    private String openaiApiKey;
    private String openaiApiType;
    private String openaiApiVersion;
    protected String openaiOrganization;
    private int chunkSize;
    private int maxRetries;
    protected long requestTimeout;

    /* loaded from: input_file:com/hw/langchain/embeddings/openai/OpenAIEmbeddings$OpenAIEmbeddingsBuilder.class */
    public static class OpenAIEmbeddingsBuilder {
        private OpenAiClient client;
        private boolean model$set;
        private String model$value;
        private String openaiApiBase;
        private String openaiProxy;
        private boolean embeddingCtxLength$set;
        private int embeddingCtxLength$value;
        private String openaiApiKey;
        private String openaiApiType;
        private String openaiApiVersion;
        private String openaiOrganization;
        private boolean chunkSize$set;
        private int chunkSize$value;
        private boolean maxRetries$set;
        private int maxRetries$value;
        private boolean requestTimeout$set;
        private long requestTimeout$value;

        OpenAIEmbeddingsBuilder() {
        }

        public OpenAIEmbeddingsBuilder client(OpenAiClient openAiClient) {
            this.client = openAiClient;
            return this;
        }

        public OpenAIEmbeddingsBuilder model(String str) {
            this.model$value = str;
            this.model$set = true;
            return this;
        }

        public OpenAIEmbeddingsBuilder openaiApiBase(String str) {
            this.openaiApiBase = str;
            return this;
        }

        public OpenAIEmbeddingsBuilder openaiProxy(String str) {
            this.openaiProxy = str;
            return this;
        }

        public OpenAIEmbeddingsBuilder embeddingCtxLength(int i) {
            this.embeddingCtxLength$value = i;
            this.embeddingCtxLength$set = true;
            return this;
        }

        public OpenAIEmbeddingsBuilder openaiApiKey(String str) {
            this.openaiApiKey = str;
            return this;
        }

        public OpenAIEmbeddingsBuilder openaiApiType(String str) {
            this.openaiApiType = str;
            return this;
        }

        public OpenAIEmbeddingsBuilder openaiApiVersion(String str) {
            this.openaiApiVersion = str;
            return this;
        }

        public OpenAIEmbeddingsBuilder openaiOrganization(String str) {
            this.openaiOrganization = str;
            return this;
        }

        public OpenAIEmbeddingsBuilder chunkSize(int i) {
            this.chunkSize$value = i;
            this.chunkSize$set = true;
            return this;
        }

        public OpenAIEmbeddingsBuilder maxRetries(int i) {
            this.maxRetries$value = i;
            this.maxRetries$set = true;
            return this;
        }

        public OpenAIEmbeddingsBuilder requestTimeout(long j) {
            this.requestTimeout$value = j;
            this.requestTimeout$set = true;
            return this;
        }

        public OpenAIEmbeddings build() {
            String str = this.model$value;
            if (!this.model$set) {
                str = OpenAIEmbeddings.$default$model();
            }
            int i = this.embeddingCtxLength$value;
            if (!this.embeddingCtxLength$set) {
                i = OpenAIEmbeddings.$default$embeddingCtxLength();
            }
            int i2 = this.chunkSize$value;
            if (!this.chunkSize$set) {
                i2 = OpenAIEmbeddings.$default$chunkSize();
            }
            int i3 = this.maxRetries$value;
            if (!this.maxRetries$set) {
                i3 = OpenAIEmbeddings.$default$maxRetries();
            }
            long j = this.requestTimeout$value;
            if (!this.requestTimeout$set) {
                j = OpenAIEmbeddings.$default$requestTimeout();
            }
            return new OpenAIEmbeddings(this.client, str, this.openaiApiBase, this.openaiProxy, i, this.openaiApiKey, this.openaiApiType, this.openaiApiVersion, this.openaiOrganization, i2, i3, j);
        }

        public String toString() {
            return "OpenAIEmbeddings.OpenAIEmbeddingsBuilder(client=" + this.client + ", model$value=" + this.model$value + ", openaiApiBase=" + this.openaiApiBase + ", openaiProxy=" + this.openaiProxy + ", embeddingCtxLength$value=" + this.embeddingCtxLength$value + ", openaiApiKey=" + this.openaiApiKey + ", openaiApiType=" + this.openaiApiType + ", openaiApiVersion=" + this.openaiApiVersion + ", openaiOrganization=" + this.openaiOrganization + ", chunkSize$value=" + this.chunkSize$value + ", maxRetries$value=" + this.maxRetries$value + ", requestTimeout$value=" + this.requestTimeout$value + ")";
        }
    }

    public OpenAIEmbeddings init() {
        this.openaiApiKey = Utils.getOrEnvOrDefault(this.openaiApiKey, "OPENAI_API_KEY", new String[0]);
        this.openaiApiBase = Utils.getOrEnvOrDefault(this.openaiApiBase, "OPENAI_API_BASE", "");
        this.openaiProxy = Utils.getOrEnvOrDefault(this.openaiProxy, "OPENAI_PROXY", "");
        this.openaiOrganization = Utils.getOrEnvOrDefault(this.openaiOrganization, "OPENAI_ORGANIZATION", "");
        this.openaiApiType = Utils.getOrEnvOrDefault(this.openaiApiType, "OPENAI_API_TYPE", "");
        this.openaiApiVersion = Utils.getOrEnvOrDefault(this.openaiApiVersion, "OPENAI_API_VERSION", "");
        this.client = OpenAiClient.builder().openaiApiBase(this.openaiApiBase).openaiApiKey(this.openaiApiKey).openaiApiVersion(this.openaiApiVersion).openaiApiType(this.openaiApiType).openaiOrganization(this.openaiOrganization).openaiProxy(this.openaiProxy).requestTimeout(this.requestTimeout).build().init();
        return this;
    }

    private List<List<Float>> getLenSafeEmbeddings(List<String> list) {
        ArrayList arrayList = new ArrayList(list.size());
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        Encoding encoding = (Encoding) Encodings.newDefaultEncodingRegistry().getEncodingForModel(this.model).orElseThrow(() -> {
            return new LangChainException("Encoding not found.");
        });
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            if (this.model.endsWith("001")) {
                str = str.replace("\n", " ");
            }
            List encode = encoding.encode(str);
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (i3 < encode.size()) {
                    arrayList2.add(encode.subList(i3, Math.min(i3 + this.embeddingCtxLength, encode.size())));
                    arrayList3.add(Integer.valueOf(i));
                    i2 = i3 + this.embeddingCtxLength;
                }
            }
        }
        ArrayList arrayList4 = new ArrayList();
        int i4 = 0;
        while (true) {
            int i5 = i4;
            if (i5 >= arrayList2.size()) {
                break;
            }
            embedWithRetry(arrayList2.subList(i5, Math.min(i5 + this.chunkSize, arrayList2.size()))).getData().forEach(embeddingData -> {
                arrayList4.add(embeddingData.getEmbedding());
            });
            i4 = i5 + this.chunkSize;
        }
        List list2 = IntStream.range(0, list.size()).mapToObj(i6 -> {
            return new ArrayList();
        }).toList();
        List list3 = IntStream.range(0, list.size()).mapToObj(i7 -> {
            return new ArrayList();
        }).toList();
        for (int i8 = 0; i8 < arrayList3.size(); i8++) {
            int intValue = ((Integer) arrayList3.get(i8)).intValue();
            ((List) list2.get(intValue)).add((List) arrayList4.get(i8));
            ((List) list3.get(intValue)).add(Integer.valueOf(((List) arrayList2.get(i8)).size()));
        }
        for (int i9 = 0; i9 < list.size(); i9++) {
            INDArray create = Nd4j.create((float[][]) ((List) list2.get(i9)).stream().map((v0) -> {
                return Floats.toArray(v0);
            }).toArray(i10 -> {
                return new float[i10];
            }));
            try {
                INDArray create2 = Nd4j.create(Doubles.toArray((Collection) list3.get(i9)));
                INDArray div = create.mulRowVector(create2).sum(new int[]{0}).div(create2.sum(new int[]{0}));
                if (create != null) {
                    create.close();
                }
                arrayList.add(Floats.asList(div.div(div.norm2Number()).toFloatVector()));
            } catch (Throwable th) {
                if (create != null) {
                    try {
                        create.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        return arrayList;
    }

    public List<Float> embeddingFunc(String str) {
        if (str.length() > this.embeddingCtxLength) {
            return getLenSafeEmbeddings(List.of(str)).get(0);
        }
        if (this.model.endsWith("001")) {
            str = str.replace("\n", " ");
        }
        return ((EmbeddingData) embedWithRetry(List.of(str)).getData().get(0)).getEmbedding();
    }

    @Override // com.hw.langchain.embeddings.base.Embeddings
    public List<List<Float>> embedDocuments(List<String> list) {
        return getLenSafeEmbeddings(list);
    }

    @Override // com.hw.langchain.embeddings.base.Embeddings
    public List<Float> embedQuery(String str) {
        return embeddingFunc(str);
    }

    public EmbeddingResp embedWithRetry(List<?> list) {
        return this.client.embedding(Embedding.builder().model(this.model).input(list).build());
    }

    private static String $default$model() {
        return "text-embedding-ada-002";
    }

    private static int $default$embeddingCtxLength() {
        return 8191;
    }

    private static int $default$chunkSize() {
        return 1000;
    }

    private static int $default$maxRetries() {
        return 6;
    }

    private static long $default$requestTimeout() {
        return 16L;
    }

    public static OpenAIEmbeddingsBuilder builder() {
        return new OpenAIEmbeddingsBuilder();
    }

    public OpenAIEmbeddings(OpenAiClient openAiClient, String str, String str2, String str3, int i, String str4, String str5, String str6, String str7, int i2, int i3, long j) {
        this.client = openAiClient;
        this.model = str;
        this.openaiApiBase = str2;
        this.openaiProxy = str3;
        this.embeddingCtxLength = i;
        this.openaiApiKey = str4;
        this.openaiApiType = str5;
        this.openaiApiVersion = str6;
        this.openaiOrganization = str7;
        this.chunkSize = i2;
        this.maxRetries = i3;
        this.requestTimeout = j;
    }
}
