package com.hw.langchain.llms.openai;

import com.google.common.base.Preconditions;
import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.utils.Resilience4jRetryUtils;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.completions.Choice;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.completions.CompletionResp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import okhttp3.Interceptor;

/* loaded from: input_file:com/hw/langchain/llms/openai/BaseOpenAI.class */
public class BaseOpenAI extends BaseLLM {
    protected OpenAiClient client;
    protected String model;
    protected float temperature;
    protected int maxTokens;
    protected float topP;
    protected float frequencyPenalty;
    protected float presencePenalty;
    protected int n;
    protected int bestOf;
    protected String openaiApiKey;
    protected String openaiApiBase;
    protected String openaiApiType;
    protected String openaiApiVersion;
    protected String openaiOrganization;
    protected String openaiProxy;
    protected String proxyUsername;
    protected String proxyPassword;
    protected int batchSize;
    protected long requestTimeout;
    protected Map<String, Float> logitBias;
    protected int maxRetries;
    protected boolean stream;
    protected Set<String> allowedSpecial;
    protected Set<String> disallowedSpecial;
    protected List<Interceptor> interceptorList;

    /* loaded from: input_file:com/hw/langchain/llms/openai/BaseOpenAI$BaseOpenAIBuilder.class */
    public static abstract class BaseOpenAIBuilder<C extends BaseOpenAI, B extends BaseOpenAIBuilder<C, B>> extends BaseLLM.BaseLLMBuilder<C, B> {
        private OpenAiClient client;
        private boolean model$set;
        private String model$value;
        private boolean temperature$set;
        private float temperature$value;
        private boolean maxTokens$set;
        private int maxTokens$value;
        private boolean topP$set;
        private float topP$value;
        private float frequencyPenalty;
        private float presencePenalty;
        private boolean n$set;
        private int n$value;
        private boolean bestOf$set;
        private int bestOf$value;
        private String openaiApiKey;
        private String openaiApiBase;
        private String openaiApiType;
        private String openaiApiVersion;
        private String openaiOrganization;
        private String openaiProxy;
        private String proxyUsername;
        private String proxyPassword;
        private boolean batchSize$set;
        private int batchSize$value;
        private boolean requestTimeout$set;
        private long requestTimeout$value;
        private Map<String, Float> logitBias;
        private boolean maxRetries$set;
        private int maxRetries$value;
        private boolean stream;
        private Set<String> allowedSpecial;
        private Set<String> disallowedSpecial;
        private List<Interceptor> interceptorList;

        public B client(OpenAiClient openAiClient) {
            this.client = openAiClient;
            return self();
        }

        public B model(String str) {
            this.model$value = str;
            this.model$set = true;
            return self();
        }

        public B temperature(float f) {
            this.temperature$value = f;
            this.temperature$set = true;
            return self();
        }

        public B maxTokens(int i) {
            this.maxTokens$value = i;
            this.maxTokens$set = true;
            return self();
        }

        public B topP(float f) {
            this.topP$value = f;
            this.topP$set = true;
            return self();
        }

        public B frequencyPenalty(float f) {
            this.frequencyPenalty = f;
            return self();
        }

        public B presencePenalty(float f) {
            this.presencePenalty = f;
            return self();
        }

        public B n(int i) {
            this.n$value = i;
            this.n$set = true;
            return self();
        }

        public B bestOf(int i) {
            this.bestOf$value = i;
            this.bestOf$set = true;
            return self();
        }

        public B openaiApiKey(String str) {
            this.openaiApiKey = str;
            return self();
        }

        public B openaiApiBase(String str) {
            this.openaiApiBase = str;
            return self();
        }

        public B openaiApiType(String str) {
            this.openaiApiType = str;
            return self();
        }

        public B openaiApiVersion(String str) {
            this.openaiApiVersion = str;
            return self();
        }

        public B openaiOrganization(String str) {
            this.openaiOrganization = str;
            return self();
        }

        public B openaiProxy(String str) {
            this.openaiProxy = str;
            return self();
        }

        public B proxyUsername(String str) {
            this.proxyUsername = str;
            return self();
        }

        public B proxyPassword(String str) {
            this.proxyPassword = str;
            return self();
        }

        public B batchSize(int i) {
            this.batchSize$value = i;
            this.batchSize$set = true;
            return self();
        }

        public B requestTimeout(long j) {
            this.requestTimeout$value = j;
            this.requestTimeout$set = true;
            return self();
        }

        public B logitBias(Map<String, Float> map) {
            this.logitBias = map;
            return self();
        }

        public B maxRetries(int i) {
            this.maxRetries$value = i;
            this.maxRetries$set = true;
            return self();
        }

        public B stream(boolean z) {
            this.stream = z;
            return self();
        }

        public B allowedSpecial(Set<String> set) {
            this.allowedSpecial = set;
            return self();
        }

        public B disallowedSpecial(Set<String> set) {
            this.disallowedSpecial = set;
            return self();
        }

