package smile.stat.distribution;

import java.lang.reflect.Array;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.MultivariateMixture;

/* loaded from: classes5.dex */
public class MultivariateExponentialFamilyMixture extends MultivariateMixture {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) MultivariateExponentialFamilyMixture.class);
    private static final long serialVersionUID = 2;
    public final double L;
    public final double bic;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MultivariateExponentialFamilyMixture(double d, int i, MultivariateMixture.Component... componentArr) {
        super(componentArr);
        for (MultivariateMixture.Component component : componentArr) {
            if (!(component.distribution instanceof MultivariateExponentialFamily)) {
                throw new IllegalArgumentException("Component " + component + " is not of multivariate exponential family.");
            }
        }
        this.L = d;
        this.bic = d - ((length() * 0.5d) * Math.log(i));
    }

    public MultivariateExponentialFamilyMixture(MultivariateMixture.Component... componentArr) {
        this(0.0d, 1, componentArr);
    }

    public static MultivariateExponentialFamilyMixture fit(double[][] dArr, MultivariateMixture.Component... componentArr) {
        return fit(dArr, componentArr, 0.2d, 500, 1.0E-4d);
    }

    public static MultivariateExponentialFamilyMixture fit(double[][] dArr, MultivariateMixture.Component[] componentArr, double d, int i, double d2) {
        int i2;
        double d3;
        if (dArr.length < componentArr.length / 2) {
            throw new IllegalArgumentException("Too many components");
        }
        if (d < 0.0d || d > 0.2d) {
            throw new IllegalArgumentException("Invalid regularization factor gamma.");
        }
        int length = dArr.length;
        int length2 = componentArr.length;
        double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length2, length);
        int i3 = 1;
        double d4 = Double.MAX_VALUE;
        int i4 = i;
        double d5 = 0.0d;
        while (i3 <= i4 && d4 > d2) {
            for (int i5 = 0; i5 < length2; i5++) {
                MultivariateMixture.Component component = componentArr[i5];
                for (int i6 = 0; i6 < length; i6++) {
                    dArr2[i5][i6] = component.priori * component.distribution.p(dArr[i6]);
                }
            }
            for (int i7 = 0; i7 < length; i7++) {
                double d6 = 0.0d;
                for (int i8 = 0; i8 < length2; i8++) {
                    d6 += dArr2[i8][i7];
                }
                for (int i9 = 0; i9 < length2; i9++) {
                    double[] dArr3 = dArr2[i9];
                    dArr3[i7] = dArr3[i7] / d6;
                }
                if (d > 0.0d) {
                    while (i2 < length2) {
                        double[] dArr4 = dArr2[i2];
                        double d7 = dArr4[i7];
                        dArr4[i7] = d7 * ((MathEx.log2(d7) * d) + 1.0d);
                        if (Double.isNaN(dArr2[i2][i7])) {
                            d3 = 0.0d;
                        } else {
                            d3 = 0.0d;
                            i2 = dArr2[i2][i7] >= 0.0d ? i2 + 1 : 0;
                        }
                        dArr2[i2][i7] = d3;
                    }
                }
            }
            double d8 = 0.0d;
            for (int i10 = 0; i10 < length2; i10++) {
                MultivariateMixture.Component M = ((MultivariateExponentialFamily) componentArr[i10].distribution).M(dArr, dArr2[i10]);
                componentArr[i10] = M;
                d8 += M.priori;
            }
            for (int i11 = 0; i11 < length2; i11++) {
                componentArr[i11] = new MultivariateMixture.Component(componentArr[i11].priori / d8, componentArr[i11].distribution);
            }
            int length3 = dArr.length;
            int i12 = 0;
            double d9 = 0.0d;
            while (i12 < length3) {
                double[] dArr5 = dArr[i12];
                int length4 = componentArr.length;
                int i13 = 0;
                double d10 = 0.0d;
                while (i13 < length4) {
                    MultivariateMixture.Component component2 = componentArr[i13];
                    d10 += component2.priori * component2.distribution.p(dArr5);
                    i13++;
                    length = length;
                    length2 = length2;
                }
                int i14 = length;
                int i15 = length2;
                if (d10 > 0.0d) {
                    d9 += Math.log(d10);
                }
                i12++;
                length = i14;
                length2 = i15;
            }
            int i16 = length;
            int i17 = length2;
            d4 = d9 - d5;
            if (i3 % 10 == 0) {
                logger.info(String.format("The log-likelihood after %d iterations: %.4f", Integer.valueOf(i3), Double.valueOf(d9)));
            }
            i3++;
            i4 = i;
            d5 = d9;
            length = i16;
            length2 = i17;
        }
        return new MultivariateExponentialFamilyMixture(d5, dArr.length, componentArr);
    }
}
