package com.hw.langchain.chains.api.base;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.api.prompt.Prompt;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.requests.TextRequestsWrapper;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:com/hw/langchain/chains/api/base/ApiChain.class */
public class ApiChain extends Chain {
    private final LLMChain apiRequestChain;
    private final LLMChain apiAnswerChain;
    private final TextRequestsWrapper requestsWrapper;
    private final String apiDocs;
    private static final String QUESTION_KEY = "question";
    private static final String OUTPUT_KEY = "output";
    private static final String API_DOCS = "api_docs";

    public ApiChain(LLMChain lLMChain, LLMChain lLMChain2, TextRequestsWrapper textRequestsWrapper, String str) {
        this.apiRequestChain = lLMChain;
        this.apiAnswerChain = lLMChain2;
        this.requestsWrapper = textRequestsWrapper;
        this.apiDocs = str;
        validateApiRequestPrompt();
        validateApiAnswerPrompt();
    }

    private void validateApiRequestPrompt() {
        List<String> inputVariables = this.apiRequestChain.getPrompt().getInputVariables();
        HashSet hashSet = new HashSet(inputVariables);
        Set of = Set.of(QUESTION_KEY, API_DOCS);
        if (!hashSet.equals(of)) {
            throw new IllegalArgumentException("Input variables should be " + of + ", got " + inputVariables);
        }
    }

    private void validateApiAnswerPrompt() {
        List<String> inputVariables = this.apiAnswerChain.getPrompt().getInputVariables();
        HashSet hashSet = new HashSet(inputVariables);
        Set of = Set.of(QUESTION_KEY, API_DOCS, "api_url", "api_response");
        if (!hashSet.equals(of)) {
            throw new IllegalArgumentException("Input variables should be " + of + ", got " + inputVariables);
        }
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> inputKeys() {
        return List.of(QUESTION_KEY);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> outputKeys() {
        return List.of(OUTPUT_KEY);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public Map<String, String> innerCall(Map<String, Object> map) {
        Object obj = map.get(QUESTION_KEY);
        String strip = this.apiRequestChain.predict(Map.of(QUESTION_KEY, obj, API_DOCS, this.apiDocs)).strip();
        return Map.of(OUTPUT_KEY, this.apiAnswerChain.predict(Map.of(QUESTION_KEY, obj, API_DOCS, this.apiDocs, "api_url", strip, "api_response", this.requestsWrapper.get(strip))));
    }

    public static ApiChain fromLlmAndApiDocs(BaseLanguageModel baseLanguageModel, String str) {
        return fromLlmAndApiDocs(baseLanguageModel, str, null, Prompt.API_URL_PROMPT, Prompt.API_RESPONSE_PROMPT);
    }

    public static ApiChain fromLlmAndApiDocs(BaseLanguageModel baseLanguageModel, String str, Map<String, String> map, BasePromptTemplate basePromptTemplate, BasePromptTemplate basePromptTemplate2) {
        return new ApiChain(new LLMChain(baseLanguageModel, basePromptTemplate), new LLMChain(baseLanguageModel, basePromptTemplate2), new TextRequestsWrapper(map), str);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return "api_chain";
    }
}
