package com.hw.langchain.chains.question.answering;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.ChainType;
import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain;
import com.hw.langchain.chains.combine.documents.stuff.StuffDocumentsChain;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import java.util.Map;
import java.util.function.Function;

/* loaded from: input_file:com/hw/langchain/chains/question/answering/Init.class */
public class Init {
    private static final Map<ChainType, Function<BaseLanguageModel, BaseCombineDocumentsChain>> LOADER_MAPPING = Map.of(ChainType.STUFF, Init::loadStuffChain);

    private Init() {
        throw new IllegalStateException("Utility class");
    }

    public static StuffDocumentsChain loadStuffChain(BaseLanguageModel baseLanguageModel) {
        return loadStuffChain(baseLanguageModel, StuffPrompt.PROMPT_SELECTOR.getPrompt(baseLanguageModel), "context");
    }

    public static StuffDocumentsChain loadStuffChain(BaseLanguageModel baseLanguageModel, BasePromptTemplate basePromptTemplate, String str) {
        return new StuffDocumentsChain(new LLMChain(baseLanguageModel, basePromptTemplate), str);
    }

    public static BaseCombineDocumentsChain loadQaChain(BaseLanguageModel baseLanguageModel) {
        return loadQaChain(baseLanguageModel, ChainType.STUFF);
    }

    public static BaseCombineDocumentsChain loadQaChain(BaseLanguageModel baseLanguageModel, ChainType chainType) {
        return LOADER_MAPPING.get(chainType).apply(baseLanguageModel);
    }
}
