package com.hw.langchain.chains.retrieval.qa.base;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.ChainType;
import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain;
import com.hw.langchain.chains.question.answering.Init;
import com.hw.langchain.schema.BaseRetriever;
import com.hw.langchain.schema.Document;
import java.util.List;

/* loaded from: input_file:com/hw/langchain/chains/retrieval/qa/base/RetrievalQa.class */
public class RetrievalQa extends BaseRetrievalQA {
    private final BaseRetriever retriever;

    public RetrievalQa(BaseCombineDocumentsChain baseCombineDocumentsChain, BaseRetriever baseRetriever) {
        super(baseCombineDocumentsChain);
        this.retriever = baseRetriever;
    }

    public static BaseRetrievalQA fromChainType(BaseLanguageModel baseLanguageModel, BaseRetriever baseRetriever) {
        return fromChainType(baseLanguageModel, ChainType.STUFF, baseRetriever);
    }

    public static BaseRetrievalQA fromChainType(BaseLanguageModel baseLanguageModel, ChainType chainType, BaseRetriever baseRetriever) {
        return new RetrievalQa(Init.loadQaChain(baseLanguageModel, chainType), baseRetriever);
    }

    @Override // com.hw.langchain.chains.retrieval.qa.base.BaseRetrievalQA
    public List<Document> getDocs(String str) {
        return this.retriever.getRelevantDocuments(str);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return "retrieval_qa";
    }
}
