package com.hw.langchain.chains.llm.math.base;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.chains.llm.math.prompt.Prompt;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.python.core.PyObject;
import org.python.util.PythonInterpreter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/hw/langchain/chains/llm/math/base/LLMMathChain.class */
public class LLMMathChain extends Chain {
    private static final Logger LOG = LoggerFactory.getLogger(LLMMathChain.class);
    private static final Pattern TEXT_PATTERN = Pattern.compile("^```text(.*?)```", 32);
    private LLMChain llmChain;
    private String inputKey = "question";
    private String outputKey = "answer";

    public LLMMathChain() {
    }

    public LLMMathChain(LLMChain lLMChain) {
        this.llmChain = lLMChain;
    }

    public static LLMMathChain fromLLM(BaseLanguageModel baseLanguageModel) {
        return fromLLM(baseLanguageModel, Prompt.PROMPT);
    }

    public static LLMMathChain fromLLM(BaseLanguageModel baseLanguageModel, BasePromptTemplate basePromptTemplate) {
        return new LLMMathChain(new LLMChain(baseLanguageModel, basePromptTemplate));
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return "llm_math_chain";
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> inputKeys() {
        return List.of(this.inputKey);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> outputKeys() {
        return List.of(this.outputKey);
    }

    public String evaluateExpression(String str) {
        LOG.debug("expression: {}", str);
        PythonInterpreter pythonInterpreter = new PythonInterpreter();
        try {
            Map of = Map.of("pi", Double.valueOf(3.141592653589793d), "e", Double.valueOf(2.718281828459045d));
            Objects.requireNonNull(pythonInterpreter);
            of.forEach((v1, v2) -> {
                r1.set(v1, v2);
            });
            PyObject eval = pythonInterpreter.eval(str.strip());
            pythonInterpreter.close();
            return eval.toString().replaceAll("^\\[|\\]$", "");
        } catch (Throwable th) {
            try {
                pythonInterpreter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public Map<String, String> processLLMResult(String str) {
        String str2;
        String strip = str.strip();
        Matcher matcher = TEXT_PATTERN.matcher(strip);
        if (matcher.find()) {
            str2 = "Answer: " + evaluateExpression(matcher.group(1));
        } else if (strip.startsWith("Answer:")) {
            str2 = strip;
        } else {
            if (!strip.contains("Answer:")) {
                throw new IllegalArgumentException("unknown format from LLM: " + strip);
            }
            str2 = "Answer: " + strip.split("Answer:")[1];
        }
        return Map.of(this.outputKey, str2);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public Map<String, String> innerCall(Map<String, Object> map) {
        return processLLMResult(this.llmChain.predict(Map.of("question", map.get(this.inputKey), "stop", List.of("```output"))));
    }
}
