package com.hw.langchain.llms.openai;

import com.google.common.base.Preconditions;
import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.utils.Utils;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatChoice;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/hw/langchain/llms/openai/OpenAIChat.class */
public class OpenAIChat extends BaseLLM {
    protected OpenAiClient client;
    protected String model;
    protected float temperature;
    protected int maxTokens;
    protected float topP;
    protected float frequencyPenalty;
    protected float presencePenalty;
    protected int n;
    protected String openaiApiKey;
    protected String openaiApiBase;
    protected String openaiApiType;
    protected String openaiApiVersion;
    protected String openaiOrganization;
    protected String openaiProxy;
    protected int maxRetries;
    private List<Message> prefixMessages;
    protected long requestTimeout;
    protected Map<String, Float> logitBias;
    protected boolean stream;

    /* loaded from: input_file:com/hw/langchain/llms/openai/OpenAIChat$OpenAIChatBuilder.class */
    public static abstract class OpenAIChatBuilder<C extends OpenAIChat, B extends OpenAIChatBuilder<C, B>> extends BaseLLM.BaseLLMBuilder<C, B> {
        private OpenAiClient client;
        private boolean model$set;
        private String model$value;
        private boolean temperature$set;
        private float temperature$value;
        private boolean maxTokens$set;
        private int maxTokens$value;
        private boolean topP$set;
        private float topP$value;
        private float frequencyPenalty;
        private float presencePenalty;
        private boolean n$set;
        private int n$value;
        private String openaiApiKey;
        private String openaiApiBase;
        private String openaiApiType;
        private String openaiApiVersion;
        private String openaiOrganization;
        private String openaiProxy;
        private boolean maxRetries$set;
        private int maxRetries$value;
        private boolean prefixMessages$set;
        private List<Message> prefixMessages$value;
        private boolean requestTimeout$set;
        private long requestTimeout$value;
        private Map<String, Float> logitBias;
        private boolean stream;

        public B client(OpenAiClient openAiClient) {
            this.client = openAiClient;
            return self();
        }

        public B model(String str) {
            this.model$value = str;
            this.model$set = true;
            return self();
        }

        public B temperature(float f) {
            this.temperature$value = f;
            this.temperature$set = true;
            return self();
        }

        public B maxTokens(int i) {
            this.maxTokens$value = i;
            this.maxTokens$set = true;
            return self();
        }

        public B topP(float f) {
            this.topP$value = f;
            this.topP$set = true;
            return self();
        }

        public B frequencyPenalty(float f) {
            this.frequencyPenalty = f;
            return self();
        }

        public B presencePenalty(float f) {
            this.presencePenalty = f;
            return self();
        }

        public B n(int i) {
            this.n$value = i;
            this.n$set = true;
            return self();
        }

        public B openaiApiKey(String str) {
            this.openaiApiKey = str;
            return self();
        }

        public B openaiApiBase(String str) {
            this.openaiApiBase = str;
            return self();
        }

        public B openaiApiType(String str) {
            this.openaiApiType = str;
            return self();
        }

        public B openaiApiVersion(String str) {
            this.openaiApiVersion = str;
            return self();
        }

        public B openaiOrganization(String str) {
            this.openaiOrganization = str;
            return self();
        }

        public B openaiProxy(String str) {
            this.openaiProxy = str;
            return self();
        }

        public B maxRetries(int i) {
            this.maxRetries$value = i;
            this.maxRetries$set = true;
            return self();
        }

        public B prefixMessages(List<Message> list) {
            this.prefixMessages$value = list;
            this.prefixMessages$set = true;
            return self();
        }

        public B requestTimeout(long j) {
            this.requestTimeout$value = j;
            this.requestTimeout$set = true;
            return self();
        }

        public B logitBias(Map<String, Float> map) {
            this.logitBias = map;
            return self();
        }

