package com.hw.langchain.llms.ollama;

import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.collect.Maps;
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.requests.TextRequestsWrapper;
import com.hw.langchain.schema.GenerationChunk;
import com.hw.langchain.schema.LLMResult;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:com/hw/langchain/llms/ollama/Ollama.class */
public class Ollama extends BaseLLM {
    private String baseUrl;
    private String model;
    private TextRequestsWrapper requestsWrapper;
    private Integer mirostat;
    private Float mirostatEta;
    private Float mirostatTau;
    private Integer numCtx;
    private Integer numGqa;
    private Integer numGpu;
    private Integer numThread;
    private Integer repeatLastN;
    private Float repeatPenalty;
    private Float temperature;
    private Float tfsZ;
    private Integer topK;
    private Float topP;

    /* loaded from: input_file:com/hw/langchain/llms/ollama/Ollama$OllamaBuilder.class */
    public static abstract class OllamaBuilder<C extends Ollama, B extends OllamaBuilder<C, B>> extends BaseLLM.BaseLLMBuilder<C, B> {
        private boolean baseUrl$set;
        private String baseUrl$value;
        private boolean model$set;
        private String model$value;
        private TextRequestsWrapper requestsWrapper;
        private boolean mirostat$set;
        private Integer mirostat$value;
        private boolean mirostatEta$set;
        private Float mirostatEta$value;
        private boolean mirostatTau$set;
        private Float mirostatTau$value;
        private boolean numCtx$set;
        private Integer numCtx$value;
        private Integer numGqa;
        private Integer numGpu;
        private Integer numThread;
        private boolean repeatLastN$set;
        private Integer repeatLastN$value;
        private boolean repeatPenalty$set;
        private Float repeatPenalty$value;
        private boolean temperature$set;
        private Float temperature$value;
        private boolean tfsZ$set;
        private Float tfsZ$value;
        private boolean topK$set;
        private Integer topK$value;
        private boolean topP$set;
        private Float topP$value;

        public B baseUrl(String str) {
            this.baseUrl$value = str;
            this.baseUrl$set = true;
            return self();
        }

        public B model(String str) {
            this.model$value = str;
            this.model$set = true;
            return self();
        }

        public B requestsWrapper(TextRequestsWrapper textRequestsWrapper) {
            this.requestsWrapper = textRequestsWrapper;
            return self();
        }

        public B mirostat(Integer num) {
            this.mirostat$value = num;
            this.mirostat$set = true;
            return self();
        }

        public B mirostatEta(Float f) {
            this.mirostatEta$value = f;
            this.mirostatEta$set = true;
            return self();
        }

        public B mirostatTau(Float f) {
            this.mirostatTau$value = f;
            this.mirostatTau$set = true;
            return self();
        }

        public B numCtx(Integer num) {
            this.numCtx$value = num;
            this.numCtx$set = true;
            return self();
        }

        public B numGqa(Integer num) {
            this.numGqa = num;
            return self();
        }

        public B numGpu(Integer num) {
            this.numGpu = num;
            return self();
        }

        public B numThread(Integer num) {
            this.numThread = num;
            return self();
        }

        public B repeatLastN(Integer num) {
            this.repeatLastN$value = num;
            this.repeatLastN$set = true;
            return self();
        }

        public B repeatPenalty(Float f) {
            this.repeatPenalty$value = f;
            this.repeatPenalty$set = true;
            return self();
        }

        public B temperature(Float f) {
            this.temperature$value = f;
            this.temperature$set = true;
            return self();
        }

        public B tfsZ(Float f) {
            this.tfsZ$value = f;
            this.tfsZ$set = true;
            return self();
        }

        public B topK(Integer num) {
            this.topK$value = num;
            this.topK$set = true;
            return self();
        }

