package com.hw.langchain.chains.combine.documents.stuff;

import com.google.common.collect.Maps;
import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain;
import com.hw.langchain.chains.combine.documents.base.BaseUtils;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.schema.Document;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.tuple.Pair;

/* loaded from: input_file:com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.class */
public class StuffDocumentsChain extends BaseCombineDocumentsChain {
    private final LLMChain llmChain;
    private final BasePromptTemplate documentPrompt;
    private String documentVariableName;
    private final String documentSeparator;

    public StuffDocumentsChain(LLMChain lLMChain, String str) {
        this(lLMChain, StuffUtils.getDefaultDocumentPrompt(), str, "\n\n");
    }

    public StuffDocumentsChain(LLMChain lLMChain, BasePromptTemplate basePromptTemplate, String str, String str2) {
        this.llmChain = lLMChain;
        this.documentPrompt = basePromptTemplate;
        this.documentVariableName = str;
        this.documentSeparator = str2;
        getDefaultDocumentVariableName();
    }

    private void getDefaultDocumentVariableName() {
        List<String> inputVariables = this.llmChain.getPrompt().getInputVariables();
        if (this.documentVariableName != null) {
            if (!inputVariables.contains(this.documentVariableName)) {
                throw new IllegalArgumentException("documentVariableName " + this.documentVariableName + " was not found in llmChain inputVariables: " + inputVariables);
            }
        } else {
            if (inputVariables.size() != 1) {
                throw new IllegalArgumentException("documentVariableName must be provided if there are multiple llmChainVariables");
            }
            this.documentVariableName = inputVariables.get(0);
        }
    }

    private Map<String, Object> getInputs(List<Document> list, Map<String, Object> map) {
        List list2 = list.stream().map(document -> {
            return BaseUtils.formatDocument(document, this.documentPrompt);
        }).toList();
        List<String> inputVariables = this.llmChain.getPrompt().getInputVariables();
        Objects.requireNonNull(inputVariables);
        Map<String, Object> filterKeys = Maps.filterKeys(map, (v1) -> {
            return r1.contains(v1);
        });
        filterKeys.put(this.documentVariableName, String.join(this.documentSeparator, list2));
        return filterKeys;
    }

    @Override // com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain
    public Pair<String, Map<String, String>> combineDocs(List<Document> list, Map<String, Object> map) {
        return Pair.of(this.llmChain.predict(getInputs(list, map)), Map.of());
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return "stuff_documents_chain";
    }
}
