package com.hw.langchain.agents.agent;

import com.hw.langchain.agents.tools.InvalidTool;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.schema.AgentFinish;
import com.hw.langchain.schema.AgentResult;
import com.hw.langchain.tools.base.BaseTool;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/hw/langchain/agents/agent/AgentExecutor.class */
public class AgentExecutor extends Chain {
    private static final Logger LOG = LoggerFactory.getLogger(AgentExecutor.class);
    private BaseSingleActionAgent agent;
    private List<BaseTool> tools;
    private boolean returnIntermediateSteps;
    private Integer maxIterations;
    private Float maxExecutionTime;
    private String earlyStoppingMethod;
    private Object handleParsingErrors;

    /* loaded from: input_file:com/hw/langchain/agents/agent/AgentExecutor$AgentExecutorBuilder.class */
    public static class AgentExecutorBuilder {
        private BaseSingleActionAgent agent;
        private List<BaseTool> tools;
        private boolean returnIntermediateSteps;
        private boolean maxIterations$set;
        private Integer maxIterations$value;
        private Float maxExecutionTime;
        private boolean earlyStoppingMethod$set;
        private String earlyStoppingMethod$value;
        private boolean handleParsingErrors$set;
        private Object handleParsingErrors$value;

        AgentExecutorBuilder() {
        }

        public AgentExecutorBuilder agent(BaseSingleActionAgent baseSingleActionAgent) {
            this.agent = baseSingleActionAgent;
            return this;
        }

        public AgentExecutorBuilder tools(List<BaseTool> list) {
            this.tools = list;
            return this;
        }

        public AgentExecutorBuilder returnIntermediateSteps(boolean z) {
            this.returnIntermediateSteps = z;
            return this;
        }

        public AgentExecutorBuilder maxIterations(Integer num) {
            this.maxIterations$value = num;
            this.maxIterations$set = true;
            return this;
        }

        public AgentExecutorBuilder maxExecutionTime(Float f) {
            this.maxExecutionTime = f;
            return this;
        }

        public AgentExecutorBuilder earlyStoppingMethod(String str) {
            this.earlyStoppingMethod$value = str;
            this.earlyStoppingMethod$set = true;
            return this;
        }

        public AgentExecutorBuilder handleParsingErrors(Object obj) {
            this.handleParsingErrors$value = obj;
            this.handleParsingErrors$set = true;
            return this;
        }

        public AgentExecutor build() {
            Integer num = this.maxIterations$value;
            if (!this.maxIterations$set) {
                num = AgentExecutor.$default$maxIterations();
            }
            String str = this.earlyStoppingMethod$value;
            if (!this.earlyStoppingMethod$set) {
                str = AgentExecutor.$default$earlyStoppingMethod();
            }
            Object obj = this.handleParsingErrors$value;
            if (!this.handleParsingErrors$set) {
                obj = AgentExecutor.$default$handleParsingErrors();
            }
            return new AgentExecutor(this.agent, this.tools, this.returnIntermediateSteps, num, this.maxExecutionTime, str, obj);
        }

        public String toString() {
            return "AgentExecutor.AgentExecutorBuilder(agent=" + this.agent + ", tools=" + this.tools + ", returnIntermediateSteps=" + this.returnIntermediateSteps + ", maxIterations$value=" + this.maxIterations$value + ", maxExecutionTime=" + this.maxExecutionTime + ", earlyStoppingMethod$value=" + this.earlyStoppingMethod$value + ", handleParsingErrors$value=" + this.handleParsingErrors$value + ")";
        }
    }

    public static AgentExecutor fromAgentAndTools(BaseSingleActionAgent baseSingleActionAgent, List<BaseTool> list) {
        return builder().agent(baseSingleActionAgent).tools(list).build();
    }