        public B topP(Float f) {
            this.topP$value = f;
            this.topP$set = true;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public abstract B self();

        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public abstract C build();

        @Override // com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public String toString() {
            return "Ollama.OllamaBuilder(super=" + super.toString() + ", baseUrl$value=" + this.baseUrl$value + ", model$value=" + this.model$value + ", requestsWrapper=" + this.requestsWrapper + ", mirostat$value=" + this.mirostat$value + ", mirostatEta$value=" + this.mirostatEta$value + ", mirostatTau$value=" + this.mirostatTau$value + ", numCtx$value=" + this.numCtx$value + ", numGqa=" + this.numGqa + ", numGpu=" + this.numGpu + ", numThread=" + this.numThread + ", repeatLastN$value=" + this.repeatLastN$value + ", repeatPenalty$value=" + this.repeatPenalty$value + ", temperature$value=" + this.temperature$value + ", tfsZ$value=" + this.tfsZ$value + ", topK$value=" + this.topK$value + ", topP$value=" + this.topP$value + ")";
        }
    }

    /* loaded from: input_file:com/hw/langchain/llms/ollama/Ollama$OllamaBuilderImpl.class */
    private static final class OllamaBuilderImpl extends OllamaBuilder<Ollama, OllamaBuilderImpl> {
        private OllamaBuilderImpl() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.hw.langchain.llms.ollama.Ollama.OllamaBuilder, com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public OllamaBuilderImpl self() {
            return this;
        }

        @Override // com.hw.langchain.llms.ollama.Ollama.OllamaBuilder, com.hw.langchain.llms.base.BaseLLM.BaseLLMBuilder
        public Ollama build() {
            return new Ollama(this);
        }
    }

    @Override // com.hw.langchain.llms.base.BaseLLM
    public String llmType() {
        return "ollama-llm";
    }

    public Ollama init() {
        this.requestsWrapper = new TextRequestsWrapper(Map.of("Content-Type", "application/json"));
        return this;
    }

    private Map<String, Object> createParams(List<String> list) {
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("mirostat", this.mirostat);
        newHashMap.put("mirostat_eta", this.mirostatEta);
        newHashMap.put("mirostat_tau", this.mirostatTau);
        newHashMap.put("num_ctx", this.numCtx);
        newHashMap.put("num_gqa", this.numGqa);
        newHashMap.put("num_gpu", this.numGpu);
        newHashMap.put("num_thread", this.numThread);
        newHashMap.put("repeat_last_n", this.repeatLastN);
        newHashMap.put("repeat_penalty", this.repeatPenalty);
        newHashMap.put("temperature", this.temperature);
        newHashMap.put("stop", list);
        newHashMap.put("tfs_z", this.tfsZ);
        newHashMap.put("top_k", this.topK);
        newHashMap.put("top_p", this.topP);
        return newHashMap;
    }

    public List<String> createStream(String str, List<String> list) {
        return this.requestsWrapper.post(this.baseUrl + "/api/generate", Map.of("model", this.model, "prompt", str, "options", createParams(list))).lines().toList();
    }