        public B stream(boolean z) {
            this.stream = z;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public abstract B self();

        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public abstract C build();

        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public String toString() {
            String baseLLMBuilder = super.toString();
            OpenAiClient openAiClient = this.client;
            String str = this.model$value;
            float f = this.temperature$value;
            int i = this.maxTokens$value;
            float f2 = this.topP$value;
            float f3 = this.frequencyPenalty;
            float f4 = this.presencePenalty;
            int i2 = this.n$value;
            String str2 = this.openaiApiKey;
            String str3 = this.openaiApiBase;
            String str4 = this.openaiApiType;
            String str5 = this.openaiApiVersion;
            String str6 = this.openaiOrganization;
            String str7 = this.openaiProxy;
            int i3 = this.maxRetries$value;
            List<Message> list = this.prefixMessages$value;
            long j = this.requestTimeout$value;
            Map<String, Float> map = this.logitBias;
            boolean z = this.stream;
            return "OpenAIChat.OpenAIChatBuilder(super=" + baseLLMBuilder + ", client=" + openAiClient + ", model$value=" + str + ", temperature$value=" + f + ", maxTokens$value=" + i + ", topP$value=" + f2 + ", frequencyPenalty=" + f3 + ", presencePenalty=" + f4 + ", n$value=" + i2 + ", openaiApiKey=" + str2 + ", openaiApiBase=" + str3 + ", openaiApiType=" + str4 + ", openaiApiVersion=" + str5 + ", openaiOrganization=" + str6 + ", openaiProxy=" + str7 + ", maxRetries$value=" + i3 + ", prefixMessages$value=" + list + ", requestTimeout$value=" + j + ", logitBias=" + baseLLMBuilder + ", stream=" + map + ")";
        }
    }

    /* loaded from: input_file:com/hw/langchain/llms/openai/OpenAIChat$OpenAIChatBuilderImpl.class */
    private static final class OpenAIChatBuilderImpl extends OpenAIChatBuilder<OpenAIChat, OpenAIChatBuilderImpl> {
        private OpenAIChatBuilderImpl() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.llms.openai.OpenAIChat.OpenAIChatBuilder, com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public OpenAIChatBuilderImpl self() {
            return this;
        }

        @Override // com.hw.langchain.llms.openai.OpenAIChat.OpenAIChatBuilder, com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public OpenAIChat build() {
            return new OpenAIChat(this);
        }
    }

    public OpenAIChat init() {
        this.openaiApiBase = Utils.getOrEnvOrDefault(this.openaiApiBase, "OPENAI_API_BASE", "");
        this.openaiApiKey = Utils.getOrEnvOrDefault(this.openaiApiKey, "OPENAI_API_KEY", new String[0]);
        this.openaiOrganization = Utils.getOrEnvOrDefault(this.openaiOrganization, "OPENAI_ORGANIZATION", "");
        this.openaiProxy = Utils.getOrEnvOrDefault(this.openaiProxy, "OPENAI_PROXY", "");
        this.openaiApiType = Utils.getOrEnvOrDefault(this.openaiApiType, "OPENAI_API_TYPE", "");
        this.openaiApiVersion = Utils.getOrEnvOrDefault(this.openaiApiVersion, "OPENAI_API_VERSION", "");
        this.client = OpenAiClient.builder().openaiApiBase(this.openaiApiBase).openaiApiKey(this.openaiApiKey).openaiApiVersion(this.openaiApiVersion).openaiApiType(this.openaiApiType).openaiOrganization(this.openaiOrganization).openaiProxy(this.openaiProxy).requestTimeout(this.requestTimeout).build().init();
        return this;
    }

    @Override // com.hw.langchain.llms.base.BaseLLM
    public String llmType() {
        return "openai-chat";
    }

    @Override // com.hw.langchain.llms.base.BaseLLM
    protected LLMResult innerGenerate(List<String> list, List<String> list2) {
        ChatCompletionResp create = this.client.create(ChatCompletion.builder().model(this.model).temperature(this.temperature).messages(getChatMessages(list)).maxTokens(Integer.valueOf(this.maxTokens)).topP(this.topP).frequencyPenalty(this.frequencyPenalty).presencePenalty(this.presencePenalty).n(Integer.valueOf(this.n)).logitBias(this.logitBias).stop(list2).build());
        ArrayList arrayList = new ArrayList();
        arrayList.add(List.of(Generation.builder().text(((ChatChoice) create.getChoices().get(0)).getMessage().getContent()).build()));
        HashMap hashMap = new HashMap(2);
        hashMap.put("token_usage", create.getUsage());
        hashMap.put("model_name", create.getModel());
        return new LLMResult(arrayList, hashMap);
    }