        public B interceptorList(List<Interceptor> list) {
            this.interceptorList = list;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public abstract B self();

        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public abstract C build();

        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public String toString() {
            String baseLLMBuilder = super.toString();
            OpenAiClient openAiClient = this.client;
            String str = this.model$value;
            float f = this.temperature$value;
            int i = this.maxTokens$value;
            float f2 = this.topP$value;
            float f3 = this.frequencyPenalty;
            float f4 = this.presencePenalty;
            int i2 = this.n$value;
            int i3 = this.bestOf$value;
            String str2 = this.openaiApiKey;
            String str3 = this.openaiApiBase;
            String str4 = this.openaiApiType;
            String str5 = this.openaiApiVersion;
            String str6 = this.openaiOrganization;
            String str7 = this.openaiProxy;
            String str8 = this.proxyUsername;
            String str9 = this.proxyPassword;
            int i4 = this.batchSize$value;
            long j = this.requestTimeout$value;
            Map<String, Float> map = this.logitBias;
            int i5 = this.maxRetries$value;
            boolean z = this.stream;
            Set<String> set = this.allowedSpecial;
            Set<String> set2 = this.disallowedSpecial;
            List<Interceptor> list = this.interceptorList;
            return "BaseOpenAI.BaseOpenAIBuilder(super=" + baseLLMBuilder + ", client=" + openAiClient + ", model$value=" + str + ", temperature$value=" + f + ", maxTokens$value=" + i + ", topP$value=" + f2 + ", frequencyPenalty=" + f3 + ", presencePenalty=" + f4 + ", n$value=" + i2 + ", bestOf$value=" + i3 + ", openaiApiKey=" + str2 + ", openaiApiBase=" + str3 + ", openaiApiType=" + str4 + ", openaiApiVersion=" + str5 + ", openaiOrganization=" + str6 + ", openaiProxy=" + str7 + ", proxyUsername=" + str8 + ", proxyPassword=" + str9 + ", batchSize$value=" + i4 + ", requestTimeout$value=" + j + ", logitBias=" + baseLLMBuilder + ", maxRetries$value=" + map + ", stream=" + i5 + ", allowedSpecial=" + z + ", disallowedSpecial=" + set + ", interceptorList=" + set2 + ")";
        }
    }

    /* loaded from: input_file:com/hw/langchain/llms/openai/BaseOpenAI$BaseOpenAIBuilderImpl.class */
    private static final class BaseOpenAIBuilderImpl extends BaseOpenAIBuilder<BaseOpenAI, BaseOpenAIBuilderImpl> {
        private BaseOpenAIBuilderImpl() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.llms.openai.BaseOpenAI.BaseOpenAIBuilder, com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public BaseOpenAIBuilderImpl self() {
            return this;
        }

        @Override // com.hw.langchain.llms.openai.BaseOpenAI.BaseOpenAIBuilder, com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public BaseOpenAI build() {
            return new BaseOpenAI(this);
        }
    }

    @Override // com.hw.langchain.llms.base.BaseLLM
    public String llmType() {
        return "openai";
    }

    @Override // com.hw.langchain.llms.base.BaseLLM
    protected LLMResult innerGenerate(List<String> list, List<String> list2) {
        ArrayList arrayList = new ArrayList();
        List<List<String>> subPrompts = getSubPrompts(list);
        Completion build = Completion.builder().model(this.model).temperature(this.temperature).maxTokens(Integer.valueOf(this.maxTokens)).topP(this.topP).frequencyPenalty(this.frequencyPenalty).presencePenalty(this.presencePenalty).n(Integer.valueOf(this.n)).logitBias(this.logitBias).stop(list2).build();
        Iterator<List<String>> it = subPrompts.iterator();
        while (it.hasNext()) {
            build.setPrompt(it.next());
            arrayList.addAll(((CompletionResp) Resilience4jRetryUtils.retryWithExponentialBackoff(this.maxRetries, () -> {
                return this.client.create(build);
            })).getChoices());
        }
        return createLLMResult(arrayList, list, Map.of());
    }

