package smile.glm;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Properties;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.glm.model.Model;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.math.special.Erf;
import smile.stat.Hypothesis;
import smile.validation.ModelSelection;

/* loaded from: classes5.dex */
public class GLM implements Serializable {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) GLM.class);
    private static final long serialVersionUID = 2;
    protected double[] beta;
    protected double deviance;
    protected double[] devianceResiduals;
    protected int df;
    protected Formula formula;
    protected double loglikelihood;
    protected Model model;
    protected double[] mu;
    protected double nullDeviance;
    String[] predictors;
    protected double[][] ztest;

    public GLM(Formula formula, String[] strArr, Model model, double[] dArr, double d, double d2, double d3, double[] dArr2, double[] dArr3, double[][] dArr4) {
        this.formula = formula;
        this.model = model;
        this.predictors = strArr;
        this.beta = dArr;
        this.loglikelihood = d;
        this.deviance = d2;
        this.nullDeviance = d3;
        this.mu = dArr2;
        this.devianceResiduals = dArr3;
        this.ztest = dArr4;
        this.df = dArr2.length - dArr.length;
    }

    public static GLM fit(Formula formula, DataFrame dataFrame, Model model) {
        return fit(formula, dataFrame, model, new Properties());
    }

    public static GLM fit(Formula formula, DataFrame dataFrame, final Model model, double d, int i) {
        Model model2;
        int i2;
        double[] dArr;
        double[] dArr2;
        int i3 = i;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d);
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        Matrix matrix = formula.matrix(dataFrame, true);
        Matrix matrix2 = new Matrix(matrix.nrows(), matrix.ncols());
        final double[] doubleArray = formula.y(dataFrame).toDoubleArray();
        int nrows = matrix.nrows();
        int ncols = matrix.ncols();
        if (nrows <= ncols) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(nrows), Integer.valueOf(ncols)));
        }
        final double[] dArr3 = new double[nrows];
        final double[] dArr4 = new double[nrows];
        final double[] dArr5 = new double[nrows];
        final double[] dArr6 = new double[nrows];
        double[] dArr7 = dArr4;
        double[] dArr8 = dArr3;
        int i4 = ncols;
        double[] dArr9 = new double[nrows];
        IntStream.range(0, nrows).parallel().forEach(new IntConsumer() { // from class: smile.glm.GLM$$ExternalSyntheticLambda0
            @Override // java.util.function.IntConsumer
            public final void accept(int i5) {
                GLM.lambda$fit$0(dArr4, model, doubleArray, dArr3, dArr6, dArr5, i5);
            }
        });
        for (int i5 = 0; i5 < i4; i5++) {
            for (int i6 = 0; i6 < nrows; i6++) {
                matrix2.set(i6, i5, matrix.get(i6, i5) * dArr5[i6]);
            }
        }
        Matrix.QR qr = matrix2.qr(true);
        double[] solve = qr.solve(dArr6);
        Matrix.QR qr2 = qr;
        double d2 = Double.POSITIVE_INFINITY;
        int i7 = 0;
        while (true) {
            if (i7 >= i3) {
                model2 = model;
                i2 = i4;
                dArr = dArr7;
                dArr2 = dArr9;
                break;
            }
            final double[] dArr10 = dArr8;
            matrix.mv(solve, dArr10);
            final double[] dArr11 = dArr7;
            int i8 = i7;
            IntStream.range(0, nrows).parallel().forEach(new IntConsumer() { // from class: smile.glm.GLM$$ExternalSyntheticLambda1
                @Override // java.util.function.IntConsumer
                public final void accept(int i9) {
                    GLM.lambda$fit$1(dArr11, model, dArr10, dArr6, doubleArray, dArr5, i9);
                }
            });
            model2 = model;
            dArr = dArr7;
            i2 = i4;
            double[] dArr12 = dArr9;
            double deviance = model2.deviance(doubleArray, dArr, dArr12);
            if (i8 > 0) {
                logger.info(String.format("Deviance after %3d iterations: %.5f", Integer.valueOf(i8), Double.valueOf(d2)));
            }
            dArr2 = dArr12;
            if (d2 - deviance < d) {
                break;
            }
            for (int i9 = 0; i9 < i2; i9++) {
                for (int i10 = 0; i10 < nrows; i10++) {
                    matrix2.set(i10, i9, matrix.get(i10, i9) * dArr5[i10]);
                }
            }
            qr2 = matrix2.qr(true);
            solve = qr2.solve(dArr6);
            i7 = i8 + 1;
            i4 = i2;
            d2 = deviance;
            dArr9 = dArr2;
            dArr8 = dArr10;
            dArr7 = dArr;
            i3 = i;
        }
        Matrix inverse = qr2.CholeskyOfAtA().inverse();
        double[][] dArr13 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, i2, 4);
        for (int i11 = 0; i11 < i2; i11++) {
            double[] dArr14 = dArr13[i11];
            dArr14[0] = solve[i11];
            dArr14[1] = Math.sqrt(inverse.get(i11, i11));
            double[] dArr15 = dArr13[i11];
            double d3 = dArr15[0] / dArr15[1];
            dArr15[2] = d3;
            dArr15[3] = 2.0d - Erf.erfc(Math.abs(d3) * (-0.7071067811865476d));
        }
        return new GLM(formula, matrix.colNames(), model, solve, model2.loglikelihood(doubleArray, dArr), d2, model2.nullDeviance(doubleArray, MathEx.mean(doubleArray)), dArr, dArr2, dArr13);
    }

    public static GLM fit(Formula formula, DataFrame dataFrame, Model model, Properties properties) {
        return fit(formula, dataFrame, model, Double.valueOf(properties.getProperty("smile.glm.tolerance", "1E-5")).doubleValue(), Integer.valueOf(properties.getProperty("smile.glm.max.iterations", "50")).intValue());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ void lambda$fit$0(double[] dArr, Model model, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, int i) {
        double mustart = model.mustart(dArr2[i]);
        dArr[i] = mustart;
        dArr3[i] = model.link(mustart);
        double dlink = model.dlink(dArr[i]);
        dArr4[i] = dArr3[i] + ((dArr2[i] - dArr[i]) * dlink);
        double sqrt = 1.0d / (dlink * Math.sqrt(model.variance(dArr[i])));
        dArr5[i] = sqrt;
        dArr4[i] = dArr4[i] * sqrt;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ void lambda$fit$1(double[] dArr, Model model, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, int i) {
        double invlink = model.invlink(dArr2[i]);
        dArr[i] = invlink;
        double dlink = model.dlink(invlink);
        dArr3[i] = dArr2[i] + ((dArr4[i] - dArr[i]) * dlink);
        double sqrt = 1.0d / (dlink * Math.sqrt(model.variance(dArr[i])));
        dArr5[i] = sqrt;
        dArr3[i] = dArr3[i] * sqrt;
    }

    public double AIC() {
        return ModelSelection.AIC(this.loglikelihood, this.beta.length);
    }

    public double BIC() {
        return ModelSelection.BIC(this.loglikelihood, this.beta.length, this.mu.length);
    }

    public double[] coefficients() {
        return this.beta;
    }

    public double deviance() {
        return this.deviance;
    }

    public double[] devianceResiduals() {
        return this.devianceResiduals;
    }

    public double[] fittedValues() {
        return this.mu;
    }

    public double loglikelihood() {
        return this.loglikelihood;
    }

    public double predict(Tuple tuple) {
        double[] array = this.formula.x(tuple).toArray(true, CategoricalEncoder.DUMMY);
        int length = this.beta.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            d += array[i] * this.beta[i];
        }
        return this.model.invlink(d);
    }

    public double[] predict(DataFrame dataFrame) {
        double[] mv = this.formula.matrix(dataFrame, true).mv(this.beta);
        int length = mv.length;
        for (int i = 0; i < length; i++) {
            mv[i] = this.model.invlink(mv[i]);
        }
        return mv;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Generalized Linear Model - %s:\n", this.model));
        double[] dArr = (double[]) this.devianceResiduals.clone();
        sb.append("\nDeviance Residuals:\n       Min          1Q      Median          3Q         Max\n");
        sb.append(String.format("%10.4f  %10.4f  %10.4f  %10.4f  %10.4f%n", Double.valueOf(MathEx.min(dArr)), Double.valueOf(MathEx.q1(dArr)), Double.valueOf(MathEx.median(dArr)), Double.valueOf(MathEx.q3(dArr)), Double.valueOf(MathEx.max(dArr))));
        int length = this.beta.length - 1;
        sb.append("\nCoefficients:\n");
        if (this.ztest != null) {
            sb.append("                  Estimate Std. Error    z value   Pr(>|z|)\n");
            for (int i = 0; i < length; i++) {
                sb.append(String.format("%-15s %10.3e %10.3e %10.4f %10.5f %s%n", this.predictors[i], Double.valueOf(this.ztest[i][0]), Double.valueOf(this.ztest[i][1]), Double.valueOf(this.ztest[i][2]), Double.valueOf(this.ztest[i][3]), Hypothesis.significance(this.ztest[i][3])));
            }
            sb.append("---------------------------------------------------------------------\nSignificance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        } else {
            sb.append(String.format("Intercept       %10.4f%n", Double.valueOf(this.beta[length])));
            for (int i2 = 0; i2 < length; i2++) {
                sb.append(String.format("%-15s %10.4f%n", this.predictors[i2], Double.valueOf(this.beta[i2])));
            }
        }
        sb.append(String.format("%n    Null deviance: %.1f on %d degrees of freedom", Double.valueOf(this.nullDeviance), Integer.valueOf(this.df + length)));
        sb.append(String.format("%nResidual deviance: %.1f on %d degrees of freedom", Double.valueOf(this.deviance), Integer.valueOf(this.df)));
        sb.append(String.format("%nAIC: %.4f     BIC: %.4f%n", Double.valueOf(AIC()), Double.valueOf(BIC())));
        return sb.toString();
    }

    public double[][] ztest() {
        return this.ztest;
    }
}