    private List<Message> getChatMessages(List<String> list) {
        Preconditions.checkArgument(list.size() == 1, "OpenAIChat currently only supports single prompt, got %s", list);
        ArrayList arrayList = new ArrayList(this.prefixMessages);
        arrayList.add(Message.of(list.get(0)));
        return arrayList;
    }

    private static String $default$model() {
        return "gpt-3.5-turbo";
    }

    private static float $default$temperature() {
        return 0.7f;
    }

    private static int $default$maxTokens() {
        return 256;
    }

    private static float $default$topP() {
        return 1.0f;
    }

    private static int $default$n() {
        return 1;
    }

    private static int $default$maxRetries() {
        return 6;
    }

    private static List<Message> $default$prefixMessages() {
        return new ArrayList();
    }

    private static long $default$requestTimeout() {
        return 16L;
    }

    protected OpenAIChat(OpenAIChatBuilder<?, ?> openAIChatBuilder) {
        super(openAIChatBuilder);
        this.client = ((OpenAIChatBuilder) openAIChatBuilder).client;
        if (((OpenAIChatBuilder) openAIChatBuilder).model$set) {
            this.model = ((OpenAIChatBuilder) openAIChatBuilder).model$value;
        } else {
            this.model = $default$model();
        }
        if (((OpenAIChatBuilder) openAIChatBuilder).temperature$set) {
            this.temperature = ((OpenAIChatBuilder) openAIChatBuilder).temperature$value;
        } else {
            this.temperature = $default$temperature();
        }
        if (((OpenAIChatBuilder) openAIChatBuilder).maxTokens$set) {
            this.maxTokens = ((OpenAIChatBuilder) openAIChatBuilder).maxTokens$value;
        } else {
            this.maxTokens = $default$maxTokens();
        }
        if (((OpenAIChatBuilder) openAIChatBuilder).topP$set) {
            this.topP = ((OpenAIChatBuilder) openAIChatBuilder).topP$value;
        } else {
            this.topP = $default$topP();
        }
        this.frequencyPenalty = ((OpenAIChatBuilder) openAIChatBuilder).frequencyPenalty;
        this.presencePenalty = ((OpenAIChatBuilder) openAIChatBuilder).presencePenalty;
        if (((OpenAIChatBuilder) openAIChatBuilder).n$set) {
            this.n = ((OpenAIChatBuilder) openAIChatBuilder).n$value;
        } else {
            this.n = $default$n();
        }
        this.openaiApiKey = ((OpenAIChatBuilder) openAIChatBuilder).openaiApiKey;
        this.openaiApiBase = ((OpenAIChatBuilder) openAIChatBuilder).openaiApiBase;
        this.openaiApiType = ((OpenAIChatBuilder) openAIChatBuilder).openaiApiType;
        this.openaiApiVersion = ((OpenAIChatBuilder) openAIChatBuilder).openaiApiVersion;
        this.openaiOrganization = ((OpenAIChatBuilder) openAIChatBuilder).openaiOrganization;
        this.openaiProxy = ((OpenAIChatBuilder) openAIChatBuilder).openaiProxy;
        if (((OpenAIChatBuilder) openAIChatBuilder).maxRetries$set) {
            this.maxRetries = ((OpenAIChatBuilder) openAIChatBuilder).maxRetries$value;
        } else {
            this.maxRetries = $default$maxRetries();
        }
        if (((OpenAIChatBuilder) openAIChatBuilder).prefixMessages$set) {
            this.prefixMessages = ((OpenAIChatBuilder) openAIChatBuilder).prefixMessages$value;
        } else {
            this.prefixMessages = $default$prefixMessages();
        }
        if (((OpenAIChatBuilder) openAIChatBuilder).requestTimeout$set) {
            this.requestTimeout = ((OpenAIChatBuilder) openAIChatBuilder).requestTimeout$value;
        } else {
            this.requestTimeout = $default$requestTimeout();
        }
        this.logitBias = ((OpenAIChatBuilder) openAIChatBuilder).logitBias;
        this.stream = ((OpenAIChatBuilder) openAIChatBuilder).stream;
    }

    public static OpenAIChatBuilder<?, ?> builder() {
        return new OpenAIChatBuilderImpl();
    }
}