    @Override // com.hw.langchain.chains.base.Chain
    public String chainType() {
        return null;
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> inputKeys() {
        return this.agent.inputKeys();
    }

    @Override // com.hw.langchain.chains.base.Chain
    public List<String> outputKeys() {
        return this.agent.returnValues();
    }

    public Map<String, String> processOutput(AgentFinish agentFinish, List<Pair<AgentAction, String>> list) {
        Map<String, String> returnValues = agentFinish.getReturnValues();
        if (this.returnIntermediateSteps) {
            returnValues.put("intermediate_steps", list.toString());
        }
        return returnValues;
    }

    public Object takeNextStep(Map<String, BaseTool> map, Map<String, Object> map2, List<Pair<AgentAction, String>> list) {
        String obj;
        AgentResult plan = this.agent.plan(list, map2);
        LOG.info("Plan output: {}", plan);
        if (plan instanceof AgentFinish) {
            return plan;
        }
        if (!(plan instanceof AgentAction)) {
            return null;
        }
        AgentAction agentAction = (AgentAction) plan;
        if (map.containsKey(agentAction.getTool())) {
            BaseTool baseTool = map.get(agentAction.getTool());
            boolean isReturnDirect = baseTool.isReturnDirect();
            Map<String, Object> map3 = this.agent.toolRunLoggingKwargs();
            if (isReturnDirect) {
                map3.put("llm_prefix", "");
            }
            obj = baseTool.run(agentAction.getToolInput(), map3).toString();
            LOG.info("Observation: {}", obj);
        } else {
            obj = new InvalidTool().run(agentAction.getTool(), this.agent.toolRunLoggingKwargs()).toString();
        }
        return List.of(Pair.of(agentAction, obj));
    }

    @Override // com.hw.langchain.chains.base.Chain
    public Map<String, String> innerCall(Map<String, Object> map) {
        AgentFinish toolReturn;
        Map<String, BaseTool> map2 = (Map) this.tools.stream().collect(Collectors.toMap((v0) -> {
            return v0.getName();
        }, baseTool -> {
            return baseTool;
        }));
        List<Pair<AgentAction, String>> arrayList = new ArrayList<>();
        int i = 0;
        long currentTimeMillis = System.currentTimeMillis();
        for (double d = 0.0d; shouldContinue(i, d); d = (System.currentTimeMillis() - currentTimeMillis) / 1000.0d) {
            Object takeNextStep = takeNextStep(map2, map, arrayList);
            LOG.info("NextStepOutput: {}", takeNextStep);
            if (takeNextStep instanceof AgentFinish) {
                return processOutput((AgentFinish) takeNextStep, arrayList);
            }
            List list = (List) takeNextStep;
            arrayList.addAll(list);
            if (list.size() == 1 && (toolReturn = getToolReturn((Pair) list.get(0))) != null) {
                return processOutput(toolReturn, arrayList);
            }
            i++;
        }
        return processOutput(this.agent.returnStoppedResponse(this.earlyStoppingMethod, arrayList, map), arrayList);
    }

    private boolean shouldContinue(int i, double d) {
        if (this.maxIterations == null || i < this.maxIterations.intValue()) {
            return this.maxExecutionTime == null || d < ((double) this.maxExecutionTime.floatValue());
        }
        return false;
    }

    public AgentFinish getToolReturn(Pair<AgentAction, String> pair) {
        AgentAction agentAction = (AgentAction) pair.getKey();
        String str = (String) pair.getValue();
        Map map = (Map) this.tools.stream().collect(Collectors.toMap((v0) -> {
            return v0.getName();
        }, baseTool -> {
            return baseTool;
        }));
        if (!map.containsKey(agentAction.getTool()) || !((BaseTool) map.get(agentAction.getTool())).isReturnDirect()) {
            return null;
        }
        HashMap hashMap = new HashMap();
        hashMap.put(this.agent.returnValues().get(0), str);
        return new AgentFinish(hashMap, "");
    }

    private static Integer $default$maxIterations() {
        return 15;
    }

    private static String $default$earlyStoppingMethod() {
        return "force";
    }

    private static Object $default$handleParsingErrors() {
        return false;
    }

    AgentExecutor(BaseSingleActionAgent baseSingleActionAgent, List<BaseTool> list, boolean z, Integer num, Float f, String str, Object obj) {
        this.agent = baseSingleActionAgent;
        this.tools = list;
        this.returnIntermediateSteps = z;
        this.maxIterations = num;
        this.maxExecutionTime = f;
        this.earlyStoppingMethod = str;
        this.handleParsingErrors = obj;
    }

    public static AgentExecutorBuilder builder() {
        return new AgentExecutorBuilder();
    }
}
