package com.hw.langchain.chat.models.openai;

import com.hw.langchain.chat.models.base.BaseChatModel;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.ChatGeneration;
import com.hw.langchain.schema.ChatResult;
import com.hw.langchain.utils.Resilience4jRetryUtils;
import com.hw.langchain.utils.Utils;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.completions.Usage;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import okhttp3.Interceptor;

/* loaded from: input_file:com/hw/langchain/chat/models/openai/ChatOpenAI.class */
public class ChatOpenAI extends BaseChatModel {
    protected OpenAiClient client;
    protected String model;
    protected float temperature;
    protected Map<String, Object> modelKwargs;
    protected String openaiApiKey;
    protected String openaiApiBase;
    protected String openaiApiType;
    protected String openaiApiVersion;
    protected String openaiOrganization;
    protected String openaiProxy;
    protected long requestTimeout;
    protected int maxRetries;
    protected boolean stream;
    protected int n;
    protected Integer maxTokens;
    private List<Interceptor> interceptorList;

    /* loaded from: input_file:com/hw/langchain/chat/models/openai/ChatOpenAI$ChatOpenAIBuilder.class */
    public static abstract class ChatOpenAIBuilder<C extends ChatOpenAI, B extends ChatOpenAIBuilder<C, B>> extends BaseChatModel.BaseChatModelBuilder<C, B> {
        private OpenAiClient client;
        private boolean model$set;
        private String model$value;
        private boolean temperature$set;
        private float temperature$value;
        private boolean modelKwargs$set;
        private Map<String, Object> modelKwargs$value;
        private String openaiApiKey;
        private String openaiApiBase;
        private String openaiApiType;
        private String openaiApiVersion;
        private String openaiOrganization;
        private String openaiProxy;
        private boolean requestTimeout$set;
        private long requestTimeout$value;
        private boolean maxRetries$set;
        private int maxRetries$value;
        private boolean stream;
        private boolean n$set;
        private int n$value;
        private Integer maxTokens;
        private List<Interceptor> interceptorList;

        public B client(OpenAiClient openAiClient) {
            this.client = openAiClient;
            return self();
        }

        public B model(String str) {
            this.model$value = str;
            this.model$set = true;
            return self();
        }

        public B temperature(float f) {
            this.temperature$value = f;
            this.temperature$set = true;
            return self();
        }

        public B modelKwargs(Map<String, Object> map) {
            this.modelKwargs$value = map;
            this.modelKwargs$set = true;
            return self();
        }

        public B openaiApiKey(String str) {
            this.openaiApiKey = str;
            return self();
        }

        public B openaiApiBase(String str) {
            this.openaiApiBase = str;
            return self();
        }

        public B openaiApiType(String str) {
            this.openaiApiType = str;
            return self();
        }

        public B openaiApiVersion(String str) {
            this.openaiApiVersion = str;
            return self();
        }

        public B openaiOrganization(String str) {
            this.openaiOrganization = str;
            return self();
        }

        public B openaiProxy(String str) {
            this.openaiProxy = str;
            return self();
        }

        public B requestTimeout(long j) {
            this.requestTimeout$value = j;
            this.requestTimeout$set = true;
            return self();
        }

        public B maxRetries(int i) {
            this.maxRetries$value = i;
            this.maxRetries$set = true;
            return self();
        }

        public B stream(boolean z) {
            this.stream = z;
            return self();
        }

        public B n(int i) {
            this.n$value = i;
            this.n$set = true;
            return self();
        }

        public B maxTokens(Integer num) {
            this.maxTokens = num;
            return self();
        }

