package com.hw.langchain.chains.sql.database.base;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.chains.sql.database.prompt.Prompt;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.sql.database.SQLDatabase;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.class */
public class SQLDatabaseSequentialChain extends Chain {
    private static final Logger LOG = LoggerFactory.getLogger(SQLDatabaseSequentialChain.class);
    private SQLDatabaseChain sqlChain;
    private LLMChain deciderChain;
    private String inputKey = "query";
    private String outputKey = "result";

    public SQLDatabaseSequentialChain(SQLDatabaseChain sQLDatabaseChain, LLMChain lLMChain) {
        this.sqlChain = sQLDatabaseChain;
        this.deciderChain = lLMChain;
    }

    public static SQLDatabaseSequentialChain fromLLM(BaseLanguageModel baseLanguageModel, SQLDatabase sQLDatabase, BasePromptTemplate basePromptTemplate, BasePromptTemplate basePromptTemplate2) {
        return new SQLDatabaseSequentialChain(SQLDatabaseChain.fromLLM(baseLanguageModel, sQLDatabase, basePromptTemplate), new LLMChain(baseLanguageModel, basePromptTemplate2, "table_names"));
    }

    public static SQLDatabaseSequentialChain fromLLM(BaseLanguageModel baseLanguageModel, SQLDatabase sQLDatabase) {
        return fromLLM(baseLanguageModel, sQLDatabase, Prompt.PROMPT, Prompt.DECIDER_PROMPT);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return "sql_database_sequential_chain";
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> inputKeys() {
        return List.of(this.inputKey);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> outputKeys() {
        return List.of(this.outputKey);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public Map<String, String> innerCall(Map<String, Object> map) {
        List<String> usableTableNames = this.sqlChain.getDatabase().getUsableTableNames();
        Map<String, Object> of = Map.of("query", map.get(this.inputKey), "table_names", String.join(", ", usableTableNames));
        List list = usableTableNames.stream().map((v0) -> {
            return v0.toLowerCase();
        }).toList();
        List<String> list2 = (List) this.deciderChain.predictAndParse(of);
        ArrayList arrayList = new ArrayList();
        for (String str : list2) {
            if (list.contains(str.toLowerCase())) {
                arrayList.add(str);
            }
        }
        LOG.info("Table names to use: {}", arrayList);
        return this.sqlChain.call(Map.of(this.sqlChain.getInputKey(), map.get(this.inputKey), "table_names_to_use", arrayList), true);
    }
}
