package smile.manifold;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Collection;
import java.util.Iterator;
import java.util.function.Consumer;
import java.util.function.DoublePredicate;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.ToDoubleFunction;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.graph.AdjacencyList;
import smile.graph.Graph;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.LevenbergMarquardt;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.matrix.ARPACK;
import smile.math.matrix.Matrix;
import smile.math.matrix.SparseMatrix;
import smile.stat.distribution.GaussianDistribution;

/* loaded from: classes5.dex */
public class UMAP implements Serializable {
    private static final long serialVersionUID = 2;
    public final double[][] coordinates;
    public final AdjacencyList graph;
    public final int[] index;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) UMAP.class);
    private static DifferentiableMultivariateFunction func = new DifferentiableMultivariateFunction() { // from class: smile.manifold.UMAP.1
        @Override // smile.math.MultivariateFunction
        public double f(double[] dArr) {
            return 1.0d / ((dArr[0] * Math.pow(dArr[2], dArr[1])) + 1.0d);
        }

        @Override // smile.math.DifferentiableMultivariateFunction
        public double g(double[] dArr, double[] dArr2) {
            double pow = Math.pow(dArr[2], dArr[1]);
            double d = (dArr[0] * pow) + 1.0d;
            double d2 = d * d;
            dArr2[0] = (-pow) / d2;
            dArr2[1] = (-(((dArr[0] * dArr[1]) * Math.log(dArr[2])) * pow)) / d2;
            return 1.0d / d;
        }
    };

    public UMAP(int[] iArr, double[][] dArr, AdjacencyList adjacencyList) {
        this.index = iArr;
        this.coordinates = dArr;
        this.graph = adjacencyList;
    }

    private static double clamp(double d) {
        if (d > 4.0d) {
            return 4.0d;
        }
        if (d < -4.0d) {
            return -4.0d;
        }
        return d;
    }

    private static SparseMatrix computeEpochPerSample(SparseMatrix sparseMatrix, int i) {
        final double orElse = sparseMatrix.nonzeros().mapToDouble(new ToDoubleFunction() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda2
            @Override // java.util.function.ToDoubleFunction
            public final double applyAsDouble(Object obj) {
                double d;
                d = ((SparseMatrix.Entry) obj).x;
                return d;
            }
        }).max().orElse(0.0d);
        final double d = orElse / i;
        sparseMatrix.nonzeros().forEach(new Consumer() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda3
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                UMAP.lambda$computeEpochPerSample$10(d, orElse, (SparseMatrix.Entry) obj);
            }
        });
        return sparseMatrix;
    }

    private static AdjacencyList computeFuzzySimplicialSet(final AdjacencyList adjacencyList, int i, int i2) {
        double d;
        double d2;
        double d3;
        double d4;
        double log2 = MathEx.log2(i);
        int numVertices = adjacencyList.getNumVertices();
        double[] dArr = new double[numVertices];
        double[] dArr2 = new double[numVertices];
        int i3 = 0;
        double d5 = 0.0d;
        double orElse = IntStream.range(0, numVertices).mapToObj(new IntFunction() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda5
            @Override // java.util.function.IntFunction
            public final Object apply(int i4) {
                Collection edges;
                edges = AdjacencyList.this.getEdges(i4);
                return edges;
            }
        }).flatMapToDouble(new Function() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda6
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                DoubleStream mapToDouble;
                mapToDouble = ((Collection) obj).stream().mapToDouble(new ToDoubleFunction() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda0
                    @Override // java.util.function.ToDoubleFunction
                    public final double applyAsDouble(Object obj2) {
                        double d6;
                        d6 = ((Graph.Edge) obj2).weight;
                        return d6;
                    }
                });
                return mapToDouble;
            }
        }).filter(new DoublePredicate() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda7
            @Override // java.util.function.DoublePredicate
            public final boolean test(double d6) {
                return UMAP.lambda$computeFuzzySimplicialSet$3(d6);
            }
        }).average().orElse(0.0d);
        int i4 = 0;
        while (i4 < numVertices) {
            Collection<Graph.Edge> edges = adjacencyList.getEdges(i4);
            dArr2[i4] = edges.stream().mapToDouble(new ToDoubleFunction() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda8
                @Override // java.util.function.ToDoubleFunction
                public final double applyAsDouble(Object obj) {
                    double d6;
                    d6 = ((Graph.Edge) obj).weight;
                    return d6;
                }
            }).filter(new DoublePredicate() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda9
                @Override // java.util.function.DoublePredicate
                public final boolean test(double d6) {
                    return UMAP.lambda$computeFuzzySimplicialSet$5(d6);
                }
            }).min().orElse(d5);
            double d6 = Double.POSITIVE_INFINITY;
            int i5 = i3;
            double d7 = d5;
            double d8 = 1.0d;
            while (true) {
                if (i5 >= i2) {
                    d = orElse;
                    d2 = d8;
                    break;
                }
                double d9 = d5;
                for (Graph.Edge edge : edges) {
                    double d10 = orElse;
                    if (MathEx.isZero(edge.weight, 1.0E-8d)) {
                        d3 = d8;
                    } else {
                        double d11 = edge.weight - dArr2[i4];
                        if (d11 > 0.0d) {
                            d3 = d8;
                            d4 = Math.exp((-d11) / d3);
                        } else {
                            d3 = d8;
                            d4 = 1.0d;
                        }
                        d9 += d4;
                    }
                    d8 = d3;
                    orElse = d10;
                }
                d = orElse;
                d2 = d8;
                if (Math.abs(d9 - log2) < 1.0E-5d) {
                    break;
                }
                if (d9 > log2) {
                    d8 = (d7 + d2) / 2.0d;
                    d6 = d2;
                } else {
                    d8 = Double.isInfinite(d6) ? d2 * 2.0d : (d2 + d6) / 2.0d;
                    d7 = d2;
                }
                i5++;
                orElse = d;
                d5 = 0.0d;
            }
            dArr[i4] = d2;
            if (dArr2[i4] > 0.0d) {
                dArr[i4] = Math.max(dArr[i4], edges.stream().mapToDouble(new ToDoubleFunction() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda10
                    @Override // java.util.function.ToDoubleFunction
                    public final double applyAsDouble(Object obj) {
                        double d12;
                        d12 = ((Graph.Edge) obj).weight;
                        return d12;
                    }
                }).filter(new DoublePredicate() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda1
                    @Override // java.util.function.DoublePredicate
                    public final boolean test(double d12) {
                        return UMAP.lambda$computeFuzzySimplicialSet$7(d12);
                    }
                }).average().orElse(0.0d) * 0.001d);
            } else {
                dArr[i4] = Math.max(d2, 0.001d * d);
            }
            i4++;
            orElse = d;
            i3 = 0;
            d5 = 0.0d;
        }
        for (int i6 = 0; i6 < numVertices; i6++) {
            for (Graph.Edge edge2 : adjacencyList.getEdges(i6)) {
                edge2.weight = Math.exp((-Math.max(0.0d, edge2.weight - dArr2[i6])) / dArr[i6]);
            }
        }
        AdjacencyList adjacencyList2 = new AdjacencyList(numVertices, false);
        for (int i7 = 0; i7 < numVertices; i7++) {
            for (Graph.Edge edge3 : adjacencyList.getEdges(i7)) {
                double d12 = edge3.weight;
                double weight = adjacencyList.getWeight(edge3.v2, edge3.v1);
                adjacencyList2.setWeight(edge3.v1, edge3.v2, (d12 + weight) - (d12 * weight));
            }
        }
        return adjacencyList2;
    }

    private static double[] fitCurve(double d, double d2) {
        double[] dArr = new double[300];
        double[] dArr2 = new double[300];
        double d3 = (3.0d * d) / 300;
        int i = 0;
        while (i < 300) {
            int i2 = i + 1;
            double d4 = i2 * d3;
            dArr[i] = d4;
            dArr2[i] = d4 < d2 ? 1.0d : Math.exp((-(d4 - d2)) / d);
            i = i2;
        }
        return LevenbergMarquardt.fit(func, dArr, dArr2, new double[]{0.5d, 0.0d}).parameters;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ void lambda$computeEpochPerSample$10(double d, double d2, SparseMatrix.Entry entry) {
        if (entry.x < d) {
            entry.update(0.0d);
        } else {
            entry.update(d2 / entry.x);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ boolean lambda$computeFuzzySimplicialSet$3(double d) {
        return !MathEx.isZero(d, 1.0E-8d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ boolean lambda$computeFuzzySimplicialSet$5(double d) {
        return !MathEx.isZero(d, 1.0E-8d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ boolean lambda$computeFuzzySimplicialSet$7(double d) {
        return !MathEx.isZero(d, 1.0E-8d);
    }

    public static <T> UMAP of(T[] tArr, Distance<T> distance) {
        return of(tArr, distance, 15);
    }

    public static <T> UMAP of(T[] tArr, Distance<T> distance, int i) {
        return of(tArr, distance, i, 2, tArr.length > 10000 ? 200 : 500, 1.0d, 0.1d, 1.0d, 5, 1.0d);
    }

    public static <T> UMAP of(T[] tArr, Distance<T> distance, int i, int i2, int i3, double d, double d2, double d3, int i4, double d4) {
        if (i2 < 2) {
            throw new IllegalArgumentException("d must be greater than 1: " + i2);
        }
        if (i < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + i);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("minDist must greater than 0: " + d2);
        }
        if (d2 > d3) {
            throw new IllegalArgumentException("minDist must be less than or equal to spread: " + d2 + ",spread=" + d3);
        }
        if (i3 < 10) {
            throw new IllegalArgumentException("epochs must be a positive integer of at least 10: " + i3);
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("learningRate must greater than 0: " + d);
        }
        if (i4 <= 0) {
            throw new IllegalArgumentException("negativeSamples must greater than 0: " + i4);
        }
        NearestNeighborGraph largest = NearestNeighborGraph.largest(NearestNeighborGraph.of(tArr, distance, i, true, null));
        AdjacencyList computeFuzzySimplicialSet = computeFuzzySimplicialSet(largest.graph, i, 64);
        SparseMatrix matrix = computeFuzzySimplicialSet.toMatrix();
        double[][] spectralLayout = spectralLayout(computeFuzzySimplicialSet, i2);
        Logger logger2 = logger;
        logger2.info("Finish initialization with spectral layout");
        double[] fitCurve = fitCurve(d3, d2);
        logger2.info("Finish fitting the curve parameters");
        SparseMatrix computeEpochPerSample = computeEpochPerSample(matrix, i3);
        logger2.info("Start optimizing the layout");
        optimizeLayout(spectralLayout, fitCurve, computeEpochPerSample, i3, d, i4, d4);
        return new UMAP(largest.index, spectralLayout, computeFuzzySimplicialSet);
    }

    public static UMAP of(double[][] dArr) {
        return of(dArr, 15);
    }

    public static UMAP of(double[][] dArr, int i) {
        return of(dArr, new EuclideanDistance(), i);
    }

    public static UMAP of(double[][] dArr, int i, int i2, int i3, double d, double d2, double d3, int i4, double d4) {
        return of(dArr, new EuclideanDistance(), i, i2, i3, d, d2, d3, i4, d4);
    }

    private static void optimizeLayout(double[][] dArr, double[] dArr2, SparseMatrix sparseMatrix, int i, double d, final int i2, double d2) {
        int i3;
        double d3;
        SparseMatrix sparseMatrix2;
        SparseMatrix sparseMatrix3;
        SparseMatrix sparseMatrix4;
        Iterator<SparseMatrix.Entry> it;
        double d4;
        int i4;
        int i5;
        double[][] dArr3 = dArr;
        int i6 = i;
        int length = dArr3.length;
        int length2 = dArr3[0].length;
        double d5 = dArr2[0];
        int i7 = 1;
        double d6 = dArr2[1];
        SparseMatrix clone = sparseMatrix.clone();
        clone.nonzeros().forEach(new Consumer() { // from class: smile.manifold.UMAP$$ExternalSyntheticLambda4
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                r2.update(((SparseMatrix.Entry) obj).x / i2);
            }
        });
        SparseMatrix clone2 = clone.clone();
        SparseMatrix clone3 = sparseMatrix.clone();
        double d7 = d;
        while (i7 <= i6) {
            Iterator<SparseMatrix.Entry> it2 = clone3.iterator();
            while (it2.hasNext()) {
                SparseMatrix.Entry next = it2.next();
                int i8 = length;
                if (next.x > 0.0d) {
                    SparseMatrix sparseMatrix5 = clone;
                    SparseMatrix sparseMatrix6 = clone2;
                    double d8 = i7;
                    if (next.x <= d8) {
                        int i9 = next.i;
                        int i10 = next.j;
                        sparseMatrix4 = clone3;
                        int i11 = next.index;
                        it = it2;
                        double[] dArr4 = dArr3[i9];
                        double[] dArr5 = dArr3[i10];
                        int i12 = i9;
                        double squaredDistance = MathEx.squaredDistance(dArr4, dArr5);
                        if (squaredDistance > 0.0d) {
                            d4 = d8;
                            double pow = ((((-2.0d) * d5) * d6) * Math.pow(squaredDistance, d6 - 1.0d)) / ((Math.pow(squaredDistance, d6) * d5) + 1.0d);
                            for (int i13 = 0; i13 < length2; i13++) {
                                double clamp = clamp((dArr4[i13] - dArr5[i13]) * pow) * d7;
                                dArr4[i13] = dArr4[i13] + clamp;
                                dArr5[i13] = dArr5[i13] - clamp;
                            }
                        } else {
                            d4 = d8;
                        }
                        next.update(next.x + sparseMatrix.get(i11));
                        sparseMatrix3 = sparseMatrix6;
                        sparseMatrix2 = sparseMatrix5;
                        int i14 = (int) ((d4 - sparseMatrix3.get(i11)) / sparseMatrix2.get(i11));
                        int i15 = 0;
                        while (i15 < i14) {
                            int randomInt = MathEx.randomInt(i8);
                            int i16 = i12;
                            if (i16 == randomInt) {
                                i4 = i16;
                                i5 = i14;
                            } else {
                                double[] dArr6 = dArr[randomInt];
                                i4 = i16;
                                i5 = i14;
                                double squaredDistance2 = MathEx.squaredDistance(dArr4, dArr6);
                                double pow2 = squaredDistance2 > 0.0d ? ((2.0d * d2) * d6) / ((squaredDistance2 + 0.001d) * ((Math.pow(squaredDistance2, d6) * d5) + 1.0d)) : 0.0d;
                                for (int i17 = 0; i17 < length2; i17++) {
                                    dArr4[i17] = dArr4[i17] + ((pow2 > 0.0d ? clamp((dArr4[i17] - dArr6[i17]) * pow2) : 4.0d) * d7);
                                }
                            }
                            i15++;
                            i14 = i5;
                            i12 = i4;
                        }
                        i3 = length2;
                        d3 = d5;
                        sparseMatrix3.set(i11, sparseMatrix3.get(i11) + (sparseMatrix2.get(i11) * i14));
                        clone2 = sparseMatrix3;
                        clone = sparseMatrix2;
                        length = i8;
                        length2 = i3;
                        clone3 = sparseMatrix4;
                        it2 = it;
                        d5 = d3;
                        dArr3 = dArr;
                    } else {
                        sparseMatrix2 = sparseMatrix5;
                        sparseMatrix3 = sparseMatrix6;
                        i3 = length2;
                        d3 = d5;
                    }
                } else {
                    i3 = length2;
                    d3 = d5;
                    sparseMatrix2 = clone;
                    sparseMatrix3 = clone2;
                }
                sparseMatrix4 = clone3;
                it = it2;
                clone2 = sparseMatrix3;
                clone = sparseMatrix2;
                length = i8;
                length2 = i3;
                clone3 = sparseMatrix4;
                it2 = it;
                d5 = d3;
                dArr3 = dArr;
            }
            logger.info(String.format("The learning rate at %3d iterations: %.5f", Integer.valueOf(i7), Double.valueOf(d7)));
            d7 = d * (1.0d - (i7 / i));
            i7++;
            i6 = i;
            length = length;
            length2 = length2;
            d5 = d5;
            dArr3 = dArr;
        }
    }

    private static double[][] spectralLayout(AdjacencyList adjacencyList, int i) {
        int numVertices = adjacencyList.getNumVertices();
        double[] dArr = new double[numVertices];
        for (int i2 = 0; i2 < numVertices; i2++) {
            Iterator<Graph.Edge> it = adjacencyList.getEdges(i2).iterator();
            while (it.hasNext()) {
                dArr[i2] = dArr[i2] + it.next().weight;
            }
            dArr[i2] = 1.0d / Math.sqrt(dArr[i2]);
        }
        AdjacencyList adjacencyList2 = new AdjacencyList(numVertices, false);
        for (int i3 = 0; i3 < numVertices; i3++) {
            adjacencyList2.setWeight(i3, i3, 1.0d);
            for (Graph.Edge edge : adjacencyList.getEdges(i3)) {
                adjacencyList2.setWeight(edge.v1, edge.v2, (-dArr[edge.v1]) * edge.weight * dArr[edge.v2]);
            }
        }
        Matrix matrix = ARPACK.syev(adjacencyList2.toMatrix(), ARPACK.SymmOption.SM, Math.min((i + 1) * 10, numVertices - 1)).Vr;
        double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, numVertices, i);
        int i4 = i;
        double d = 0.0d;
        while (true) {
            i4--;
            if (i4 < 0) {
                break;
            }
            int ncols = (matrix.ncols() - i4) - 2;
            for (int i5 = 0; i5 < numVertices; i5++) {
                double d2 = matrix.get(i5, ncols);
                dArr2[i5][i4] = d2;
                double abs = Math.abs(d2);
                if (abs > d) {
                    d = abs;
                }
            }
        }
        double d3 = 10.0d / d;
        GaussianDistribution gaussianDistribution = new GaussianDistribution(0.0d, 1.0E-4d);
        for (int i6 = 0; i6 < numVertices; i6++) {
            for (int i7 = 0; i7 < i; i7++) {
                double[] dArr3 = dArr2[i6];
                dArr3[i7] = (dArr3[i7] * d3) + gaussianDistribution.rand();
            }
        }
        double[] colMax = MathEx.colMax(dArr2);
        double[] colMin = MathEx.colMin(dArr2);
        double[] dArr4 = new double[i];
        for (int i8 = 0; i8 < i; i8++) {
            dArr4[i8] = colMax[i8] - colMin[i8];
        }
        for (int i9 = 0; i9 < numVertices; i9++) {
            for (int i10 = 0; i10 < i; i10++) {
                double[] dArr5 = dArr2[i9];
                dArr5[i10] = ((dArr5[i10] - colMin[i10]) * 10.0d) / dArr4[i10];
            }
        }
        return dArr2;
    }
}
