package smile.regression;

import coil.disk.DiskLruCache;
import java.util.Arrays;
import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;

/* loaded from: classes5.dex */
public class RidgeRegression {
    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d) {
        double[] dArr = new double[dataFrame.size()];
        Arrays.fill(dArr, 1.0d);
        return fit(formula, dataFrame, dArr, new double[]{d}, new double[]{0.0d});
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.valueOf(properties.getProperty("smile.ridge.lambda", DiskLruCache.VERSION)).doubleValue());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double[] dArr, double[] dArr2, double[] dArr3) {
        double[] dArr4 = dArr;
        double[] dArr5 = dArr2;
        double[] dArr6 = dArr3;
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        int i = 0;
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        int nrows = matrix.nrows();
        int ncols = matrix.ncols();
        if (dArr4.length != nrows) {
            throw new IllegalArgumentException(String.format("Invalid weights vector size: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(nrows)));
        }
        for (int i2 = 0; i2 < nrows; i2++) {
            if (dArr4[i2] <= 0.0d) {
                throw new IllegalArgumentException(String.format("Invalid weights[%d] = %f", Integer.valueOf(i2), Double.valueOf(dArr4[i2])));
            }
        }
        if (dArr5.length == 1) {
            double d = dArr5[0];
            dArr5 = new double[ncols];
            Arrays.fill(dArr5, d);
        } else if (dArr5.length != ncols) {
            throw new IllegalArgumentException(String.format("Invalid lambda vector size: %d != %d", Integer.valueOf(dArr5.length), Integer.valueOf(ncols)));
        }
        for (int i3 = 0; i3 < ncols; i3++) {
            if (dArr5[i3] < 0.0d) {
                throw new IllegalArgumentException(String.format("Invalid lambda[%d] = %f", Integer.valueOf(i3), Double.valueOf(dArr5[i3])));
            }
        }
        if (dArr6.length == 1) {
            double d2 = dArr6[0];
            dArr6 = new double[ncols];
            Arrays.fill(dArr6, d2);
        } else if (dArr6.length != ncols) {
            throw new IllegalArgumentException(String.format("Invalid beta0 vector size: %d != %d", Integer.valueOf(dArr6.length), Integer.valueOf(ncols)));
        }
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        for (int i4 = 0; i4 < colSds.length; i4++) {
            if (MathEx.isZero(colSds[i4])) {
                throw new IllegalArgumentException(String.format("The column '%s' is constant", matrix.colName(i4)));
            }
        }
        Matrix scale = matrix.scale(colMeans, colSds);
        Matrix matrix2 = new Matrix(ncols, nrows);
        int i5 = 0;
        while (i5 < ncols) {
            while (i < nrows) {
                matrix2.set(i5, i, dArr4[i] * scale.get(i, i5));
                i++;
                nrows = nrows;
                dArr4 = dArr;
            }
            i5++;
            dArr4 = dArr;
            i = 0;
        }
        double[] mv = matrix2.mv(doubleArray);
        for (int i6 = 0; i6 < ncols; i6++) {
            mv[i6] = mv[i6] + (dArr5[i6] * dArr6[i6]);
        }
        Matrix mm = matrix2.mm(scale);
        mm.uplo(UPLO.LOWER);
        for (int i7 = 0; i7 < ncols; i7++) {
            mm.add(i7, i7, dArr5[i7]);
        }
        double[] solve = mm.cholesky(true).solve(mv);
        for (int i8 = 0; i8 < ncols; i8++) {
            solve[i8] = solve[i8] / colSds[i8];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, solve, MathEx.mean(doubleArray) - MathEx.dot(solve, colMeans));
    }
}
