package smile.clustering;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.IntConsumer;
import java.util.function.ToDoubleFunction;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.sort.QuickSort;

/* loaded from: classes5.dex */
public class XMeans extends CentroidClustering<double[], double[]> {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) XMeans.class);
    private static final double LOG2PI = Math.log(6.283185307179586d);

    public XMeans(double d, double[][] dArr, int[] iArr) {
        super(d, dArr, iArr);
    }

    private static double bic(int i, int i2, double d) {
        return (((((-i) * LOG2PI) + ((r1 * i2) * Math.log(d / (i - 1)))) + (-r0)) / 2.0d) - (((i2 + 1) * 0.5d) * Math.log(i));
    }

    private static double bic(int i, int i2, int i3, double d, int[] iArr) {
        double d2 = d / (i2 - i);
        double d3 = 0.0d;
        for (int i4 = 0; i4 < i; i4++) {
            d3 += logLikelihood(i, i2, iArr[i4], i3, d2);
        }
        return d3 - (((i + (i3 * i)) * 0.5d) * Math.log(i2));
    }

    public static XMeans fit(double[][] dArr, int i) {
        return fit(dArr, i, 100, 1.0E-4d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static XMeans fit(final double[][] dArr, int i, int i2, double d) {
        double[][] dArr2;
        int[] iArr;
        int[] iArr2;
        int i3;
        int[] iArr3;
        KMeans[] kMeansArr;
        ArrayList arrayList;
        double[][] dArr3 = dArr;
        int i4 = i2;
        if (i < 2) {
            throw new IllegalArgumentException("Invalid parameter kmax = " + i);
        }
        int length = dArr3.length;
        int length2 = dArr3[0].length;
        int[] iArr4 = new int[i];
        iArr4[0] = length;
        final int[] iArr5 = new int[length];
        double[][] dArr4 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, i, length2);
        final double[] colMeans = MathEx.colMeans(dArr);
        double[][] dArr5 = {colMeans};
        double sum = ((Stream) Arrays.stream(dArr).parallel()).mapToDouble(new ToDoubleFunction() { // from class: smile.clustering.XMeans$$ExternalSyntheticLambda0
            @Override // java.util.function.ToDoubleFunction
            public final double applyAsDouble(Object obj) {
                double squaredDistance;
                squaredDistance = MathEx.squaredDistance((double[]) obj, colMeans);
                return squaredDistance;
            }
        }).sum();
        double[] dArr6 = new double[i];
        dArr6[0] = sum;
        BBDTree bBDTree = new BBDTree(dArr3);
        KMeans[] kMeansArr2 = new KMeans[i];
        ArrayList arrayList2 = new ArrayList();
        double d2 = sum;
        int i5 = 1;
        while (true) {
            if (i5 >= i) {
                dArr2 = dArr5;
                iArr = iArr5;
                break;
            }
            arrayList2.clear();
            double[][] dArr7 = dArr4;
            double[] dArr8 = new double[i5];
            final double[] dArr9 = dArr6;
            dArr2 = dArr5;
            BBDTree bBDTree2 = bBDTree;
            int i6 = 0;
            while (i6 < i5) {
                int i7 = iArr4[i6];
                if (i7 < 25) {
                    iArr2 = iArr4;
                    logger.info("Cluster {} too small to split: {} observations", Integer.valueOf(i6), Integer.valueOf(i7));
                    dArr8[i6] = 0.0d;
                    kMeansArr2[i6] = null;
                    iArr3 = iArr5;
                    i3 = length;
                    kMeansArr = kMeansArr2;
                    arrayList = arrayList2;
                } else {
                    iArr2 = iArr4;
                    ArrayList arrayList3 = arrayList2;
                    double[][] dArr10 = new double[i7];
                    int i8 = 0;
                    int i9 = 0;
                    while (i8 < length) {
                        int i10 = length;
                        if (iArr5[i8] == i6) {
                            dArr10[i9] = dArr3[i8];
                            i9++;
                        }
                        i8++;
                        length = i10;
                    }
                    i3 = length;
                    KMeans fit = KMeans.fit(dArr10, 2, i4, d);
                    kMeansArr2[i6] = fit;
                    iArr3 = iArr5;
                    kMeansArr = kMeansArr2;
                    arrayList = arrayList3;
                    double bic = bic(2, i7, length2, fit.distortion, kMeansArr2[i6].size);
                    double bic2 = bic(i7, length2, dArr9[i6]);
                    dArr8[i6] = bic - bic2;
                    logger.info(String.format("Cluster %3d BIC: %12.4f, BIC after split: %12.4f, improvement: %12.4f", Integer.valueOf(i6), Double.valueOf(bic2), Double.valueOf(bic), Double.valueOf(dArr8[i6])));
                }
                i6++;
                arrayList2 = arrayList;
                iArr5 = iArr3;
                iArr4 = iArr2;
                length = i3;
                kMeansArr2 = kMeansArr;
            }
            iArr = iArr5;
            int[] iArr6 = iArr4;
            final int i11 = length;
            KMeans[] kMeansArr3 = kMeansArr2;
            final ArrayList arrayList4 = arrayList2;
            int[] sort = QuickSort.sort(dArr8);
            for (int i12 = 0; i12 < i5; i12++) {
                if (dArr8[i12] <= 0.0d) {
                    arrayList4.add(dArr2[sort[i12]]);
                }
            }
            int size = arrayList4.size();
            int i13 = i5;
            while (true) {
                i13--;
                if (i13 < 0) {
                    break;
                }
                if (dArr8[i13] > 0.0d) {
                    if (((arrayList4.size() + i13) - size) + 1 < i) {
                        logger.info("Split cluster {}", Integer.valueOf(sort[i13]));
                        arrayList4.add(((double[][]) kMeansArr3[sort[i13]].centroids)[0]);
                        arrayList4.add(((double[][]) kMeansArr3[sort[i13]].centroids)[1]);
                    } else {
                        arrayList4.add(dArr2[sort[i13]]);
                    }
                }
            }
            if (arrayList4.size() == i5) {
                logger.info("No more split. Finish with {} clusters", Integer.valueOf(i5));
                break;
            }
            i5 = arrayList4.size();
            double[][] dArr11 = (double[][]) arrayList4.toArray(new double[i5]);
            double d3 = Double.MAX_VALUE;
            int i14 = 1;
            while (i14 <= i4 && d3 > d) {
                double clustering = bBDTree2.clustering(dArr11, dArr7, iArr6, iArr);
                double d4 = d2 - clustering;
                i14++;
                d2 = clustering;
                d3 = d4;
            }
            iArr5 = iArr;
            Arrays.fill(dArr9, 0.0d);
            IntStream.range(0, i5).parallel().forEach(new IntConsumer() { // from class: smile.clustering.XMeans$$ExternalSyntheticLambda1
                @Override // java.util.function.IntConsumer
                public final void accept(int i15) {
                    XMeans.lambda$fit$1(arrayList4, i11, iArr5, dArr9, dArr, i15);
                }
            });
            logger.info(String.format("Distortion with %d clusters: %.5f", Integer.valueOf(i5), Double.valueOf(d2)));
            dArr3 = dArr;
            i4 = i2;
            dArr5 = dArr11;
            dArr4 = dArr7;
            bBDTree = bBDTree2;
            dArr6 = dArr9;
            iArr4 = iArr6;
            length2 = length2;
            kMeansArr2 = kMeansArr3;
            arrayList2 = arrayList4;
            length = i11;
        }
        return new XMeans(d2, dArr2, iArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ void lambda$fit$1(ArrayList arrayList, int i, int[] iArr, double[] dArr, double[][] dArr2, int i2) {
        double[] dArr3 = (double[]) arrayList.get(i2);
        for (int i3 = 0; i3 < i; i3++) {
            if (iArr[i3] == i2) {
                dArr[i2] = dArr[i2] + MathEx.squaredDistance(dArr2[i3], dArr3);
            }
        }
    }

    private static double logLikelihood(int i, int i2, int i3, int i4, double d) {
        double d2 = -i3;
        double d3 = i3;
        return ((((LOG2PI * d2) + ((r1 * i4) * Math.log(d))) + (-(i3 - i))) / 2.0d) + (d3 * Math.log(d3)) + (d2 * Math.log(i2));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // smile.clustering.CentroidClustering
    public double distance(double[] dArr, double[] dArr2) {
        return MathEx.squaredDistance(dArr, dArr2);
    }
}
