package com.hw.langchain.retrievers.self.query.base;

import com.google.common.collect.Maps;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.chains.query.constructor.ir.StructuredQuery;
import com.hw.langchain.chains.query.constructor.ir.Visitor;
import com.hw.langchain.chains.query.constructor.schema.AttributeInfo;
import com.hw.langchain.schema.BaseRetriever;
import com.hw.langchain.schema.Document;
import com.hw.langchain.vectorstores.base.SearchType;
import com.hw.langchain.vectorstores.base.VectorStore;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/hw/langchain/retrievers/self/query/base/SelfQueryRetriever.class */
public class SelfQueryRetriever implements BaseRetriever {
    private static final Logger LOG = LoggerFactory.getLogger(SelfQueryRetriever.class);
    private final VectorStore vectorStore;
    private final LLMChain llmChain;
    private final SearchType searchType;
    private final Map<String, Object> searchKwargs;
    private final Visitor structuredQueryTranslator;
    private final boolean useOriginalQuery;

    public SelfQueryRetriever(VectorStore vectorStore, LLMChain lLMChain, Visitor visitor, boolean z) {
        this(vectorStore, lLMChain, SearchType.SIMILARITY, Maps.newHashMap(), visitor, z);
    }

    public SelfQueryRetriever(VectorStore vectorStore, LLMChain lLMChain, SearchType searchType, Map<String, Object> map, Visitor visitor, boolean z) {
        this.vectorStore = vectorStore;
        this.llmChain = lLMChain;
        this.searchType = searchType;
        this.searchKwargs = map;
        this.structuredQueryTranslator = visitor;
        this.useOriginalQuery = z;
    }

    @Override // com.hw.langchain.schema.BaseRetriever
    public List<Document> getRelevantDocuments(String str) {
        StructuredQuery structuredQuery = (StructuredQuery) this.llmChain.predictAndParse(this.llmChain.prepInputs(Map.of("query", str)));
        LOG.info("Structured Query: {}", structuredQuery);
        return this.vectorStore.search(structuredQuery.getQuery(), this.searchType, this.structuredQueryTranslator.visitStructuredQuery(structuredQuery));
    }

    public static SelfQueryRetriever fromLLM(BaseLanguageModel baseLanguageModel, VectorStore vectorStore, String str, List<AttributeInfo> list) {
        return fromLLM(baseLanguageModel, vectorStore, str, list, BaseUtils.getBuiltinTranslator(vectorStore), false, false);
    }

    public static SelfQueryRetriever fromLLM(BaseLanguageModel baseLanguageModel, VectorStore vectorStore, String str, List<AttributeInfo> list, Visitor visitor, boolean z, boolean z2) {
        return new SelfQueryRetriever(vectorStore, com.hw.langchain.chains.query.constructor.base.BaseUtils.loadQueryConstructorChain(baseLanguageModel, str, list, null, visitor.getAllowedComparators(), visitor.getAllowedOperators(), z), visitor, z2);
    }
}
