package com.hw.langchain.chains.base;

import com.google.common.collect.Maps;
import com.hw.langchain.schema.BaseMemory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/hw/langchain/chains/base/Chain.class */
public abstract class Chain {
    protected BaseMemory memory;

    public abstract String chainType();

    public abstract List<String> inputKeys();

    public abstract List<String> outputKeys();

    private void validateInputs(Map<String, Object> map) {
        HashSet hashSet = new HashSet(inputKeys());
        hashSet.removeAll(map.keySet());
        if (!hashSet.isEmpty()) {
            throw new IllegalArgumentException(String.format("Missing some input keys: %s", hashSet));
        }
    }

    private void validateOutputs(Map<String, String> map) {
        HashSet hashSet = new HashSet(outputKeys());
        hashSet.removeAll(map.keySet());
        if (!hashSet.isEmpty()) {
            throw new IllegalArgumentException(String.format("Missing some output keys: %s", hashSet));
        }
    }

    public abstract Map<String, String> innerCall(Map<String, Object> map);

    public Map<String, String> call(Object obj, boolean z) {
        return call(prepInputs(obj), z);
    }

    public Map<String, String> call(Map<String, Object> map, boolean z) {
        Map<String, Object> prepInputs = prepInputs(map);
        return prepOutputs(prepInputs, innerCall(prepInputs), z);
    }

    private Map<String, String> prepOutputs(Map<String, Object> map, Map<String, String> map2, boolean z) {
        validateOutputs(map2);
        if (this.memory != null) {
            this.memory.saveContext(map, map2);
        }
        if (z) {
            return map2;
        }
        HashMap newHashMap = Maps.newHashMap();
        map.forEach((str, obj) -> {
            newHashMap.put(str, obj.toString());
        });
        newHashMap.putAll(map2);
        return newHashMap;
    }

    private Map<String, Object> prepInputs(Object obj) {
        HashSet hashSet = new HashSet(inputKeys());
        if (this.memory != null) {
            hashSet.removeAll(new HashSet(this.memory.memoryVariables()));
        }
        if (hashSet.size() != 1) {
            throw new IllegalArgumentException(String.format("A single string input was passed in, but this chain expects multiple inputs (%s). When a chain expects multiple inputs, please call it by passing in a dictionary, eg `chain(Map.of('foo', 1, 'bar', 2))`", hashSet));
        }
        return Map.of((String) new ArrayList(hashSet).get(0), obj);
    }

    public Map<String, Object> prepInputs(Map<String, Object> map) {
        HashMap hashMap = new HashMap(map);
        if (this.memory != null) {
            hashMap.putAll(this.memory.loadMemoryVariables(map));
        }
        validateInputs(hashMap);
        return hashMap;
    }

    public String run(Object obj) {
        if (outputKeys().size() != 1) {
            throw new IllegalArgumentException("The `run` method is not supported when there is not exactly one output key. Got " + outputKeys() + ".");
        }
        return call(obj, false).get(outputKeys().get(0));
    }

    public String run(Map<String, Object> map) {
        if (outputKeys().size() != 1) {
            throw new IllegalArgumentException("The `run` method is not supported when there is not exactly one output key. Got " + outputKeys() + ".");
        }
        return call(map, false).get(outputKeys().get(0));
    }
}
