package smile.manifold;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.function.IntConsumer;
import java.util.function.IntToDoubleFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;

/* loaded from: classes5.dex */
public class TSNE implements Serializable {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) TSNE.class);
    private static final long serialVersionUID = 2;
    private double[][] D;
    private double[][] P;
    private double[][] Q;
    private double Qsum;
    public final double[][] coordinates;
    private double eta;
    private double finalMomentum;
    private double[][] gains;
    private double minGain;
    private double momentum;
    private int momentumSwitchIter;
    private int totalIter;

    public TSNE(double[][] dArr, int i) {
        this(dArr, i, 20.0d, 200.0d, 1000);
    }

    public TSNE(double[][] dArr, int i, double d, double d2, int i2) {
        this.momentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.momentumSwitchIter = 250;
        this.minGain = 0.01d;
        this.totalIter = 1;
        this.eta = d2;
        int length = dArr.length;
        if (dArr.length == dArr[0].length) {
            this.D = dArr;
        } else {
            double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, length);
            this.D = dArr2;
            MathEx.pdist(dArr, dArr2, new TSNE$$ExternalSyntheticLambda1());
        }
        double[][] dArr3 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, i);
        this.coordinates = dArr3;
        this.gains = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, i);
        GaussianDistribution gaussianDistribution = new GaussianDistribution(0.0d, 1.0E-4d);
        for (int i3 = 0; i3 < length; i3++) {
            Arrays.fill(this.gains[i3], 1.0d);
            double[] dArr4 = dArr3[i3];
            for (int i4 = 0; i4 < i; i4++) {
                dArr4[i4] = gaussianDistribution.rand();
            }
        }
        this.P = expd(this.D, d, 0.001d);
        this.Q = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, length);
        double d3 = length * 2;
        for (int i5 = 0; i5 < length; i5++) {
            double[] dArr5 = this.P[i5];
            for (int i6 = 0; i6 < i5; i6++) {
                double d4 = ((dArr5[i6] + this.P[i6][i5]) * 12.0d) / d3;
                if (Double.isNaN(d4) || d4 < 1.0E-16d) {
                    d4 = 1.0E-16d;
                }
                dArr5[i6] = d4;
                this.P[i6][i5] = d4;
            }
        }
        update(i2);
    }

    private double computeQ(final double[][] dArr, final double[][] dArr2) {
        final int length = dArr.length;
        return IntStream.range(0, length).parallel().mapToDouble(new IntToDoubleFunction() { // from class: smile.manifold.TSNE$$ExternalSyntheticLambda2
            @Override // java.util.function.IntToDoubleFunction
            public final double applyAsDouble(int i) {
                return TSNE.lambda$computeQ$5(dArr, dArr2, length, i);
            }
        }).sum();
    }

    private double[][] expd(final double[][] dArr, final double d, final double d2) {
        final int length = dArr.length;
        final double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, length);
        final double[] rowSums = MathEx.rowSums(dArr);
        IntStream.range(0, length).parallel().forEach(new IntConsumer() { // from class: smile.manifold.TSNE$$ExternalSyntheticLambda0
            @Override // java.util.function.IntConsumer
            public final void accept(int i) {
                TSNE.lambda$expd$4(d, dArr2, dArr, length, rowSums, d2, i);
            }
        });
        return dArr2;
    }

    public static /* synthetic */ double lambda$computeQ$5(double[][] dArr, double[][] dArr2, int i, int i2) {
        double[] dArr3 = dArr[i2];
        double[] dArr4 = dArr2[i2];
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            double squaredDistance = 1.0d / (MathEx.squaredDistance(dArr3, dArr[i3]) + 1.0d);
            dArr4[i3] = squaredDistance;
            d += squaredDistance;
        }
        return d;
    }

    public static /* synthetic */ void lambda$expd$4(double d, double[][] dArr, double[][] dArr2, int i, double[] dArr3, double d2, int i2) {
        int i3 = i;
        double log2 = MathEx.log2(d);
        double[] dArr4 = dArr[i2];
        double[] dArr5 = dArr2[i2];
        double sqrt = Math.sqrt((i3 - 1) / dArr3[i2]);
        logger.debug("initial beta[{}] = {}", Integer.valueOf(i2), Double.valueOf(sqrt));
        double d3 = Double.POSITIVE_INFINITY;
        double d4 = Double.MAX_VALUE;
        int i4 = 0;
        double d5 = 0.0d;
        while (Math.abs(d4) > d2 && i4 < 50) {
            int i5 = 0;
            double d6 = 0.0d;
            double d7 = 0.0d;
            while (i5 < i3) {
                double d8 = sqrt * dArr5[i5];
                double exp = Math.exp(-d8);
                dArr4[i5] = exp;
                d6 += exp;
                d7 += exp * d8;
                i5++;
                i4 = i4;
            }
            int i6 = i4;
            dArr4[i2] = 0.0d;
            double d9 = d6 - 1.0d;
            double log22 = MathEx.log2(d9) + (d7 / d9);
            double d10 = log22 - log2;
            if (Math.abs(d10) <= d2) {
                for (int i7 = 0; i7 < i3; i7++) {
                    dArr4[i7] = dArr4[i7] / d9;
                }
            } else if (d10 <= 0.0d) {
                double d11 = sqrt;
                sqrt = (sqrt + d5) / 2.0d;
                d3 = d11;
            } else if (Double.isInfinite(d3)) {
                d5 = sqrt;
                sqrt = 2.0d * sqrt;
            } else {
                double d12 = sqrt;
                sqrt = (sqrt + d3) / 2.0d;
                d5 = d12;
            }
            logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", Double.valueOf(d10), Integer.valueOf(i2), Double.valueOf(sqrt), Double.valueOf(log22), Double.valueOf(log2));
            i4 = i6 + 1;
            i3 = i;
            d4 = d10;
        }
    }

    public static /* synthetic */ void lambda$update$3(double[][] dArr, int i, double[] dArr2, int i2) {
        double[] dArr3 = dArr[i2];
        for (int i3 = 0; i3 < i; i3++) {
            dArr3[i3] = dArr3[i3] - dArr2[i3];
        }
    }

    private void sne(int i, double[] dArr, double[] dArr2) {
        int i2;
        double[][] dArr3 = this.coordinates;
        int length = dArr3.length;
        int length2 = dArr3[0].length;
        double[] dArr4 = dArr3[i];
        double[] dArr5 = this.P[i];
        double[] dArr6 = this.Q[i];
        double[] dArr7 = this.gains[i];
        Arrays.fill(dArr2, 0.0d);
        int i3 = 0;
        while (i3 < length) {
            if (i != i3) {
                double[] dArr8 = dArr3[i3];
                double d = dArr6[i3];
                double d2 = (dArr5[i3] - (d / this.Qsum)) * d;
                i2 = length2;
                for (int i4 = 0; i4 < i2; i4++) {
                    dArr2[i4] = dArr2[i4] + ((dArr4[i4] - dArr8[i4]) * 4.0d * d2);
                }
            } else {
                i2 = length2;
            }
            i3++;
            length2 = i2;
        }
        int i5 = length2;
        for (int i6 = 0; i6 < i5; i6++) {
            double d3 = Math.signum(dArr2[i6]) != Math.signum(dArr[i6]) ? dArr7[i6] + 0.2d : dArr7[i6] * 0.8d;
            dArr7[i6] = d3;
            double d4 = this.minGain;
            if (d3 < d4) {
                dArr7[i6] = d4;
            }
        }
    }

    /* renamed from: lambda$update$0$smile-manifold-TSNE */
    public /* synthetic */ void m10012lambda$update$0$smilemanifoldTSNE(double[][] dArr, double[][] dArr2, int i) {
        sne(i, dArr[i], dArr2[i]);
    }

    /* renamed from: lambda$update$1$smile-manifold-TSNE */
    public /* synthetic */ void m10013lambda$update$1$smilemanifoldTSNE(double[][] dArr, double[][] dArr2, double[][] dArr3, int i, int i2) {
        double[] dArr4 = dArr[i2];
        double[] dArr5 = dArr2[i2];
        double[] dArr6 = dArr3[i2];
        double[] dArr7 = this.gains[i2];
        for (int i3 = 0; i3 < i; i3++) {
            double d = (this.momentum * dArr5[i3]) - ((this.eta * dArr7[i3]) * dArr6[i3]);
            dArr5[i3] = d;
            dArr4[i3] = dArr4[i3] + d;
        }
    }

    /* renamed from: lambda$update$2$smile-manifold-TSNE */
    public /* synthetic */ double m10014lambda$update$2$smilemanifoldTSNE(int i) {
        double[] dArr = this.P[i];
        double[] dArr2 = this.Q[i];
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            double d2 = dArr[i2];
            double d3 = dArr2[i2] / this.Qsum;
            if (Double.isNaN(d3) || d3 < 1.0E-16d) {
                d3 = 1.0E-16d;
            }
            d += d2 * MathEx.log2(d2 / d3);
        }
        return d;
    }

    public void update(int i) {
        final double[][] dArr = this.coordinates;
        int length = dArr.length;
        int i2 = 0;
        final int length2 = dArr[0].length;
        final double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, length2);
        final double[][] dArr3 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, length2);
        int i3 = 1;
        while (i3 <= i) {
            this.Qsum = computeQ(dArr, this.Q);
            IntStream.range(i2, length).parallel().forEach(new IntConsumer() { // from class: smile.manifold.TSNE$$ExternalSyntheticLambda3
                @Override // java.util.function.IntConsumer
                public final void accept(int i4) {
                    TSNE.this.m10012lambda$update$0$smilemanifoldTSNE(dArr2, dArr3, i4);
                }
            });
            IntStream.range(i2, length).parallel().forEach(new IntConsumer() { // from class: smile.manifold.TSNE$$ExternalSyntheticLambda4
                @Override // java.util.function.IntConsumer
                public final void accept(int i4) {
                    TSNE.this.m10013lambda$update$1$smilemanifoldTSNE(dArr, dArr2, dArr3, length2, i4);
                }
            });
            if (this.totalIter == this.momentumSwitchIter) {
                this.momentum = this.finalMomentum;
                for (int i4 = 0; i4 < length; i4++) {
                    double[] dArr4 = this.P[i4];
                    for (int i5 = 0; i5 < length; i5++) {
                        dArr4[i5] = dArr4[i5] / 12.0d;
                    }
                }
            }
            if (i3 % 100 == 0) {
                logger.info("Error after {} iterations: {}", Integer.valueOf(this.totalIter), Double.valueOf(IntStream.range(0, length).parallel().mapToDouble(new IntToDoubleFunction() { // from class: smile.manifold.TSNE$$ExternalSyntheticLambda5
                    @Override // java.util.function.IntToDoubleFunction
                    public final double applyAsDouble(int i6) {
                        return TSNE.this.m10014lambda$update$2$smilemanifoldTSNE(i6);
                    }
                }).sum() * 2.0d));
            }
            i3++;
            this.totalIter++;
            i2 = 0;
        }
        final double[] colMeans = MathEx.colMeans(dArr);
        IntStream.range(0, length).parallel().forEach(new IntConsumer() { // from class: smile.manifold.TSNE$$ExternalSyntheticLambda6
            @Override // java.util.function.IntConsumer
            public final void accept(int i6) {
                TSNE.lambda$update$3(dArr, length2, colMeans, i6);
            }
        });
    }
}