    @Override // com.hw.langchain.llms.base.BaseLLM
    protected LLMResult innerGenerate(List<String> list, List<String> list2) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            GenerationChunk generationChunk = null;
            for (String str : createStream(it.next(), list2)) {
                if (StringUtils.isNotEmpty(str)) {
                    GenerationChunk streamResponseToGenerationChunk = streamResponseToGenerationChunk(str);
                    generationChunk = generationChunk == null ? streamResponseToGenerationChunk : generationChunk.add(streamResponseToGenerationChunk);
                }
            }
            arrayList.add(List.of((GenerationChunk) Objects.requireNonNull(generationChunk)));
        }
        return new LLMResult(arrayList);
    }

    public static GenerationChunk streamResponseToGenerationChunk(String str) {
        Map map = (Map) JsonUtils.convertFromJsonStr(str, new TypeReference<Map<String, Object>>() { // from class: com.hw.langchain.llms.ollama.Ollama.1
        });
        Map map2 = null;
        if (map.get("done").equals(true)) {
            map2 = map;
        }
        return new GenerationChunk((String) map.getOrDefault("response", ""), map2);
    }

    private static String $default$baseUrl() {
        return "http://localhost:11434";
    }

    private static String $default$model() {
        return "llama2";
    }

    private static Integer $default$mirostat() {
        return 0;
    }

    private static Float $default$mirostatEta() {
        return Float.valueOf(0.1f);
    }

    private static Float $default$mirostatTau() {
        return Float.valueOf(5.0f);
    }

    private static Integer $default$numCtx() {
        return 2048;
    }

    private static Integer $default$repeatLastN() {
        return 64;
    }

    private static Float $default$repeatPenalty() {
        return Float.valueOf(1.1f);
    }

    private static Float $default$temperature() {
        return Float.valueOf(0.8f);
    }

    private static Float $default$tfsZ() {
        return Float.valueOf(1.0f);
    }

    private static Integer $default$topK() {
        return 40;
    }

    private static Float $default$topP() {
        return Float.valueOf(0.9f);
    }

    protected Ollama(OllamaBuilder<?, ?> ollamaBuilder) {
        super(ollamaBuilder);
        if (((OllamaBuilder) ollamaBuilder).baseUrl$set) {
            this.baseUrl = ((OllamaBuilder) ollamaBuilder).baseUrl$value;
        } else {
            this.baseUrl = $default$baseUrl();
        }
        if (((OllamaBuilder) ollamaBuilder).model$set) {
            this.model = ((OllamaBuilder) ollamaBuilder).model$value;
        } else {
            this.model = $default$model();
        }
        this.requestsWrapper = ((OllamaBuilder) ollamaBuilder).requestsWrapper;
        if (((OllamaBuilder) ollamaBuilder).mirostat$set) {
            this.mirostat = ((OllamaBuilder) ollamaBuilder).mirostat$value;
        } else {
            this.mirostat = $default$mirostat();
        }
        if (((OllamaBuilder) ollamaBuilder).mirostatEta$set) {
            this.mirostatEta = ((OllamaBuilder) ollamaBuilder).mirostatEta$value;
        } else {
            this.mirostatEta = $default$mirostatEta();
        }
        if (((OllamaBuilder) ollamaBuilder).mirostatTau$set) {
            this.mirostatTau = ((OllamaBuilder) ollamaBuilder).mirostatTau$value;
        } else {
            this.mirostatTau = $default$mirostatTau();
        }
        if (((OllamaBuilder) ollamaBuilder).numCtx$set) {
            this.numCtx = ((OllamaBuilder) ollamaBuilder).numCtx$value;
        } else {
            this.numCtx = $default$numCtx();
        }
        this.numGqa = ((OllamaBuilder) ollamaBuilder).numGqa;
        this.numGpu = ((OllamaBuilder) ollamaBuilder).numGpu;
        this.numThread = ((OllamaBuilder) ollamaBuilder).numThread;
        if (((OllamaBuilder) ollamaBuilder).repeatLastN$set) {
            this.repeatLastN = ((OllamaBuilder) ollamaBuilder).repeatLastN$value;
        } else {
            this.repeatLastN = $default$repeatLastN();
        }
        if (((OllamaBuilder) ollamaBuilder).repeatPenalty$set) {
            this.repeatPenalty = ((OllamaBuilder) ollamaBuilder).repeatPenalty$value;
        } else {
            this.repeatPenalty = $default$repeatPenalty();
        }
        if (((OllamaBuilder) ollamaBuilder).temperature$set) {
            this.temperature = ((OllamaBuilder) ollamaBuilder).temperature$value;
        } else {
            this.temperature = $default$temperature();
        }
        if (((OllamaBuilder) ollamaBuilder).tfsZ$set) {
            this.tfsZ = ((OllamaBuilder) ollamaBuilder).tfsZ$value;
        } else {
            this.tfsZ = $default$tfsZ();
        }
        if (((OllamaBuilder) ollamaBuilder).topK$set) {
            this.topK = ((OllamaBuilder) ollamaBuilder).topK$value;
        } else {
            this.topK = $default$topK();
        }
        if (((OllamaBuilder) ollamaBuilder).topP$set) {
            this.topP = ((OllamaBuilder) ollamaBuilder).topP$value;
        } else {
            this.topP = $default$topP();
        }
    }

    public static OllamaBuilder<?, ?> builder() {
        return new OllamaBuilderImpl();
    }
}
