package com.hw.langchain.vectorstores.utils;

import com.google.common.collect.Lists;
import com.hw.langchain.math.utils.MathUtils;
import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:com/hw/langchain/vectorstores/utils/Utils.class */
public class Utils {
    private Utils() {
    }

    public static List<Integer> maximalMarginalRelevance(INDArray iNDArray, List<List<Float>> list, int i, float f) {
        if (Math.min(i, list.size()) <= 0) {
            return new ArrayList();
        }
        if (iNDArray.rank() == 1) {
            iNDArray = Nd4j.expandDims(iNDArray, 0);
        }
        INDArray row = MathUtils.cosineSimilarity(iNDArray, list).getRow(0L);
        int i2 = Nd4j.argMax(row, new int[0]).getInt(new int[]{0});
        ArrayList newArrayList = Lists.newArrayList(new Integer[]{Integer.valueOf(i2)});
        INDArray createFromList = Nd4jUtils.createFromList(list.get(i2));
        while (true) {
            INDArray iNDArray2 = createFromList;
            if (newArrayList.size() >= Math.min(i, list.size())) {
                return newArrayList;
            }
            float f2 = Float.NEGATIVE_INFINITY;
            int i3 = -1;
            INDArray cosineSimilarity = MathUtils.cosineSimilarity(list, iNDArray2);
            for (int i4 = 0; i4 < row.columns(); i4++) {
                if (!newArrayList.contains(Integer.valueOf(i4))) {
                    float f3 = (f * row.getFloat(i4)) - ((1.0f - f) * Transforms.max(cosineSimilarity.getRow(i4), 0.0d).getFloat(0L));
                    if (f3 > f2) {
                        f2 = f3;
                        i3 = i4;
                    }
                }
            }
            newArrayList.add(Integer.valueOf(i3));
            createFromList = Nd4j.vstack(new INDArray[]{iNDArray2, Nd4jUtils.createFromList(list.get(i3))});
        }
    }
}
