package com.hw.langchain.chains.conversation.base;

import com.google.common.collect.Sets;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.conversation.prompt.Prompt;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.memory.buffer.ConversationBufferMemory;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.schema.BaseMemory;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

/* loaded from: input_file:com/hw/langchain/chains/conversation/base/ConversationChain.class */
public class ConversationChain extends LLMChain {
    protected String inputKey;

    public ConversationChain(BaseLanguageModel baseLanguageModel) {
        this(baseLanguageModel, Prompt.PROMPT, new ConversationBufferMemory());
    }

    public ConversationChain(BaseLanguageModel baseLanguageModel, BasePromptTemplate basePromptTemplate, BaseMemory baseMemory) {
        super(baseLanguageModel, basePromptTemplate, "response");
        this.inputKey = "input";
        this.memory = baseMemory;
        validatePromptInputVariables();
    }

    @Override // com.hw.langchain.chains.llm.LLMChain, com.hw.langchain.chains.base.Chain
    public List<String> inputKeys() {
        return List.of(this.inputKey);
    }

    public void validatePromptInputVariables() {
        List<String> memoryVariables = this.memory.memoryVariables();
        if (memoryVariables.contains(this.inputKey)) {
            throw new IllegalArgumentException(String.format("The input key %s was also found in the memory keys %s - please provide keys that don't overlap.", this.inputKey, memoryVariables));
        }
        List<String> inputVariables = this.prompt.getInputVariables();
        ArrayList arrayList = new ArrayList(memoryVariables);
        arrayList.add(this.inputKey);
        if (!Sets.symmetricDifference(new HashSet(inputVariables), new HashSet(arrayList)).isEmpty()) {
            throw new IllegalArgumentException(String.format("Got unexpected prompt input variables. The prompt expects %s, but got %s as inputs from memory, and %s as the normal input key.", inputVariables, memoryVariables, this.inputKey));
        }
    }
}
