package smile.clustering;

import java.io.Serializable;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.ARPACK;
import smile.math.matrix.Matrix;

/* loaded from: classes5.dex */
public class SpectralClustering extends PartitionClustering implements Serializable {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) SpectralClustering.class);
    private static final long serialVersionUID = 2;
    public final double distortion;

    public SpectralClustering(double d, int i, int[] iArr) {
        super(i, iArr);
        this.distortion = d;
    }

    public static SpectralClustering fit(Matrix matrix, int i) {
        return fit(matrix, i, 100, 1.0E-4d);
    }

    public static SpectralClustering fit(Matrix matrix, int i, int i2, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        int nrows = matrix.nrows();
        double[] colSums = matrix.colSums();
        for (int i3 = 0; i3 < nrows; i3++) {
            double d2 = colSums[i3];
            if (d2 == 0.0d) {
                throw new IllegalArgumentException("Isolated vertex: " + i3);
            }
            colSums[i3] = 1.0d / Math.sqrt(d2);
        }
        for (int i4 = 0; i4 < nrows; i4++) {
            for (int i5 = 0; i5 < i4; i5++) {
                double d3 = colSums[i4] * matrix.get(i4, i5) * colSums[i5];
                matrix.set(i4, i5, d3);
                matrix.set(i5, i4, d3);
            }
        }
        matrix.uplo(UPLO.LOWER);
        double[][] array = ARPACK.syev(matrix, ARPACK.SymmOption.LA, i).Vr.toArray();
        for (int i6 = 0; i6 < nrows; i6++) {
            MathEx.unitize2(array[i6]);
        }
        KMeans fit = KMeans.fit(array, i, i2, d);
        return new SpectralClustering(fit.distortion, i, fit.y);
    }

    public static SpectralClustering fit(double[][] dArr, int i, double d) {
        return fit(dArr, i, d, 100, 1.0E-4d);
    }

    public static SpectralClustering fit(double[][] dArr, int i, double d, int i2, double d2) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid standard deviation of Gaussian kernel: " + d);
        }
        int length = dArr.length;
        double d3 = (-0.5d) / (d * d);
        Matrix matrix = new Matrix(length, length);
        for (int i3 = 0; i3 < length; i3++) {
            for (int i4 = 0; i4 < i3; i4++) {
                double exp = Math.exp(MathEx.squaredDistance(dArr[i3], dArr[i4]) * d3);
                matrix.set(i3, i4, exp);
                matrix.set(i4, i3, exp);
            }
        }
        return fit(matrix, i, i2, d2);
    }

    public static SpectralClustering fit(double[][] dArr, int i, int i2, double d) {
        return fit(dArr, i, i2, d, 100, 1.0E-4d);
    }

    public static SpectralClustering fit(double[][] dArr, int i, final int i2, double d, int i3, double d2) {
        if (i2 < i || i2 >= dArr.length) {
            throw new IllegalArgumentException("Invalid number of random samples: " + i2);
        }
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid standard deviation of Gaussian kernel: " + d);
        }
        final int length = dArr.length;
        final double d3 = (-0.5d) / (d * d);
        int[] permutate = MathEx.permutate(length);
        final double[][] dArr2 = new double[length];
        for (int i4 = 0; i4 < length; i4++) {
            dArr2[i4] = dArr[permutate[i4]];
        }
        final Matrix matrix = new Matrix(length, i2);
        final double[] dArr3 = new double[length];
        IntStream.range(0, length).parallel().forEach(new IntConsumer() { // from class: smile.clustering.SpectralClustering$$ExternalSyntheticLambda0
            @Override // java.util.function.IntConsumer
            public final void accept(int i5) {
                SpectralClustering.lambda$fit$0(length, d3, dArr2, dArr3, i2, matrix, i5);
            }
        });
        for (int i5 = 0; i5 < length; i5++) {
            if (dArr3[i5] < 1.0E-4d) {
                logger.error(String.format("Small D[%d] = %f. The data may contain outliers.", Integer.valueOf(i5), Double.valueOf(dArr3[i5])));
            }
            dArr3[i5] = 1.0d / Math.sqrt(dArr3[i5]);
        }
        for (int i6 = 0; i6 < length; i6++) {
            for (int i7 = 0; i7 < i2; i7++) {
                matrix.set(i6, i7, dArr3[i6] * matrix.get(i6, i7) * dArr3[i7]);
            }
        }
        int i8 = i2 - 1;
        Matrix submatrix = matrix.submatrix(0, 0, i8, i8);
        submatrix.uplo(UPLO.LOWER);
        Matrix.EVD syev = ARPACK.syev(submatrix, ARPACK.SymmOption.LA, i);
        double[] dArr4 = syev.wr;
        double sqrt = Math.sqrt(i2 / length);
        for (int i9 = 0; i9 < i; i9++) {
            double d4 = dArr4[i9];
            if (d4 <= 1.0E-8d) {
                throw new IllegalStateException("Non-positive eigen value: " + dArr4[i9]);
            }
            dArr4[i9] = sqrt / d4;
        }
        Matrix matrix2 = syev.Vr;
        for (int i10 = 0; i10 < i2; i10++) {
            for (int i11 = 0; i11 < i; i11++) {
                matrix2.mul(i10, i11, dArr4[i11]);
            }
        }
        double[][] array = matrix.mm(matrix2).toArray();
        for (int i12 = 0; i12 < length; i12++) {
            MathEx.unitize2(array[i12]);
        }
        KMeans fit = KMeans.fit(array, i, i3, d2);
        int[] iArr = new int[length];
        for (int i13 = 0; i13 < length; i13++) {
            iArr[permutate[i13]] = fit.y[i13];
        }
        return new SpectralClustering(fit.distortion, i, iArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ void lambda$fit$0(int i, double d, double[][] dArr, double[] dArr2, int i2, Matrix matrix, int i3) {
        for (int i4 = 0; i4 < i; i4++) {
            if (i3 != i4) {
                double exp = Math.exp(MathEx.squaredDistance(dArr[i3], dArr[i4]) * d);
                dArr2[i3] = dArr2[i3] + exp;
                if (i4 < i2) {
                    matrix.set(i3, i4, exp);
                }
            }
        }
    }
}
