package smile.math;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.matrix.Matrix;

/* loaded from: classes5.dex */
public class LevenbergMarquardt {
    private static Logger logger = LoggerFactory.getLogger((Class<?>) LevenbergMarquardt.class);
    public final double[] fittedValues;
    public final double[] parameters;
    public final double[] residuals;
    public final double sse;

    LevenbergMarquardt(double[] dArr, double[] dArr2, double[] dArr3, double d) {
        this.parameters = dArr;
        this.fittedValues = dArr2;
        this.residuals = dArr3;
        this.sse = d;
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[] dArr, double[] dArr2, double[] dArr3) {
        return fit(differentiableMultivariateFunction, dArr, dArr2, dArr3, 1.0E-4d, 20);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[] dArr, double[] dArr2, double[] dArr3, double d, int i) {
        DifferentiableMultivariateFunction differentiableMultivariateFunction2;
        double d2;
        double[] dArr4;
        double[] dArr5;
        double[] dArr6;
        DifferentiableMultivariateFunction differentiableMultivariateFunction3 = differentiableMultivariateFunction;
        double[] dArr7 = dArr;
        int i2 = i;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + d);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int length = dArr7.length;
        int length2 = dArr3.length;
        int i3 = length2 + 1;
        double[] dArr8 = new double[i3];
        double[] dArr9 = new double[i3];
        System.arraycopy(dArr3, 0, dArr8, 0, length2);
        double[] dArr10 = new double[length];
        double[] dArr11 = new double[length];
        double[] dArr12 = new double[length2];
        double[] dArr13 = new double[length2];
        double[] dArr14 = new double[length2];
        double[] dArr15 = new double[length2];
        double[] dArr16 = new double[length2];
        Arrays.fill(dArr16, 1.0d);
        Matrix matrix = new Matrix(length, length2);
        double[] dArr17 = new double[i3];
        double[] dArr18 = {0.1d, 1.0d, 100.0d, 10000.0d, 1000000.0d};
        int i4 = 1;
        double d3 = 1.0d;
        while (i4 <= i2) {
            System.arraycopy(dArr8, 0, dArr9, 0, length2);
            int i5 = 0;
            double d4 = 0.0d;
            while (i5 < length) {
                dArr9[length2] = dArr7[i5];
                double g = dArr2[i5] - differentiableMultivariateFunction3.g(dArr9, dArr14);
                dArr11[i5] = g;
                d4 += g * g;
                int i6 = i4;
                for (int i7 = 0; i7 < length2; i7++) {
                    matrix.set(i5, i7, dArr14[i7]);
                }
                i5++;
                differentiableMultivariateFunction3 = differentiableMultivariateFunction;
                dArr7 = dArr;
                i4 = i6;
            }
            int i8 = i4;
            double d5 = (1.0d - d) * d4;
            for (int i9 = 0; i9 < length2; i9++) {
                double d6 = 0.0d;
                for (int i10 = 0; i10 < length; i10++) {
                    double d7 = matrix.get(i10, i9);
                    d6 += d7 * d7;
                }
                if (d6 > 0.0d) {
                    dArr16[i9] = 1.0d / Math.sqrt(d6);
                } else {
                    dArr16[i9] = 1.0d;
                }
                for (int i11 = 0; i11 < length; i11++) {
                    matrix.mul(i11, i9, dArr16[i9]);
                }
            }
            d2 = 0.0d;
            Matrix.SVD svd = matrix.svd(true, true);
            double[] dArr19 = svd.s;
            double dot = MathEx.dot(dArr19, dArr19);
            Matrix matrix2 = svd.U;
            Matrix matrix3 = svd.V;
            matrix2.tv(dArr11, dArr12);
            Matrix matrix4 = matrix;
            double d8 = d4;
            int i12 = 0;
            int i13 = 5;
            while (true) {
                if (i12 >= i13) {
                    differentiableMultivariateFunction2 = differentiableMultivariateFunction;
                    dArr4 = dArr13;
                    double[] dArr20 = dArr17;
                    dArr5 = dArr16;
                    dArr6 = dArr20;
                    break;
                }
                double d9 = dot;
                double[] dArr21 = dArr13;
                double max = Math.max(d3 * dArr18[i12], 1.0E-7d);
                double sqrt = Math.sqrt(d9 + max);
                for (int i14 = 0; i14 < length2; i14++) {
                    dArr21[i14] = dArr12[i14] / sqrt;
                }
                dArr4 = dArr21;
                matrix3.mv(dArr4, dArr15);
                for (int i15 = 0; i15 < length2; i15++) {
                    dArr15[i15] = dArr15[i15] * dArr16[i15];
                }
                for (int i16 = 0; i16 < length2; i16++) {
                    dArr17[i16] = dArr15[i16] + dArr9[i16];
                }
                double d10 = 0.0d;
                int i17 = 0;
                while (i17 < length) {
                    dArr17[length2] = dArr[i17];
                    double[] dArr22 = dArr17;
                    double f = dArr2[i17] - differentiableMultivariateFunction.f(dArr22);
                    d10 += f * f;
                    i17++;
                    matrix3 = matrix3;
                    dArr17 = dArr22;
                    dArr16 = dArr16;
                }
                Matrix matrix5 = matrix3;
                differentiableMultivariateFunction2 = differentiableMultivariateFunction;
                double[] dArr23 = dArr17;
                dArr5 = dArr16;
                dArr6 = dArr23;
                if (d10 < d4) {
                    System.arraycopy(dArr6, 0, dArr8, 0, length2);
                    d4 = d10;
                }
                if (d10 <= d5) {
                    d3 = max;
                    d8 = d10;
                    break;
                }
                i12++;
                d8 = d10;
                dot = d9;
                matrix3 = matrix5;
                dArr13 = dArr4;
                i13 = 5;
                dArr17 = dArr6;
                dArr16 = dArr5;
            }
            logger.info(String.format("SSE after %3d iterations: %.5f", Integer.valueOf(i8), Double.valueOf(d4)));
            if (d8 < MathEx.EPSILON || d8 > d5) {
                logger.info(String.format("converges on SSE after %d iterations", Integer.valueOf(i8)));
                break;
            }
            i4 = i8 + 1;
            i2 = i;
            differentiableMultivariateFunction3 = differentiableMultivariateFunction2;
            dArr13 = dArr4;
            matrix = matrix4;
            dArr7 = dArr;
            double[] dArr24 = dArr5;
            dArr17 = dArr6;
            dArr16 = dArr24;
        }
        differentiableMultivariateFunction2 = differentiableMultivariateFunction3;
        d2 = 0.0d;
        double[] dArr25 = new double[length2];
        System.arraycopy(dArr8, 0, dArr25, 0, length2);
        for (int i18 = 0; i18 < length; i18++) {
            dArr8[length2] = dArr[i18];
            double f2 = differentiableMultivariateFunction2.f(dArr8);
            dArr10[i18] = f2;
            double d11 = dArr2[i18] - f2;
            dArr11[i18] = d11;
            d2 += d11 * d11;
        }
        return new LevenbergMarquardt(dArr25, dArr10, dArr11, d2);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[][] dArr, double[] dArr2, double[] dArr3) {
        return fit(differentiableMultivariateFunction, dArr, dArr2, dArr3, 1.0E-4d, 20);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[][] dArr, double[] dArr2, double[] dArr3, double d, int i) {
        DifferentiableMultivariateFunction differentiableMultivariateFunction2;
        double d2;
        double[] dArr4;
        double[] dArr5;
        double[] dArr6;
        double[] dArr7;
        double[] dArr8;
        DifferentiableMultivariateFunction differentiableMultivariateFunction3 = differentiableMultivariateFunction;
        double[][] dArr9 = dArr;
        double d3 = d;
        int i2 = i;
        if (d3 <= 0.0d) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + d);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int length = dArr9.length;
        int length2 = dArr9[0].length;
        int length3 = dArr3.length;
        int i3 = length3 + length2;
        double[] dArr10 = new double[i3];
        double[] dArr11 = new double[i3];
        System.arraycopy(dArr3, 0, dArr10, 0, length3);
        double[] dArr12 = new double[length];
        double[] dArr13 = new double[length];
        double[] dArr14 = new double[length3];
        double[] dArr15 = new double[i3];
        double[] dArr16 = new double[length3];
        double[] dArr17 = new double[length3];
        double[] dArr18 = new double[length3];
        Arrays.fill(dArr16, 1.0d);
        Matrix matrix = new Matrix(length, length3);
        double[] dArr19 = new double[length3];
        double[] dArr20 = {0.1d, 1.0d, 100.0d, 10000.0d, 1000000.0d};
        int i4 = 1;
        double d4 = 1.0d;
        while (i4 <= i2) {
            int i5 = 0;
            System.arraycopy(dArr10, 0, dArr11, 0, length3);
            double d5 = 0.0d;
            while (i5 < length) {
                int i6 = i4;
                double[] dArr21 = dArr10;
                System.arraycopy(dArr9[i5], 0, dArr11, length3, length2);
                double g = dArr2[i5] - differentiableMultivariateFunction3.g(dArr11, dArr14);
                dArr13[i5] = g;
                d5 += g * g;
                for (int i7 = 0; i7 < length3; i7++) {
                    matrix.set(i5, i7, dArr14[i7]);
                }
                i5++;
                differentiableMultivariateFunction3 = differentiableMultivariateFunction;
                dArr9 = dArr;
                i4 = i6;
                dArr10 = dArr21;
            }
            int i8 = i4;
            double[] dArr22 = dArr10;
            double d6 = (1.0d - d3) * d5;
            for (int i9 = 0; i9 < length3; i9++) {
                double d7 = 0.0d;
                for (int i10 = 0; i10 < length; i10++) {
                    double d8 = matrix.get(i10, i9);
                    d7 += d8 * d8;
                }
                if (d7 > 0.0d) {
                    dArr16[i9] = 1.0d / Math.sqrt(d7);
                } else {
                    dArr16[i9] = 1.0d;
                }
                for (int i11 = 0; i11 < length; i11++) {
                    matrix.mul(i11, i9, dArr16[i9]);
                }
            }
            d2 = 0.0d;
            Matrix.SVD svd = matrix.svd(true, true);
            double[] dArr23 = svd.s;
            double dot = MathEx.dot(dArr23, dArr23);
            Matrix matrix2 = svd.U;
            Matrix matrix3 = svd.V;
            double[] dArr24 = dArr17;
            matrix2.tv(dArr13, dArr24);
            double d9 = d5;
            int i12 = 0;
            int i13 = 5;
            while (true) {
                if (i12 >= i13) {
                    dArr4 = dArr11;
                    dArr5 = dArr15;
                    dArr6 = dArr19;
                    dArr10 = dArr22;
                    differentiableMultivariateFunction2 = differentiableMultivariateFunction;
                    dArr7 = dArr18;
                    dArr8 = dArr20;
                    break;
                }
                dArr4 = dArr11;
                Matrix matrix4 = matrix3;
                double max = Math.max(d4 * dArr20[i12], 1.0E-7d);
                double sqrt = Math.sqrt(dot + max);
                for (int i14 = 0; i14 < length3; i14++) {
                    dArr18[i14] = dArr24[i14] / sqrt;
                }
                double[] dArr25 = dArr18;
                double[] dArr26 = dArr19;
                dArr8 = dArr20;
                matrix4.mv(dArr25, dArr26);
                for (int i15 = 0; i15 < length3; i15++) {
                    dArr26[i15] = dArr26[i15] * dArr16[i15];
                }
                for (int i16 = 0; i16 < length3; i16++) {
                    dArr15[i16] = dArr26[i16] + dArr4[i16];
                }
                double d10 = 0.0d;
                int i17 = 0;
                while (i17 < length) {
                    double[] dArr27 = dArr25;
                    double[] dArr28 = dArr15;
                    System.arraycopy(dArr[i17], 0, dArr28, length3, length2);
                    double f = dArr2[i17] - differentiableMultivariateFunction.f(dArr28);
                    d10 += f * f;
                    i17++;
                    dArr26 = dArr26;
                    dArr15 = dArr28;
                    dArr25 = dArr27;
                }
                differentiableMultivariateFunction2 = differentiableMultivariateFunction;
                dArr7 = dArr25;
                dArr5 = dArr15;
                dArr6 = dArr26;
                if (d10 < d5) {
                    dArr10 = dArr22;
                    System.arraycopy(dArr5, 0, dArr10, 0, length3);
                    d5 = d10;
                } else {
                    dArr10 = dArr22;
                }
                if (d10 <= d6) {
                    d4 = max;
                    d9 = d10;
                    break;
                }
                i12++;
                matrix3 = matrix4;
                dArr22 = dArr10;
                dArr11 = dArr4;
                dArr20 = dArr8;
                dArr18 = dArr7;
                d9 = d10;
                i13 = 5;
                dArr19 = dArr6;
                dArr15 = dArr5;
            }
            logger.info(String.format("SSE after %3d iterations: %.5f", Integer.valueOf(i8), Double.valueOf(d5)));
            if (d9 < MathEx.EPSILON || d9 > d6) {
                logger.info(String.format("converges on SSE after %d iterations", Integer.valueOf(i8)));
                break;
            }
            d3 = d;
            i2 = i;
            dArr20 = dArr8;
            dArr18 = dArr7;
            dArr19 = dArr6;
            dArr15 = dArr5;
            differentiableMultivariateFunction3 = differentiableMultivariateFunction2;
            dArr11 = dArr4;
            dArr17 = dArr24;
            i4 = i8 + 1;
            dArr9 = dArr;
        }
        differentiableMultivariateFunction2 = differentiableMultivariateFunction3;
        d2 = 0.0d;
        double[] dArr29 = new double[length3];
        System.arraycopy(dArr10, 0, dArr29, 0, length3);
        for (int i18 = 0; i18 < length; i18++) {
            System.arraycopy(dArr[i18], 0, dArr10, length3, length2);
            double f2 = differentiableMultivariateFunction2.f(dArr10);
            dArr12[i18] = f2;
            double d11 = dArr2[i18] - f2;
            dArr13[i18] = d11;
            d2 += d11 * d11;
        }
        return new LevenbergMarquardt(dArr29, dArr12, dArr13, d2);
    }
}
