package com.hw.langchain.math.utils;

import com.hw.langchain.vectorstores.utils.ArrayUtils;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:com/hw/langchain/math/utils/MathUtils.class */
public class MathUtils {
    private MathUtils() {
    }

    public static INDArray cosineSimilarity(List<List<Float>> list, INDArray iNDArray) {
        return cosineSimilarity(Nd4j.createFromArray((Float[][]) ArrayUtils.listToArray(list)), iNDArray);
    }

    public static INDArray cosineSimilarity(INDArray iNDArray, List<List<Float>> list) {
        return cosineSimilarity(iNDArray, Nd4j.createFromArray((Float[][]) ArrayUtils.listToArray(list)));
    }

    public static INDArray cosineSimilarity(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.isEmpty() || iNDArray2.isEmpty()) {
            return Nd4j.create(new float[0][0]);
        }
        if (iNDArray.shape()[1] != iNDArray2.shape()[1]) {
            throw new IllegalArgumentException(String.format("Number of columns in X and Y must be the same. X has shape %s and Y has shape %s.", Arrays.toString(iNDArray.shape()), Arrays.toString(iNDArray2.shape())));
        }
        return Transforms.allCosineSimilarities(iNDArray, iNDArray2, new int[]{iNDArray.rank() - 1});
    }
}
