package smile.vq;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.function.Function;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import smile.clustering.CentroidClustering;
import smile.graph.AdjacencyMatrix;
import smile.graph.Graph;
import smile.math.MathEx;
import smile.math.TimeFunction;
import smile.sort.QuickSort;
import smile.vq.NeuralGas;

/* loaded from: classes5.dex */
public class NeuralGas implements VectorQuantizer {
    private static final long serialVersionUID = 2;
    private TimeFunction alpha;
    private double[] dist;
    private AdjacencyMatrix graph;
    private TimeFunction lifetime;
    private Neuron[] neurons;
    private TimeFunction theta;
    private int t = 0;
    private double eps = 1.0E-7d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes5.dex */
    public static class Neuron implements Serializable {
        public final int i;
        public final double[] w;

        public Neuron(int i, double[] dArr) {
            this.i = i;
            this.w = dArr;
        }
    }

    public NeuralGas(final double[][] dArr, TimeFunction timeFunction, TimeFunction timeFunction2, TimeFunction timeFunction3) {
        this.neurons = (Neuron[]) IntStream.range(0, dArr.length).mapToObj(new IntFunction() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda6
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return NeuralGas.lambda$new$0(dArr, i);
            }
        }).toArray(new IntFunction() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda7
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return NeuralGas.lambda$new$1(i);
            }
        });
        this.alpha = timeFunction;
        this.theta = timeFunction2;
        this.lifetime = timeFunction3;
        this.graph = new AdjacencyMatrix(dArr.length);
        this.dist = new double[dArr.length];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double[][] lambda$neurons$4(int i) {
        return new double[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Neuron lambda$new$0(double[][] dArr, int i) {
        return new Neuron(i, (double[]) dArr[i].clone());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Neuron[] lambda$new$1(int i) {
        return new Neuron[i];
    }

    public static double[][] seed(int i, double[][] dArr) {
        double[][] dArr2 = new double[i];
        CentroidClustering.seed(dArr, dArr2, new int[dArr.length], new ToDoubleBiFunction() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda4
            @Override // java.util.function.ToDoubleBiFunction
            public final double applyAsDouble(Object obj, Object obj2) {
                double squaredDistance;
                squaredDistance = MathEx.squaredDistance((double[]) obj, (double[]) obj2);
                return squaredDistance;
            }
        });
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$quantize$6$smile-vq-NeuralGas, reason: not valid java name */
    public /* synthetic */ void m10043lambda$quantize$6$smilevqNeuralGas(double[] dArr, int i) {
        this.dist[i] = MathEx.distance(this.neurons[i].w, dArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$update$5$smile-vq-NeuralGas, reason: not valid java name */
    public /* synthetic */ void m10044lambda$update$5$smilevqNeuralGas(double[] dArr, int i) {
        this.dist[i] = MathEx.distance(this.neurons[i].w, dArr);
    }

    public Graph network() {
        double apply = this.lifetime.apply(this.t);
        for (int i = 0; i < this.neurons.length; i++) {
            for (Graph.Edge edge : this.graph.getEdges(i)) {
                if (this.t - edge.weight > apply) {
                    this.graph.setWeight(edge.v1, edge.v2, 0.0d);
                }
            }
        }
        return this.graph;
    }

    public double[][] neurons() {
        Arrays.sort(this.neurons, new Comparator() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda0
            @Override // java.util.Comparator
            public final int compare(Object obj, Object obj2) {
                int compare;
                compare = Integer.compare(((NeuralGas.Neuron) obj).i, ((NeuralGas.Neuron) obj2).i);
                return compare;
            }
        });
        return (double[][]) Arrays.stream(this.neurons).map(new Function() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda1
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                double[] dArr;
                dArr = ((NeuralGas.Neuron) obj).w;
                return dArr;
            }
        }).toArray(new IntFunction() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda2
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return NeuralGas.lambda$neurons$4(i);
            }
        });
    }

    @Override // smile.vq.VectorQuantizer
    public double[] quantize(final double[] dArr) {
        IntStream.range(0, this.neurons.length).parallel().forEach(new IntConsumer() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda3
            @Override // java.util.function.IntConsumer
            public final void accept(int i) {
                NeuralGas.this.m10043lambda$quantize$6$smilevqNeuralGas(dArr, i);
            }
        });
        return this.neurons[MathEx.whichMin(this.dist)].w;
    }

    @Override // smile.vq.VectorQuantizer
    public void update(final double[] dArr) {
        Neuron[] neuronArr = this.neurons;
        int length = neuronArr.length;
        int length2 = dArr.length;
        IntStream.range(0, neuronArr.length).parallel().forEach(new IntConsumer() { // from class: smile.vq.NeuralGas$$ExternalSyntheticLambda5
            @Override // java.util.function.IntConsumer
            public final void accept(int i) {
                NeuralGas.this.m10044lambda$update$5$smilevqNeuralGas(dArr, i);
            }
        });
        QuickSort.sort(this.dist, this.neurons);
        double apply = this.alpha.apply(this.t);
        double apply2 = this.theta.apply(this.t);
        for (int i = 0; i < length; i++) {
            double exp = Math.exp((-i) / apply2) * apply;
            if (exp > this.eps) {
                double[] dArr2 = this.neurons[i].w;
                for (int i2 = 0; i2 < length2; i2++) {
                    double d = dArr2[i2];
                    dArr2[i2] = d + ((dArr[i2] - d) * exp);
                }
            }
        }
        this.graph.setWeight(this.neurons[0].i, this.neurons[1].i, this.t);
        this.t++;
    }
}
