package com.hw.langchain.chains.sql.database.base;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.chains.sql.database.prompt.Prompt;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.sql.database.SQLDatabase;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.class */
public class SQLDatabaseChain extends Chain {
    private static final Logger LOG = LoggerFactory.getLogger(SQLDatabaseChain.class);
    private LLMChain llmChain;
    private SQLDatabase database;
    private int topK = 5;
    private String inputKey = "query";
    private String outputKey = "result";
    private boolean returnIntermediateSteps;
    private boolean returnDirect;
    private boolean useQueryChecker;
    private BasePromptTemplate queryCheckerPrompt;

    public SQLDatabaseChain(LLMChain lLMChain, SQLDatabase sQLDatabase) {
        this.llmChain = lLMChain;
        this.database = sQLDatabase;
    }

    public static SQLDatabaseChain fromLLM(BaseLanguageModel baseLanguageModel, SQLDatabase sQLDatabase) {
        return fromLLM(baseLanguageModel, sQLDatabase, Prompt.SQL_PROMPTS.get(sQLDatabase.getDialect()));
    }

    public static SQLDatabaseChain fromLLM(BaseLanguageModel baseLanguageModel, SQLDatabase sQLDatabase, BasePromptTemplate basePromptTemplate) {
        return new SQLDatabaseChain(new LLMChain(baseLanguageModel, basePromptTemplate), sQLDatabase);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return "sql_database_chain";
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> inputKeys() {
        return List.of(this.inputKey);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> outputKeys() {
        return List.of(this.outputKey);
    }

    @Override // com.hw.langchain.chains.base.Chain
    public Map<String, String> innerCall(Map<String, Object> map) {
        String trim;
        String str = map.get(this.inputKey) + "\nSQLQuery:";
        String tableInfo = this.database.getTableInfo((List) map.get("table_names_to_use"));
        HashMap hashMap = new HashMap();
        hashMap.put("input", str);
        hashMap.put("top_k", Integer.valueOf(this.topK));
        hashMap.put("dialect", this.database.getDialect());
        hashMap.put("table_info", tableInfo);
        hashMap.put("stop", List.of("\nSQLResult:"));
        String predict = this.llmChain.predict(hashMap);
        LOG.info("SQL command:\n {}", predict);
        String run = this.database.run(predict, false);
        LOG.info("SQLResult: \n{}", run);
        if (this.returnDirect) {
            trim = run;
        } else {
            hashMap.put("input", str + String.format("%s\nSQLResult: %s\nAnswer:", predict, run));
            trim = this.llmChain.predict(hashMap).trim();
        }
        LOG.info("Final Result: \n{}", trim);
        return Map.of(this.outputKey, trim);
    }

    public LLMChain getLlmChain() {
        return this.llmChain;
    }

    public SQLDatabase getDatabase() {
        return this.database;
    }

    public int getTopK() {
        return this.topK;
    }

    public String getInputKey() {
        return this.inputKey;
    }

    public String getOutputKey() {
        return this.outputKey;
    }

    public boolean isReturnIntermediateSteps() {
        return this.returnIntermediateSteps;
    }

    public boolean isReturnDirect() {
        return this.returnDirect;
    }

    public boolean isUseQueryChecker() {
        return this.useQueryChecker;
    }

    public BasePromptTemplate getQueryCheckerPrompt() {
        return this.queryCheckerPrompt;
    }

    public void setLlmChain(LLMChain lLMChain) {
        this.llmChain = lLMChain;
    }

    public void setDatabase(SQLDatabase sQLDatabase) {
        this.database = sQLDatabase;
    }

    public void setTopK(int i) {
        this.topK = i;
    }

    public void setInputKey(String str) {
        this.inputKey = str;
    }

    public void setOutputKey(String str) {
        this.outputKey = str;
    }

    public void setReturnIntermediateSteps(boolean z) {
        this.returnIntermediateSteps = z;
    }

    public void setReturnDirect(boolean z) {
        this.returnDirect = z;
    }

    public void setUseQueryChecker(boolean z) {
        this.useQueryChecker = z;
    }

    public void setQueryCheckerPrompt(BasePromptTemplate basePromptTemplate) {
        this.queryCheckerPrompt = basePromptTemplate;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SQLDatabaseChain)) {
            return false;
        }
        SQLDatabaseChain sQLDatabaseChain = (SQLDatabaseChain) obj;
        if (!sQLDatabaseChain.canEqual(this) || getTopK() != sQLDatabaseChain.getTopK() || isReturnIntermediateSteps() != sQLDatabaseChain.isReturnIntermediateSteps() || isReturnDirect() != sQLDatabaseChain.isReturnDirect() || isUseQueryChecker() != sQLDatabaseChain.isUseQueryChecker()) {
            return false;
        }
        LLMChain llmChain = getLlmChain();
        LLMChain llmChain2 = sQLDatabaseChain.getLlmChain();
        if (llmChain == null) {
            if (llmChain2 != null) {
                return false;
            }
        } else if (!llmChain.equals(llmChain2)) {
            return false;
        }
        SQLDatabase database = getDatabase();
        SQLDatabase database2 = sQLDatabaseChain.getDatabase();
        if (database == null) {
            if (database2 != null) {
                return false;
            }
        } else if (!database.equals(database2)) {
            return false;
        }
        String inputKey = getInputKey();
        String inputKey2 = sQLDatabaseChain.getInputKey();
        if (inputKey == null) {
            if (inputKey2 != null) {
                return false;
            }
        } else if (!inputKey.equals(inputKey2)) {
            return false;
        }
        String outputKey = getOutputKey();
        String outputKey2 = sQLDatabaseChain.getOutputKey();
        if (outputKey == null) {
            if (outputKey2 != null) {
                return false;
            }
        } else if (!outputKey.equals(outputKey2)) {
            return false;
        }
        BasePromptTemplate queryCheckerPrompt = getQueryCheckerPrompt();
        BasePromptTemplate queryCheckerPrompt2 = sQLDatabaseChain.getQueryCheckerPrompt();
        return queryCheckerPrompt == null ? queryCheckerPrompt2 == null : queryCheckerPrompt.equals(queryCheckerPrompt2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SQLDatabaseChain;
    }

    public int hashCode() {
        int topK = (((((((1 * 59) + getTopK()) * 59) + (isReturnIntermediateSteps() ? 79 : 97)) * 59) + (isReturnDirect() ? 79 : 97)) * 59) + (isUseQueryChecker() ? 79 : 97);
        LLMChain llmChain = getLlmChain();
        int hashCode = (topK * 59) + (llmChain == null ? 43 : llmChain.hashCode());
        SQLDatabase database = getDatabase();
        int hashCode2 = (hashCode * 59) + (database == null ? 43 : database.hashCode());
        String inputKey = getInputKey();
        int hashCode3 = (hashCode2 * 59) + (inputKey == null ? 43 : inputKey.hashCode());
        String outputKey = getOutputKey();
        int hashCode4 = (hashCode3 * 59) + (outputKey == null ? 43 : outputKey.hashCode());
        BasePromptTemplate queryCheckerPrompt = getQueryCheckerPrompt();
        return (hashCode4 * 59) + (queryCheckerPrompt == null ? 43 : queryCheckerPrompt.hashCode());
    }

    public String toString() {
        return "SQLDatabaseChain(llmChain=" + getLlmChain() + ", database=" + getDatabase() + ", topK=" + getTopK() + ", inputKey=" + getInputKey() + ", outputKey=" + getOutputKey() + ", returnIntermediateSteps=" + isReturnIntermediateSteps() + ", returnDirect=" + isReturnDirect() + ", useQueryChecker=" + isUseQueryChecker() + ", queryCheckerPrompt=" + getQueryCheckerPrompt() + ")";
    }
}