        public B interceptorList(List<Interceptor> list) {
            this.interceptorList = list;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.chat.models.base.BaseChatModel.BaseChatModelBuilder
        public abstract B self();

        @Override // com.hw.langchain.chat.models.base.BaseChatModel.BaseChatModelBuilder
        public abstract C build();

        @Override // com.hw.langchain.chat.models.base.BaseChatModel.BaseChatModelBuilder
        public String toString() {
            String baseChatModelBuilder = super.toString();
            OpenAiClient openAiClient = this.client;
            String str = this.model$value;
            float f = this.temperature$value;
            Map<String, Object> map = this.modelKwargs$value;
            String str2 = this.openaiApiKey;
            String str3 = this.openaiApiBase;
            String str4 = this.openaiApiType;
            String str5 = this.openaiApiVersion;
            String str6 = this.openaiOrganization;
            String str7 = this.openaiProxy;
            long j = this.requestTimeout$value;
            int i = this.maxRetries$value;
            boolean z = this.stream;
            int i2 = this.n$value;
            Integer num = this.maxTokens;
            List<Interceptor> list = this.interceptorList;
            return "ChatOpenAI.ChatOpenAIBuilder(super=" + baseChatModelBuilder + ", client=" + openAiClient + ", model$value=" + str + ", temperature$value=" + f + ", modelKwargs$value=" + map + ", openaiApiKey=" + str2 + ", openaiApiBase=" + str3 + ", openaiApiType=" + str4 + ", openaiApiVersion=" + str5 + ", openaiOrganization=" + str6 + ", openaiProxy=" + str7 + ", requestTimeout$value=" + j + ", maxRetries$value=" + baseChatModelBuilder + ", stream=" + i + ", n$value=" + z + ", maxTokens=" + i2 + ", interceptorList=" + num + ")";
        }
    }

    /* loaded from: input_file:com/hw/langchain/chat/models/openai/ChatOpenAI$ChatOpenAIBuilderImpl.class */
    private static final class ChatOpenAIBuilderImpl extends ChatOpenAIBuilder<ChatOpenAI, ChatOpenAIBuilderImpl> {
        private ChatOpenAIBuilderImpl() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.chat.models.openai.ChatOpenAI.ChatOpenAIBuilder, com.hw.langchain.chat.models.base.BaseChatModel.BaseChatModelBuilder
        public ChatOpenAIBuilderImpl self() {
            return this;
        }

        @Override // com.hw.langchain.chat.models.openai.ChatOpenAI.ChatOpenAIBuilder, com.hw.langchain.chat.models.base.BaseChatModel.BaseChatModelBuilder
        public ChatOpenAI build() {
            return new ChatOpenAI(this);
        }
    }

    public ChatOpenAI init() {
        this.openaiApiKey = Utils.getOrEnvOrDefault(this.openaiApiKey, "OPENAI_API_KEY", new String[0]);
        this.openaiOrganization = Utils.getOrEnvOrDefault(this.openaiOrganization, "OPENAI_ORGANIZATION", "");
        this.openaiApiBase = Utils.getOrEnvOrDefault(this.openaiApiBase, "OPENAI_API_BASE", "");
        this.openaiProxy = Utils.getOrEnvOrDefault(this.openaiProxy, "OPENAI_PROXY", "");
        this.openaiApiType = Utils.getOrEnvOrDefault(this.openaiApiType, "OPENAI_API_TYPE", "");
        this.openaiApiVersion = Utils.getOrEnvOrDefault(this.openaiApiVersion, "OPENAI_API_VERSION", "");
        this.client = OpenAiClient.builder().openaiApiBase(this.openaiApiBase).openaiApiKey(this.openaiApiKey).openaiApiVersion(this.openaiApiVersion).openaiApiType(this.openaiApiType).openaiOrganization(this.openaiOrganization).openaiProxy(this.openaiProxy).requestTimeout(this.requestTimeout).interceptorList(this.interceptorList).build().init();
        if (this.n < 1) {
            throw new IllegalArgumentException("n must be at least 1.");
        }
        if (this.n <= 1 || !this.stream) {
            return this;
        }
        throw new IllegalArgumentException("n must be 1 when streaming.");
    }

    @Override // com.hw.langchain.chat.models.base.BaseChatModel
    public Map<String, Object> combineLlmOutputs(List<Map<String, Object>> list) {
        return Map.of("token_usage", (Usage) list.stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).map(map -> {
            return (Usage) map.get("token_usage");
        }).reduce((usage, usage2) -> {
            return new Usage(Long.valueOf(usage.getPromptTokens().longValue() + usage2.getPromptTokens().longValue()), Long.valueOf(usage.getCompletionTokens().longValue() + usage2.getCompletionTokens().longValue()), Long.valueOf(usage.getTotalTokens().longValue() + usage2.getTotalTokens().longValue()));
        }).orElse(new Usage()), "model_name", this.model);
    }

