package smile.clustering;

import java.util.Arrays;
import smile.math.MathEx;

/* loaded from: classes5.dex */
public class BBDTree {
    private int[] index;
    private Node root;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes5.dex */
    public class Node {
        double[] center;
        double cost;
        int index;
        Node lower;
        double[] radius;
        int size;
        double[] sum;
        Node upper;

        Node(int i) {
            this.center = new double[i];
            this.radius = new double[i];
            this.sum = new double[i];
        }
    }

    public BBDTree(double[][] dArr) {
        int length = dArr.length;
        this.index = new int[length];
        for (int i = 0; i < length; i++) {
            this.index[i] = i;
        }
        this.root = buildNode(dArr, 0, length);
    }

    private Node buildNode(double[][] dArr, int i, int i2) {
        int i3 = 0;
        int length = dArr[0].length;
        Node node = new Node(length);
        int i4 = i2 - i;
        node.size = i4;
        node.index = i;
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        for (int i5 = 0; i5 < length; i5++) {
            double[] dArr4 = dArr[this.index[i]];
            dArr2[i5] = dArr4[i5];
            dArr3[i5] = dArr4[i5];
        }
        int i6 = i + 1;
        for (int i7 = i6; i7 < i2; i7++) {
            for (int i8 = 0; i8 < length; i8++) {
                double d = dArr[this.index[i7]][i8];
                if (dArr2[i8] > d) {
                    dArr2[i8] = d;
                }
                if (dArr3[i8] < d) {
                    dArr3[i8] = d;
                }
            }
        }
        int i9 = -1;
        double d2 = -1.0d;
        for (int i10 = 0; i10 < length; i10++) {
            node.center[i10] = (dArr2[i10] + dArr3[i10]) / 2.0d;
            node.radius[i10] = (dArr3[i10] - dArr2[i10]) / 2.0d;
            if (node.radius[i10] > d2) {
                d2 = node.radius[i10];
                i9 = i10;
            }
        }
        if (d2 < 1.0E-10d) {
            node.upper = null;
            node.lower = null;
            System.arraycopy(dArr[this.index[i]], 0, node.sum, 0, length);
            if (i2 > i6) {
                while (i3 < length) {
                    double[] dArr5 = node.sum;
                    dArr5[i3] = dArr5[i3] * i4;
                    i3++;
                }
            }
            node.cost = 0.0d;
            return node;
        }
        double d3 = node.center[i9];
        int i11 = i2 - 1;
        int i12 = i;
        int i13 = 0;
        while (i12 <= i11) {
            int[] iArr = this.index;
            int i14 = iArr[i12];
            boolean z = true;
            boolean z2 = dArr[i14][i9] < d3;
            int i15 = iArr[i11];
            boolean z3 = dArr[i15][i9] >= d3;
            if (z2 || z3) {
                z = z2;
            } else {
                iArr[i12] = i15;
                iArr[i11] = i14;
                z3 = true;
            }
            if (z) {
                i12++;
                i13++;
            }
            if (z3) {
                i11--;
            }
        }
        int i16 = i + i13;
        node.lower = buildNode(dArr, i, i16);
        node.upper = buildNode(dArr, i16, i2);
        for (int i17 = 0; i17 < length; i17++) {
            node.sum[i17] = node.lower.sum[i17] + node.upper.sum[i17];
        }
        double[] dArr6 = new double[length];
        while (i3 < length) {
            dArr6[i3] = node.sum[i3] / node.size;
            i3++;
        }
        node.cost = getNodeCost(node.lower, dArr6) + getNodeCost(node.upper, dArr6);
        return node;
    }

    private double filter(Node node, double[][] dArr, int[] iArr, int i, double[][] dArr2, int[] iArr2, int[] iArr3) {
        int length = dArr[0].length;
        double squaredDistance = MathEx.squaredDistance(node.center, dArr[iArr[0]]);
        int i2 = iArr[0];
        for (int i3 = 1; i3 < i; i3++) {
            double squaredDistance2 = MathEx.squaredDistance(node.center, dArr[iArr[i3]]);
            if (squaredDistance2 < squaredDistance) {
                i2 = iArr[i3];
                squaredDistance = squaredDistance2;
            }
        }
        if (node.lower != null) {
            int[] iArr4 = new int[i];
            int i4 = 0;
            for (int i5 = 0; i5 < i; i5++) {
                if (!prune(node.center, node.radius, dArr, i2, iArr[i5])) {
                    iArr4[i4] = iArr[i5];
                    i4++;
                }
            }
            if (i4 > 1) {
                int i6 = i4;
                return filter(node.lower, dArr, iArr4, i6, dArr2, iArr2, iArr3) + filter(node.upper, dArr, iArr4, i6, dArr2, iArr2, iArr3);
            }
        }
        for (int i7 = 0; i7 < length; i7++) {
            double[] dArr3 = dArr2[i2];
            dArr3[i7] = dArr3[i7] + node.sum[i7];
        }
        iArr2[i2] = iArr2[i2] + node.size;
        int i8 = node.index + node.size;
        for (int i9 = node.index; i9 < i8; i9++) {
            iArr3[this.index[i9]] = i2;
        }
        return getNodeCost(node, dArr[i2]);
    }

    private double getNodeCost(Node node, double[] dArr) {
        int length = dArr.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            double d2 = (node.sum[i] / node.size) - dArr[i];
            d += d2 * d2;
        }
        return node.cost + (node.size * d);
    }

    private boolean prune(double[] dArr, double[] dArr2, double[][] dArr3, int i, int i2) {
        if (i == i2) {
            return false;
        }
        int length = dArr3[0].length;
        double[] dArr4 = dArr3[i];
        double[] dArr5 = dArr3[i2];
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < length; i3++) {
            double d3 = dArr5[i3];
            double d4 = dArr4[i3];
            double d5 = d3 - d4;
            d += d5 * d5;
            d2 += ((d5 > 0.0d ? dArr[i3] + dArr2[i3] : dArr[i3] - dArr2[i3]) - d4) * d5;
        }
        return d >= d2 * 2.0d;
    }

    public double clustering(double[][] dArr, double[][] dArr2, int[] iArr, int[] iArr2) {
        int length = dArr.length;
        Arrays.fill(iArr, 0);
        int[] iArr3 = new int[length];
        for (int i = 0; i < length; i++) {
            iArr3[i] = i;
            Arrays.fill(dArr2[i], 0.0d);
        }
        double filter = filter(this.root, dArr, iArr3, length, dArr2, iArr, iArr2);
        int length2 = dArr[0].length;
        for (int i2 = 0; i2 < length; i2++) {
            if (iArr[i2] > 0) {
                for (int i3 = 0; i3 < length2; i3++) {
                    dArr[i2][i3] = dArr2[i2][i3] / iArr[i2];
                }
            }
        }
        return filter;
    }
}
