package smile.base.cart;

import java.util.Arrays;
import java.util.function.IntToDoubleFunction;
import java.util.function.IntUnaryOperator;
import smile.base.cart.Loss;
import smile.math.MathEx;
import smile.sort.QuickSelect;

/* loaded from: classes5.dex */
public interface Loss {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: smile.base.cart.Loss$3, reason: invalid class name */
    /* loaded from: classes5.dex */
    public static class AnonymousClass3 implements Loss {
        double[] residual;
        double[] response;
        final /* synthetic */ double val$p;

        AnonymousClass3(double d) {
            this.val$p = d;
        }

        @Override // smile.base.cart.Loss
        public double intercept(double[] dArr) {
            int length = dArr.length;
            double[] dArr2 = new double[length];
            this.response = dArr2;
            this.residual = new double[length];
            System.arraycopy(dArr, 0, dArr2, 0, length);
            double select = QuickSelect.select(this.response, (int) (length * this.val$p));
            for (int i = 0; i < length; i++) {
                this.residual[i] = dArr[i] - select;
            }
            return select;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* renamed from: lambda$output$0$smile-base-cart-Loss$3, reason: not valid java name */
        public /* synthetic */ double m9948lambda$output$0$smilebasecartLoss$3(int i) {
            return this.residual[i];
        }

        @Override // smile.base.cart.Loss
        public double output(int[] iArr, int[] iArr2) {
            return QuickSelect.select(Arrays.stream(iArr).mapToDouble(new IntToDoubleFunction() { // from class: smile.base.cart.Loss$3$$ExternalSyntheticLambda0
                @Override // java.util.function.IntToDoubleFunction
                public final double applyAsDouble(int i) {
                    return Loss.AnonymousClass3.this.m9948lambda$output$0$smilebasecartLoss$3(i);
                }
            }).toArray(), (int) (r5.length * this.val$p));
        }

        @Override // smile.base.cart.Loss
        public double[] residual() {
            return this.residual;
        }

        @Override // smile.base.cart.Loss
        public double[] response() {
            int i = 0;
            while (true) {
                double[] dArr = this.residual;
                if (i >= dArr.length) {
                    return this.response;
                }
                this.response[i] = Math.signum(dArr[i]);
                i++;
            }
        }

        public String toString() {
            return String.format("Quantile(%3.1f%%)", Double.valueOf(this.val$p * 100.0d));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: smile.base.cart.Loss$4, reason: invalid class name */
    /* loaded from: classes5.dex */
    public static class AnonymousClass4 implements Loss {
        double[] residual;
        double[] response;

        AnonymousClass4() {
        }

        @Override // smile.base.cart.Loss
        public double intercept(double[] dArr) {
            int length = dArr.length;
            double[] dArr2 = new double[length];
            this.response = dArr2;
            this.residual = new double[length];
            System.arraycopy(dArr, 0, dArr2, 0, length);
            double median = QuickSelect.median(this.response);
            for (int i = 0; i < length; i++) {
                this.residual[i] = dArr[i] - median;
            }
            return median;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* renamed from: lambda$output$0$smile-base-cart-Loss$4, reason: not valid java name */
        public /* synthetic */ double m9949lambda$output$0$smilebasecartLoss$4(int i) {
            return this.residual[i];
        }

        @Override // smile.base.cart.Loss
        public double output(int[] iArr, int[] iArr2) {
            return QuickSelect.median(Arrays.stream(iArr).mapToDouble(new IntToDoubleFunction() { // from class: smile.base.cart.Loss$4$$ExternalSyntheticLambda0
                @Override // java.util.function.IntToDoubleFunction
                public final double applyAsDouble(int i) {
                    return Loss.AnonymousClass4.this.m9949lambda$output$0$smilebasecartLoss$4(i);
                }
            }).toArray());
        }

        @Override // smile.base.cart.Loss
        public double[] residual() {
            return this.residual;
        }

        @Override // smile.base.cart.Loss
        public double[] response() {
            int i = 0;
            while (true) {
                double[] dArr = this.residual;
                if (i >= dArr.length) {
                    return this.response;
                }
                this.response[i] = Math.signum(dArr[i]);
                i++;
            }
        }

        public String toString() {
            return "LeastAbsoluteDeviation";
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: smile.base.cart.Loss$5, reason: invalid class name */
    /* loaded from: classes5.dex */
    public static class AnonymousClass5 implements Loss {
        private double delta;
        double[] residual;
        double[] response;
        final /* synthetic */ double val$p;

        AnonymousClass5(double d) {
            this.val$p = d;
        }

        @Override // smile.base.cart.Loss
        public double intercept(double[] dArr) {
            int length = dArr.length;
            double[] dArr2 = new double[length];
            this.response = dArr2;
            this.residual = new double[length];
            System.arraycopy(dArr, 0, dArr2, 0, length);
            double median = QuickSelect.median(this.response);
            for (int i = 0; i < length; i++) {
                this.residual[i] = dArr[i] - median;
            }
            return median;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* renamed from: lambda$output$0$smile-base-cart-Loss$5, reason: not valid java name */
        public /* synthetic */ double m9950lambda$output$0$smilebasecartLoss$5(int i) {
            return this.residual[i];
        }

        @Override // smile.base.cart.Loss
        public double output(int[] iArr, int[] iArr2) {
            double median = QuickSelect.median(Arrays.stream(iArr).mapToDouble(new IntToDoubleFunction() { // from class: smile.base.cart.Loss$5$$ExternalSyntheticLambda0
                @Override // java.util.function.IntToDoubleFunction
                public final double applyAsDouble(int i) {
                    return Loss.AnonymousClass5.this.m9950lambda$output$0$smilebasecartLoss$5(i);
                }
            }).toArray());
            double d = 0.0d;
            for (int i : iArr) {
                double d2 = this.residual[i] - median;
                d += Math.signum(d2) * Math.min(this.delta, Math.abs(d2));
            }
            return median + (d / iArr.length);
        }

        @Override // smile.base.cart.Loss
        public double[] residual() {
            return this.residual;
        }

        @Override // smile.base.cart.Loss
        public double[] response() {
            int length = this.residual.length;
            for (int i = 0; i < length; i++) {
                this.response[i] = Math.abs(this.residual[i]);
            }
            this.delta = QuickSelect.select(this.response, (int) (length * this.val$p));
            for (int i2 = 0; i2 < length; i2++) {
                double abs = Math.abs(this.residual[i2]);
                double d = this.delta;
                if (abs <= d) {
                    this.response[i2] = this.residual[i2];
                } else {
                    this.response[i2] = d * Math.signum(this.residual[i2]);
                }
            }
            return this.response;
        }

        public String toString() {
            return String.format("Huber(%3.1f%%)", Double.valueOf(this.val$p * 100.0d));
        }
    }

    /* renamed from: smile.base.cart.Loss$6, reason: invalid class name */
    /* loaded from: classes5.dex */
    static class AnonymousClass6 implements Loss {
        double[] residual;
        double[] response;
        final /* synthetic */ int[] val$labels;
        final /* synthetic */ int val$n;
        int[] y;

        AnonymousClass6(int[] iArr, int i) {
            this.val$labels = iArr;
            this.val$n = i;
            this.y = Arrays.stream(iArr).map(new IntUnaryOperator() { // from class: smile.base.cart.Loss$6$$ExternalSyntheticLambda0
                @Override // java.util.function.IntUnaryOperator
                public final int applyAsInt(int i2) {
                    return Loss.AnonymousClass6.lambda$$0(i2);
                }
            }).toArray();
            this.response = new double[i];
            this.residual = new double[i];
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static /* synthetic */ int lambda$$0(int i) {
            return (i * 2) - 1;
        }

        @Override // smile.base.cart.Loss
        public double intercept(double[] dArr) {
            double mean = MathEx.mean(this.y);
            double log = Math.log((mean + 1.0d) / (1.0d - mean)) * 0.5d;
            Arrays.fill(this.residual, log);
            return log;
        }

        @Override // smile.base.cart.Loss
        public double output(int[] iArr, int[] iArr2) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i : iArr) {
                double abs = Math.abs(this.response[i]);
                d += this.response[i];
                d2 += abs * (2.0d - abs);
            }
            return d / d2;
        }

        @Override // smile.base.cart.Loss
        public double[] residual() {
            return this.residual;
        }

        @Override // smile.base.cart.Loss
        public double[] response() {
            for (int i = 0; i < this.val$n; i++) {
                this.response[i] = (this.y[i] * 2.0d) / (Math.exp((r2 * 2) * this.residual[i]) + 1.0d);
            }
            return this.response;
        }

        public String toString() {
            return "Logistic";
        }
    }

    /* renamed from: smile.base.cart.Loss$7, reason: invalid class name */
    /* loaded from: classes5.dex */
    static class AnonymousClass7 implements Loss {
        double[] residual;
        double[] response;
        final /* synthetic */ int val$c;
        final /* synthetic */ int val$k;
        final /* synthetic */ int[] val$labels;
        final /* synthetic */ int val$n;
        final /* synthetic */ double[][] val$p;
        int[] y;

        AnonymousClass7(int[] iArr, final int i, int i2, int i3, double[][] dArr) {
            this.val$labels = iArr;
            this.val$c = i;
            this.val$n = i2;
            this.val$k = i3;
            this.val$p = dArr;
            this.y = Arrays.stream(iArr).map(new IntUnaryOperator() { // from class: smile.base.cart.Loss$7$$ExternalSyntheticLambda0
                @Override // java.util.function.IntUnaryOperator
                public final int applyAsInt(int i4) {
                    return Loss.AnonymousClass7.lambda$$0(i, i4);
                }
            }).toArray();
            this.response = new double[i2];
            this.residual = new double[i2];
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static /* synthetic */ int lambda$$0(int i, int i2) {
            return i2 == i ? 1 : 0;
        }

        @Override // smile.base.cart.Loss
        public double intercept(double[] dArr) {
            throw new IllegalStateException("This method should not be called.");
        }

        @Override // smile.base.cart.Loss
        public double output(int[] iArr, int[] iArr2) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i : iArr) {
                double abs = Math.abs(this.response[i]);
                d2 += this.response[i];
                d += abs * (1.0d - abs);
            }
            if (d < 1.0E-10d) {
                return d2 / iArr.length;
            }
            int i2 = this.val$k;
            return ((i2 - 1.0d) / i2) * (d2 / d);
        }

        @Override // smile.base.cart.Loss
        public double[] residual() {
            return this.residual;
        }

        @Override // smile.base.cart.Loss
        public double[] response() {
            for (int i = 0; i < this.val$n; i++) {
                this.response[i] = this.y[i] - this.val$p[i][this.val$c];
            }
            return this.response;
        }

        public String toString() {
            return String.format("Logistic(%d)", Integer.valueOf(this.val$k));
        }
    }

    /* loaded from: classes5.dex */
    public enum Type {
        LeastSquares,
        Quantile,
        LeastAbsoluteDeviation,
        Huber
    }

    static Loss huber(double d) {
        if (d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid percentile: " + d);
        }
        return new AnonymousClass5(d);
    }

    static Loss lad() {
        return new AnonymousClass4();
    }

    static Loss logistic(int i, int i2, int[] iArr, double[][] dArr) {
        return new AnonymousClass7(iArr, i, iArr.length, i2, dArr);
    }

    static Loss logistic(int[] iArr) {
        return new AnonymousClass6(iArr, iArr.length);
    }

    static Loss ls() {
        return new Loss() { // from class: smile.base.cart.Loss.1
            double[] residual;

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr) {
                int length = dArr.length;
                this.residual = new double[length];
                double mean = MathEx.mean(dArr);
                for (int i = 0; i < length; i++) {
                    this.residual[i] = dArr[i] - mean;
                }
                return mean;
            }

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                double d = 0.0d;
                int i = 0;
                for (int i2 : iArr) {
                    int i3 = iArr2[i2];
                    i += i3;
                    d += this.residual[i2] * i3;
                }
                return d / i;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                return this.residual;
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                return this.residual;
            }

            public String toString() {
                return "LeastSquares";
            }
        };
    }

    static Loss ls(double[] dArr) {
        return new Loss(dArr) { // from class: smile.base.cart.Loss.2
            double[] residual;
            final /* synthetic */ double[] val$y;

            {
                this.val$y = dArr;
                this.residual = dArr;
            }

            @Override // smile.base.cart.Loss
            public double intercept(double[] dArr2) {
                throw new IllegalStateException("This method should not be called.");
            }

            @Override // smile.base.cart.Loss
            public double output(int[] iArr, int[] iArr2) {
                double d = 0.0d;
                int i = 0;
                for (int i2 : iArr) {
                    int i3 = iArr2[i2];
                    i += i3;
                    d += this.residual[i2] * i3;
                }
                return d / i;
            }

            @Override // smile.base.cart.Loss
            public double[] residual() {
                throw new IllegalStateException("This method should not be called.");
            }

            @Override // smile.base.cart.Loss
            public double[] response() {
                return this.residual;
            }

            public String toString() {
                return "LeastSquares";
            }
        };
    }

    static Loss quantile(double d) {
        if (d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid percentile: " + d);
        }
        return new AnonymousClass3(d);
    }

    static Loss valueOf(String str) {
        str.hashCode();
        if (str.equals("LeastAbsoluteDeviation")) {
            return lad();
        }
        if (str.equals("LeastSquares")) {
            return ls();
        }
        if (str.startsWith("Quantile(") && str.endsWith(")")) {
            return quantile(Double.parseDouble(str.substring(9, str.length() - 1)));
        }
        if (str.startsWith("Huber(") && str.endsWith(")")) {
            return huber(Double.parseDouble(str.substring(6, str.length() - 1)));
        }
        throw new IllegalArgumentException("Unsupported loss: " + str);
    }

    double intercept(double[] dArr);

    double output(int[] iArr, int[] iArr2);

    double[] residual();

    double[] response();
}