    @Override // com.hw.langchain.chat.models.base.BaseChatModel
    public ChatResult innerGenerate(List<BaseMessage> list, List<String> list2) {
        ChatCompletion build = ChatCompletion.builder().model(this.model).temperature(this.temperature).messages(convertMessages(list)).maxTokens(this.maxTokens).stream(this.stream).n(Integer.valueOf(this.n)).stop(list2).build();
        return createChatResult((ChatCompletionResp) Resilience4jRetryUtils.retryWithExponentialBackoff(this.maxRetries, () -> {
            return this.client.create(build);
        }));
    }

    public List<Message> convertMessages(List<BaseMessage> list) {
        return list.stream().map(OpenAI::convertLangChainToOpenAI).toList();
    }

    public ChatResult createChatResult(ChatCompletionResp chatCompletionResp) {
        return new ChatResult(chatCompletionResp.getChoices().stream().map(chatChoice -> {
            return OpenAI.convertOpenAiToLangChain(chatChoice.getMessage());
        }).map(ChatGeneration::new).toList(), Map.of("token_usage", chatCompletionResp.getUsage(), "model_name", chatCompletionResp.getModel()));
    }

    @Override // com.hw.langchain.chat.models.base.BaseChatModel
    public String llmType() {
        return "openai-chat";
    }

    private static String $default$model() {
        return "gpt-3.5-turbo";
    }

    private static float $default$temperature() {
        return 0.7f;
    }

    private static Map<String, Object> $default$modelKwargs() {
        return new HashMap();
    }

    private static long $default$requestTimeout() {
        return 16L;
    }

    private static int $default$maxRetries() {
        return 6;
    }

    private static int $default$n() {
        return 1;
    }

    protected ChatOpenAI(ChatOpenAIBuilder<?, ?> chatOpenAIBuilder) {
        super(chatOpenAIBuilder);
        this.client = ((ChatOpenAIBuilder) chatOpenAIBuilder).client;
        if (((ChatOpenAIBuilder) chatOpenAIBuilder).model$set) {
            this.model = ((ChatOpenAIBuilder) chatOpenAIBuilder).model$value;
        } else {
            this.model = $default$model();
        }
        if (((ChatOpenAIBuilder) chatOpenAIBuilder).temperature$set) {
            this.temperature = ((ChatOpenAIBuilder) chatOpenAIBuilder).temperature$value;
        } else {
            this.temperature = $default$temperature();
        }
        if (((ChatOpenAIBuilder) chatOpenAIBuilder).modelKwargs$set) {
            this.modelKwargs = ((ChatOpenAIBuilder) chatOpenAIBuilder).modelKwargs$value;
        } else {
            this.modelKwargs = $default$modelKwargs();
        }
        this.openaiApiKey = ((ChatOpenAIBuilder) chatOpenAIBuilder).openaiApiKey;
        this.openaiApiBase = ((ChatOpenAIBuilder) chatOpenAIBuilder).openaiApiBase;
        this.openaiApiType = ((ChatOpenAIBuilder) chatOpenAIBuilder).openaiApiType;
        this.openaiApiVersion = ((ChatOpenAIBuilder) chatOpenAIBuilder).openaiApiVersion;
        this.openaiOrganization = ((ChatOpenAIBuilder) chatOpenAIBuilder).openaiOrganization;
        this.openaiProxy = ((ChatOpenAIBuilder) chatOpenAIBuilder).openaiProxy;
        if (((ChatOpenAIBuilder) chatOpenAIBuilder).requestTimeout$set) {
            this.requestTimeout = ((ChatOpenAIBuilder) chatOpenAIBuilder).requestTimeout$value;
        } else {
            this.requestTimeout = $default$requestTimeout();
        }
        if (((ChatOpenAIBuilder) chatOpenAIBuilder).maxRetries$set) {
            this.maxRetries = ((ChatOpenAIBuilder) chatOpenAIBuilder).maxRetries$value;
        } else {
            this.maxRetries = $default$maxRetries();
        }
        this.stream = ((ChatOpenAIBuilder) chatOpenAIBuilder).stream;
        if (((ChatOpenAIBuilder) chatOpenAIBuilder).n$set) {
            this.n = ((ChatOpenAIBuilder) chatOpenAIBuilder).n$value;
        } else {
            this.n = $default$n();
        }
        this.maxTokens = ((ChatOpenAIBuilder) chatOpenAIBuilder).maxTokens;
        this.interceptorList = ((ChatOpenAIBuilder) chatOpenAIBuilder).interceptorList;
    }

    public static ChatOpenAIBuilder<?, ?> builder() {
        return new ChatOpenAIBuilderImpl();
    }
}
