package com.hw.langchain.sql.database;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:com/hw/langchain/sql/database/SQLDatabase.class */
public class SQLDatabase {
    private final Connection connection;
    private final List<String> includeTables;
    private final List<String> ignoreTables;
    private final int sampleRowsInTableInfo;
    private boolean indexesInTableInfo;

    public SQLDatabase(String str, String str2, String str3) {
        this(str, str2, str3, null, null, 3, false);
    }

    public SQLDatabase(String str, String str2, String str3, List<String> list, List<String> list2, int i, boolean z) {
        try {
            if (CollectionUtils.isNotEmpty(list) && CollectionUtils.isNotEmpty(list2)) {
                throw new IllegalArgumentException("Cannot specify both includeTables and ignoreTables");
            }
            this.connection = DriverManager.getConnection(str, str2, str3);
            this.includeTables = list;
            this.ignoreTables = list2;
            this.sampleRowsInTableInfo = i;
            this.indexesInTableInfo = z;
        } catch (SQLException e) {
            throw e;
        }
    }

    public static SQLDatabase fromUri(String str, String str2, String str3) {
        return new SQLDatabase(str, str2, str3);
    }

    public String getDialect() {
        try {
            return this.connection.getMetaData().getDatabaseProductName().toLowerCase();
        } catch (SQLException e) {
            throw e;
        }
    }

    public List<String> getUsableTableNames() {
        if (CollectionUtils.isNotEmpty(this.includeTables)) {
            return this.includeTables;
        }
        List<String> allTables = getAllTables();
        if (CollectionUtils.isNotEmpty(this.ignoreTables)) {
            allTables.removeAll(this.ignoreTables);
        }
        return allTables;
    }

    private List<String> getAllTables() {
        try {
            ArrayList arrayList = new ArrayList();
            ResultSet tables = this.connection.getMetaData().getTables(this.connection.getCatalog(), this.connection.getSchema(), "%", new String[]{"TABLE"});
            while (tables.next()) {
                try {
                    arrayList.add(tables.getString("TABLE_NAME"));
                } finally {
                }
            }
            if (tables != null) {
                tables.close();
            }
            return arrayList;
        } catch (SQLException e) {
            throw e;
        }
    }

    public String getTableInfo(List<String> list) {
        List<String> usableTableNames = getUsableTableNames();
        if (list != null) {
            ArrayList arrayList = new ArrayList(list);
            arrayList.removeAll(usableTableNames);
            if (!arrayList.isEmpty()) {
                throw new IllegalArgumentException("tableNames " + arrayList + " not found in database");
            }
            usableTableNames = list;
        }
        ArrayList arrayList2 = new ArrayList();
        for (String str : usableTableNames) {
            String replaceAll = getTableDdl(str).replaceAll("\\n+$", "");
            boolean z = this.indexesInTableInfo || this.sampleRowsInTableInfo > 0;
            if (z) {
                replaceAll = replaceAll + "\n\n/*";
            }
            if (this.indexesInTableInfo) {
                replaceAll = replaceAll + "\n" + getTableIndexes(str) + "\n";
            }
            if (this.sampleRowsInTableInfo > 0) {
                replaceAll = replaceAll + "\n" + getSampleRows(str) + "\n";
            }
            if (z) {
                replaceAll = replaceAll + "*/";
            }
            arrayList2.add(replaceAll);
        }
        return String.join("\n\n", arrayList2);
    }

    public String getTableDdl(String str) {
        try {
            StringBuilder sb = new StringBuilder();
            DatabaseMetaData metaData = this.connection.getMetaData();
            ResultSet tables = metaData.getTables(this.connection.getCatalog(), this.connection.getSchema(), str, new String[]{"TABLE"});
            while (tables.next()) {
                ResultSet columns = metaData.getColumns(this.connection.getCatalog(), this.connection.getSchema(), str, "%");
                sb.append("\nCREATE TABLE ").append(str).append(" (");
                while (columns.next()) {
                    String string = columns.getString("COLUMN_NAME");
                    String string2 = columns.getString("TYPE_NAME");
                    int i = columns.getInt("COLUMN_SIZE");
                    int i2 = columns.getInt("DECIMAL_DIGITS");
                    boolean z = columns.getBoolean("NULLABLE");
                    String string3 = columns.getString("COLUMN_DEF");
                    String string4 = columns.getString("REMARKS");
                    sb.append("\n\t").append(string).append(" ").append(string2);
                    if (i > 0) {
                        sb.append("(").append(i);
                        if (i2 > 0) {
                            sb.append(",").append(i2);
                        }
                        sb.append(")");
                    }
                    if (!z) {
                        sb.append(" NOT NULL");
                    }
                    if (string3 != null) {
                        sb.append(" DEFAULT ").append(string3);
                    }
                    if (StringUtils.isNotEmpty(string4)) {
                        sb.append(" COMMENT '").append(string4).append("'");
                    }
                    sb.append(",");
                }
                if (sb.charAt(sb.length() - 1) == ',') {
                    sb.deleteCharAt(sb.length() - 1);
                }
                String string5 = tables.getString("REMARKS");
                if (StringUtils.isNotEmpty(string5)) {
                    sb.append("\n) COMMENT '").append(string5).append("'\n\n");
                } else {
                    sb.append("\n)\n\n");
                }
            }
            return sb.toString();
        } catch (SQLException e) {
            throw e;
        }
    }

    public String getTableIndexes(String str) {
        return "";
    }

    public String getSampleRows(String str) {
        return String.format("%d rows from %s table:\n%s", Integer.valueOf(this.sampleRowsInTableInfo), str, run("SELECT * FROM " + str + " LIMIT " + this.sampleRowsInTableInfo, true));
    }

    public String run(String str, boolean z) {
        try {
            Statement createStatement = this.connection.createStatement();
            try {
                if (!createStatement.execute(str)) {
                    String str2 = "Update Count: " + createStatement.getUpdateCount();
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    return str2;
                }
                ResultSet resultSet = createStatement.getResultSet();
                ResultSetMetaData metaData = resultSet.getMetaData();
                int columnCount = resultSet.getMetaData().getColumnCount();
                String str3 = "";
                if (z) {
                    ArrayList arrayList = new ArrayList();
                    for (int i = 1; i <= columnCount; i++) {
                        arrayList.add(metaData.getColumnName(i));
                    }
                    str3 = str3 + String.join("\t", arrayList) + "\n";
                }
                ArrayList arrayList2 = new ArrayList();
                while (resultSet.next()) {
                    ArrayList arrayList3 = new ArrayList();
                    for (int i2 = 1; i2 <= columnCount; i2++) {
                        arrayList3.add(resultSet.getString(i2));
                    }
                    arrayList2.add(arrayList3);
                }
                String str4 = str3 + ((String) arrayList2.stream().map(list -> {
                    return String.join("\t", list);
                }).collect(Collectors.joining("\n")));
                if (createStatement != null) {
                    createStatement.close();
                }
                return str4;
            } finally {
            }
        } catch (SQLException e) {
            throw e;
        }
    }

    public void close() {
        try {
            if (this.connection != null) {
                this.connection.close();
            }
        } catch (SQLException e) {
            throw e;
        }
    }
}