    private LLMResult createLLMResult(List<Choice> list, List<String> list2, Map<String, Integer> map) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            List<Choice> subList = list.subList(i * this.n, (i + 1) * this.n);
            ArrayList arrayList2 = new ArrayList();
            for (Choice choice : subList) {
                HashMap hashMap = new HashMap(2);
                hashMap.put("finishReason", choice.getFinishReason());
                hashMap.put("logprobs", choice.getLogprobs());
                arrayList2.add(Generation.builder().text(choice.getText()).generationInfo(hashMap).build());
            }
            arrayList.add(arrayList2);
        }
        HashMap hashMap2 = new HashMap(2);
        hashMap2.put("token_usage", map);
        hashMap2.put("model_name", this.model);
        return new LLMResult(arrayList, hashMap2);
    }

    private List<List<String>> getSubPrompts(List<String> list) {
        if (this.maxTokens == -1) {
            Preconditions.checkArgument(list.size() == 1, "maxTokens set to -1 not supported for multiple inputs.");
            this.maxTokens = maxTokensForPrompt(list.get(0));
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= list.size()) {
                return arrayList;
            }
            arrayList.add(list.subList(i2, Math.min(i2 + this.batchSize, list.size())));
            i = i2 + this.batchSize;
        }
    }

    private int maxTokensForPrompt(String str) {
        return 0;
    }

    private static String $default$model() {
        return "text-davinci-003";
    }

    private static float $default$temperature() {
        return 0.7f;
    }

    private static int $default$maxTokens() {
        return 256;
    }

    private static float $default$topP() {
        return 1.0f;
    }

    private static int $default$n() {
        return 1;
    }

    private static int $default$bestOf() {
        return 1;
    }

    private static int $default$batchSize() {
        return 20;
    }

    private static long $default$requestTimeout() {
        return 16L;
    }

    private static int $default$maxRetries() {
        return 6;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseOpenAI(BaseOpenAIBuilder<?, ?> baseOpenAIBuilder) {
        super(baseOpenAIBuilder);
        this.client = ((BaseOpenAIBuilder) baseOpenAIBuilder).client;
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).model$set) {
            this.model = ((BaseOpenAIBuilder) baseOpenAIBuilder).model$value;
        } else {
            this.model = $default$model();
        }
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).temperature$set) {
            this.temperature = ((BaseOpenAIBuilder) baseOpenAIBuilder).temperature$value;
        } else {
            this.temperature = $default$temperature();
        }
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).maxTokens$set) {
            this.maxTokens = ((BaseOpenAIBuilder) baseOpenAIBuilder).maxTokens$value;
        } else {
            this.maxTokens = $default$maxTokens();
        }
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).topP$set) {
            this.topP = ((BaseOpenAIBuilder) baseOpenAIBuilder).topP$value;
        } else {
            this.topP = $default$topP();
        }
        this.frequencyPenalty = ((BaseOpenAIBuilder) baseOpenAIBuilder).frequencyPenalty;
        this.presencePenalty = ((BaseOpenAIBuilder) baseOpenAIBuilder).presencePenalty;
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).n$set) {
            this.n = ((BaseOpenAIBuilder) baseOpenAIBuilder).n$value;
        } else {
            this.n = $default$n();
        }
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).bestOf$set) {
            this.bestOf = ((BaseOpenAIBuilder) baseOpenAIBuilder).bestOf$value;
        } else {
            this.bestOf = $default$bestOf();
        }
        this.openaiApiKey = ((BaseOpenAIBuilder) baseOpenAIBuilder).openaiApiKey;
        this.openaiApiBase = ((BaseOpenAIBuilder) baseOpenAIBuilder).openaiApiBase;
        this.openaiApiType = ((BaseOpenAIBuilder) baseOpenAIBuilder).openaiApiType;
        this.openaiApiVersion = ((BaseOpenAIBuilder) baseOpenAIBuilder).openaiApiVersion;
        this.openaiOrganization = ((BaseOpenAIBuilder) baseOpenAIBuilder).openaiOrganization;
        this.openaiProxy = ((BaseOpenAIBuilder) baseOpenAIBuilder).openaiProxy;
        this.proxyUsername = ((BaseOpenAIBuilder) baseOpenAIBuilder).proxyUsername;
        this.proxyPassword = ((BaseOpenAIBuilder) baseOpenAIBuilder).proxyPassword;
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).batchSize$set) {
            this.batchSize = ((BaseOpenAIBuilder) baseOpenAIBuilder).batchSize$value;
        } else {
            this.batchSize = $default$batchSize();
        }
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).requestTimeout$set) {
            this.requestTimeout = ((BaseOpenAIBuilder) baseOpenAIBuilder).requestTimeout$value;
        } else {
            this.requestTimeout = $default$requestTimeout();
        }
        this.logitBias = ((BaseOpenAIBuilder) baseOpenAIBuilder).logitBias;
        if (((BaseOpenAIBuilder) baseOpenAIBuilder).maxRetries$set) {
            this.maxRetries = ((BaseOpenAIBuilder) baseOpenAIBuilder).maxRetries$value;
        } else {
            this.maxRetries = $default$maxRetries();
        }
        this.stream = ((BaseOpenAIBuilder) baseOpenAIBuilder).stream;
        this.allowedSpecial = ((BaseOpenAIBuilder) baseOpenAIBuilder).allowedSpecial;
        this.disallowedSpecial = ((BaseOpenAIBuilder) baseOpenAIBuilder).disallowedSpecial;
        this.interceptorList = ((BaseOpenAIBuilder) baseOpenAIBuilder).interceptorList;
    }

    public static BaseOpenAIBuilder<?, ?> builder() {
        return new BaseOpenAIBuilderImpl();
    }
}
