package smile.classification;

import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Properties;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.SHAP;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.util.IntSet;
import smile.util.Strings;

/* loaded from: classes5.dex */
public class GradientTreeBoost implements SoftClassifier<Tuple>, DataFrameClassifier, SHAP<Tuple> {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) GradientTreeBoost.class);
    private static final long serialVersionUID = 2;
    private double b;
    private RegressionTree[][] forest;
    private Formula formula;
    private double[] importance;
    private int k;
    private IntSet labels;
    private double shrinkage;
    private RegressionTree[] trees;

    public GradientTreeBoost(Formula formula, RegressionTree[] regressionTreeArr, double d, double d2, double[] dArr) {
        this(formula, regressionTreeArr, d, d2, dArr, IntSet.of(2));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[] regressionTreeArr, double d, double d2, double[] dArr, IntSet intSet) {
        this.formula = formula;
        this.k = 2;
        this.trees = regressionTreeArr;
        this.b = d;
        this.shrinkage = d2;
        this.importance = dArr;
        this.labels = intSet;
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] regressionTreeArr, double d, double[] dArr) {
        this(formula, regressionTreeArr, d, dArr, IntSet.of(regressionTreeArr.length));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] regressionTreeArr, double d, double[] dArr, IntSet intSet) {
        this.k = 2;
        this.b = 0.0d;
        this.shrinkage = 0.05d;
        this.formula = formula;
        this.k = regressionTreeArr.length;
        this.forest = regressionTreeArr;
        this.shrinkage = d;
        this.importance = dArr;
        this.labels = intSet;
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame dataFrame, int i, int i2, int i3, int i4, double d, double d2) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Invalid shrinkage: " + d);
        }
        if (d2 <= 0.0d || d2 > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + d2);
        }
        Formula expand = formula.expand(dataFrame.schema());
        DataFrame x = expand.x(dataFrame);
        BaseVector y = expand.y(dataFrame);
        int[][] order = CART.order(x);
        ClassLabels fit = ClassLabels.fit(y);
        return fit.k == 2 ? train2(expand, x, fit, order, i, i2, i3, i4, d, d2) : traink(expand, x, fit, order, i, i2, i3, i4, d, d2);
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Integer.valueOf(properties.getProperty("smile.gbt.trees", "500")).intValue(), Integer.valueOf(properties.getProperty("smile.gbt.max.depth", "20")).intValue(), Integer.valueOf(properties.getProperty("smile.gbt.max.nodes", "6")).intValue(), Integer.valueOf(properties.getProperty("smile.gbt.node.size", "5")).intValue(), Double.valueOf(properties.getProperty("smile.gbt.shrinkage", "0.05")).doubleValue(), Double.valueOf(properties.getProperty("smile.gbt.sample.rate", "0.7")).doubleValue());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ RegressionTree[] lambda$trees$0(int i) {
        return new RegressionTree[i];
    }

    private static void sampling(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, double d) {
        int length = iArr.length;
        int length2 = iArr3.length;
        Arrays.fill(iArr, 0);
        MathEx.permutate(iArr2);
        for (int i = 0; i < length2; i++) {
            int round = (int) Math.round(iArr3[i] * d);
            int i2 = 0;
            for (int i3 = 0; i3 < length && i2 < round; i3++) {
                int i4 = iArr2[i3];
                if (iArr4[i4] == i) {
                    iArr[i4] = 1;
                    i2++;
                }
            }
        }
    }

    private static GradientTreeBoost train2(Formula formula, DataFrame dataFrame, ClassLabels classLabels, int[][] iArr, int i, int i2, int i3, int i4, double d, double d2) {
        int nrows = dataFrame.nrows();
        int i5 = classLabels.k;
        int[] iArr2 = classLabels.y;
        int[] iArr3 = new int[i5];
        int i6 = 0;
        for (int i7 = 0; i7 < nrows; i7++) {
            int i8 = iArr2[i7];
            iArr3[i8] = iArr3[i8] + 1;
        }
        Loss logistic = Loss.logistic(iArr2);
        double intercept = logistic.intercept(null);
        double[] residual = logistic.residual();
        StructField structField = new StructField("residual", DataTypes.DoubleType);
        RegressionTree[] regressionTreeArr = new RegressionTree[i];
        int[] array = IntStream.range(0, nrows).toArray();
        int[] iArr4 = new int[nrows];
        int i9 = 0;
        while (i9 < i) {
            int i10 = i9;
            sampling(iArr4, array, iArr3, iArr2, d2);
            i9 = i10 + 1;
            logger.info("Training {} tree", Strings.ordinal(i9));
            int[] iArr5 = iArr4;
            RegressionTree[] regressionTreeArr2 = regressionTreeArr;
            StructField structField2 = structField;
            Loss loss = logistic;
            RegressionTree regressionTree = new RegressionTree(dataFrame, logistic, structField2, i2, i3, i4, dataFrame.ncols(), iArr5, iArr);
            regressionTreeArr2[i10] = regressionTree;
            int i11 = i6;
            while (i11 < nrows) {
                residual[i11] = residual[i11] + (regressionTree.predict(dataFrame.get(i11)) * d);
                i11++;
                i6 = 0;
            }
            iArr4 = iArr5;
            regressionTreeArr = regressionTreeArr2;
            structField = structField2;
            logistic = loss;
        }
        RegressionTree[] regressionTreeArr3 = regressionTreeArr;
        double[] dArr = new double[dataFrame.ncols()];
        for (int i12 = 0; i12 < i; i12++) {
            double[] importance = regressionTreeArr3[i12].importance();
            for (int i13 = 0; i13 < importance.length; i13++) {
                dArr[i13] = dArr[i13] + importance[i13];
            }
        }
        return new GradientTreeBoost(formula, regressionTreeArr3, intercept, d, dArr, classLabels.labels);
    }

    private static GradientTreeBoost traink(Formula formula, DataFrame dataFrame, ClassLabels classLabels, int[][] iArr, int i, int i2, int i3, int i4, double d, double d2) {
        int i5 = i;
        int nrows = dataFrame.nrows();
        int i6 = classLabels.k;
        int[] iArr2 = classLabels.y;
        int[] iArr3 = new int[i6];
        for (int i7 = 0; i7 < nrows; i7++) {
            int i8 = iArr2[i7];
            iArr3[i8] = iArr3[i8] + 1;
        }
        StructField structField = new StructField("residual", DataTypes.DoubleType);
        RegressionTree[][] regressionTreeArr = (RegressionTree[][]) Array.newInstance((Class<?>) RegressionTree.class, i6, i5);
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) Double.TYPE, nrows, i6);
        double[][] dArr2 = new double[i6];
        Loss[] lossArr = new Loss[i6];
        for (int i9 = 0; i9 < i6; i9++) {
            Loss logistic = Loss.logistic(i9, i6, iArr2, dArr);
            lossArr[i9] = logistic;
            dArr2[i9] = logistic.residual();
        }
        int[] array = IntStream.range(0, nrows).toArray();
        int[] iArr4 = new int[nrows];
        int i10 = 0;
        while (i10 < i5) {
            int i11 = i10 + 1;
            logger.info("Training {} tree", Strings.ordinal(i11));
            for (int i12 = 0; i12 < nrows; i12++) {
                for (int i13 = 0; i13 < i6; i13++) {
                    dArr[i12][i13] = dArr2[i13][i12];
                }
                MathEx.softmax(dArr[i12]);
            }
            int i14 = 0;
            while (i14 < i6) {
                int i15 = i10;
                int[] iArr5 = iArr4;
                double[][] dArr3 = dArr2;
                Loss[] lossArr2 = lossArr;
                sampling(iArr4, array, iArr3, iArr2, d2);
                double[][] dArr4 = dArr;
                RegressionTree[][] regressionTreeArr2 = regressionTreeArr;
                StructField structField2 = structField;
                RegressionTree regressionTree = new RegressionTree(dataFrame, lossArr2[i14], structField2, i2, i3, i4, dataFrame.ncols(), iArr5, iArr);
                regressionTreeArr2[i14][i15] = regressionTree;
                double[] dArr5 = dArr3[i14];
                for (int i16 = 0; i16 < nrows; i16++) {
                    dArr5[i16] = dArr5[i16] + (regressionTree.predict(dataFrame.get(i16)) * d);
                }
                i14++;
                dArr = dArr4;
                regressionTreeArr = regressionTreeArr2;
                structField = structField2;
                i10 = i15;
                iArr4 = iArr5;
                dArr2 = dArr3;
                lossArr = lossArr2;
            }
            i5 = i;
            structField = structField;
            i10 = i11;
            dArr2 = dArr2;
        }
        RegressionTree[][] regressionTreeArr3 = regressionTreeArr;
        double[] dArr6 = new double[dataFrame.ncols()];
        for (RegressionTree[] regressionTreeArr4 : regressionTreeArr3) {
            for (RegressionTree regressionTree2 : regressionTreeArr4) {
                double[] importance = regressionTree2.importance();
                for (int i17 = 0; i17 < importance.length; i17++) {
                    dArr6[i17] = dArr6[i17] + importance[i17];
                }
            }
        }
        return new GradientTreeBoost(formula, regressionTreeArr3, d, dArr6, classLabels.labels);
    }

    @Override // smile.classification.DataFrameClassifier, smile.feature.TreeSHAP
    public Formula formula() {
        return this.formula;
    }

    public double[] importance() {
        return this.importance;
    }

    @Override // smile.classification.Classifier
    public int predict(Tuple tuple) {
        Tuple x = this.formula.x(tuple);
        double d = 0.0d;
        if (this.k == 2) {
            double d2 = this.b;
            for (RegressionTree regressionTree : this.trees) {
                d2 += this.shrinkage * regressionTree.predict(x);
            }
            return this.labels.valueOf(d2 > 0.0d ? 1 : 0);
        }
        double d3 = Double.NEGATIVE_INFINITY;
        int i = -1;
        int i2 = 0;
        while (i2 < this.k) {
            double d4 = d;
            for (RegressionTree regressionTree2 : this.forest[i2]) {
                d4 += this.shrinkage * regressionTree2.predict(x);
            }
            if (d4 > d3) {
                i = i2;
                d3 = d4;
            }
            i2++;
            d = 0.0d;
        }
        return this.labels.valueOf(i);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(Tuple tuple, double[] dArr) {
        if (dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        Tuple x = this.formula.x(tuple);
        double d = 0.0d;
        if (this.k == 2) {
            double d2 = this.b;
            for (RegressionTree regressionTree : this.trees) {
                d2 += this.shrinkage * regressionTree.predict(x);
            }
            double exp = 1.0d / (Math.exp(2.0d * d2) + 1.0d);
            dArr[0] = exp;
            dArr[1] = 1.0d - exp;
            return this.labels.valueOf(d2 > 0.0d ? 1 : 0);
        }
        double d3 = Double.NEGATIVE_INFINITY;
        int i = -1;
        int i2 = 0;
        while (i2 < this.k) {
            dArr[i2] = d;
            for (RegressionTree regressionTree2 : this.forest[i2]) {
                dArr[i2] = dArr[i2] + (this.shrinkage * regressionTree2.predict(x));
            }
            double d4 = dArr[i2];
            if (d4 > d3) {
                d3 = d4;
                i = i2;
            }
            i2++;
            d = 0.0d;
        }
        double d5 = 0.0d;
        for (int i3 = 0; i3 < this.k; i3++) {
            double exp2 = Math.exp(dArr[i3] - d3);
            dArr[i3] = exp2;
            d5 += exp2;
        }
        while (r7 < this.k) {
            dArr[r7] = dArr[r7] / d5;
            r7++;
        }
        return this.labels.valueOf(i);
    }

    @Override // smile.classification.DataFrameClassifier
    public StructType schema() {
        RegressionTree[] regressionTreeArr = this.trees;
        return regressionTreeArr != null ? regressionTreeArr[0].schema() : this.forest[0][0].schema();
    }

    public double[] shap(DataFrame dataFrame) {
        this.formula.bind(dataFrame.schema());
        return shap((Stream) dataFrame.stream().parallel());
    }

    @Override // smile.feature.SHAP
    public double[] shap(Tuple tuple) {
        int i;
        Tuple x = this.formula.x(tuple);
        int length = x.length();
        int i2 = this.k * length;
        double[] dArr = new double[i2];
        RegressionTree[] regressionTreeArr = this.trees;
        if (regressionTreeArr != null) {
            i = regressionTreeArr.length;
            for (RegressionTree regressionTree : regressionTreeArr) {
                double[] shap = regressionTree.shap(x);
                for (int i3 = 0; i3 < i2; i3++) {
                    dArr[i3] = dArr[i3] + shap[i3];
                }
            }
        } else {
            int length2 = this.forest[0].length;
            for (int i4 = 0; i4 < this.k; i4++) {
                for (RegressionTree regressionTree2 : this.forest[i4]) {
                    double[] shap2 = regressionTree2.shap(x);
                    for (int i5 = 0; i5 < length; i5++) {
                        int i6 = (this.k * i5) + i4;
                        dArr[i6] = dArr[i6] + shap2[i5];
                    }
                }
            }
            i = length2;
        }
        for (int i7 = 0; i7 < i2; i7++) {
            dArr[i7] = dArr[i7] / i;
        }
        return dArr;
    }

    public int size() {
        return trees().length;
    }

    public int[][] test(DataFrame dataFrame) {
        DataFrame x = this.formula.x(dataFrame);
        int nrows = x.nrows();
        RegressionTree[] regressionTreeArr = this.trees;
        int i = 0;
        int length = regressionTreeArr != null ? regressionTreeArr.length : this.forest[0].length;
        int[][] iArr = (int[][]) Array.newInstance((Class<?>) Integer.TYPE, length, nrows);
        int i2 = this.k;
        double d = 0.0d;
        if (i2 == 2) {
            for (int i3 = 0; i3 < nrows; i3++) {
                Tuple tuple = x.get(i3);
                double d2 = 0.0d;
                for (int i4 = 0; i4 < length; i4++) {
                    d2 += this.shrinkage * this.trees[i4].predict(tuple);
                    iArr[i4][i3] = d2 > 0.0d ? 1 : 0;
                }
            }
        } else {
            double[] dArr = new double[i2];
            int i5 = 0;
            while (i5 < nrows) {
                Tuple tuple2 = x.get(i5);
                Arrays.fill(dArr, d);
                int i6 = i;
                while (i6 < length) {
                    for (int i7 = i; i7 < this.k; i7++) {
                        dArr[i7] = dArr[i7] + (this.shrinkage * this.forest[i7][i6].predict(tuple2));
                    }
                    iArr[i6][i5] = MathEx.whichMax(dArr);
                    i6++;
                    i = 0;
                }
                i5++;
                i = 0;
                d = 0.0d;
            }
        }
        return iArr;
    }

    public RegressionTree[] trees() {
        RegressionTree[] regressionTreeArr = this.trees;
        return regressionTreeArr != null ? regressionTreeArr : (RegressionTree[]) Arrays.stream(this.forest).flatMap(new Function() { // from class: smile.classification.GradientTreeBoost$$ExternalSyntheticLambda0
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                Stream stream;
                stream = Arrays.stream((RegressionTree[]) obj);
                return stream;
            }
        }).toArray(new IntFunction() { // from class: smile.classification.GradientTreeBoost$$ExternalSyntheticLambda1
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return GradientTreeBoost.lambda$trees$0(i);
            }
        });
    }

    public void trim(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        if (this.k == 2) {
            RegressionTree[] regressionTreeArr = this.trees;
            if (i > regressionTreeArr.length) {
                throw new IllegalArgumentException("The new model size is larger than the current size.");
            }
            if (i < regressionTreeArr.length) {
                this.trees = (RegressionTree[]) Arrays.copyOf(regressionTreeArr, i);
                return;
            }
            return;
        }
        int i2 = 0;
        RegressionTree[] regressionTreeArr2 = this.forest[0];
        if (i > regressionTreeArr2.length) {
            throw new IllegalArgumentException("The new model size is larger than the current one.");
        }
        if (i >= regressionTreeArr2.length) {
            return;
        }
        while (true) {
            RegressionTree[][] regressionTreeArr3 = this.forest;
            if (i2 >= regressionTreeArr3.length) {
                return;
            }
            regressionTreeArr3[i2] = (RegressionTree[]) Arrays.copyOf(regressionTreeArr3[i2], i);
            i2++;
        }
    }
}
