package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: classes5.dex */
public abstract class Layer implements Serializable {
    private static final long serialVersionUID = 2;
    protected double[] bias;
    protected transient ThreadLocal<double[]> biasGradient;
    protected transient ThreadLocal<double[]> biasUpdate;
    protected int n;
    protected transient ThreadLocal<double[]> output;
    protected transient ThreadLocal<double[]> outputGradient;
    protected int p;
    protected transient ThreadLocal<double[]> rmsBiasGradient;
    protected transient ThreadLocal<Matrix> rmsWeightGradient;
    protected Matrix weight;
    protected transient ThreadLocal<Matrix> weightGradient;
    protected transient ThreadLocal<Matrix> weightUpdate;

    /* JADX WARN: Illegal instructions before constructor call */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public Layer(int r11, int r12) {
        /*
            r10 = this;
            int r0 = r11 + r12
            double r0 = (double) r0
            r2 = 4618441417868443648(0x4018000000000000, double:6.0)
            double r2 = r2 / r0
            double r0 = java.lang.Math.sqrt(r2)
            double r6 = -r0
            double r8 = java.lang.Math.sqrt(r2)
            r4 = r11
            r5 = r12
            smile.math.matrix.Matrix r12 = smile.math.matrix.Matrix.rand(r4, r5, r6, r8)
            double[] r11 = new double[r11]
            r10.<init>(r12, r11)
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: smile.base.mlp.Layer.<init>(int, int):void");
    }

    public Layer(Matrix matrix, double[] dArr) {
        this.n = matrix.nrows();
        this.p = matrix.ncols();
        this.weight = matrix;
        this.bias = dArr;
        init();
    }

    private void init() {
        this.output = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.1
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.outputGradient = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.2
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.weightGradient = new ThreadLocal<Matrix>() { // from class: smile.base.mlp.Layer.3
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.biasGradient = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.4
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.rmsWeightGradient = new ThreadLocal<Matrix>() { // from class: smile.base.mlp.Layer.5
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.rmsBiasGradient = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.6
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.weightUpdate = new ThreadLocal<Matrix>() { // from class: smile.base.mlp.Layer.7
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.biasUpdate = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.8
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
    }

    public static HiddenLayerBuilder linear(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.linear());
    }

    public static OutputLayerBuilder mle(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.LIKELIHOOD);
    }

    public static OutputLayerBuilder mse(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.MEAN_SQUARED_ERROR);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        init();
    }

    public static HiddenLayerBuilder rectifier(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder sigmoid(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.tanh());
    }

    public abstract void backpropagate(double[] dArr);

    public void computeGradient(double[] dArr) {
        double[] dArr2 = this.outputGradient.get();
        Matrix matrix = this.weightGradient.get();
        double[] dArr3 = this.biasGradient.get();
        matrix.add(1.0d, dArr2, dArr);
        for (int i = 0; i < this.n; i++) {
            dArr3[i] = dArr3[i] + dArr2[i];
        }
    }

    public void computeGradientUpdate(double[] dArr, double d, double d2, double d3) {
        double[] dArr2 = this.outputGradient.get();
        int i = 0;
        if (d2 <= 0.0d || d2 >= 1.0d) {
            this.weight.add(d, dArr2, dArr);
            while (i < this.n) {
                double[] dArr3 = this.bias;
                dArr3[i] = dArr3[i] + (dArr2[i] * d);
                i++;
            }
        } else {
            Matrix matrix = this.weightUpdate.get();
            double[] dArr4 = this.biasUpdate.get();
            matrix.mul(d2);
            matrix.add(d, dArr2, dArr);
            this.weight.add(1.0d, matrix);
            while (i < this.n) {
                double d4 = (dArr4[i] * d2) + (dArr2[i] * d);
                dArr4[i] = d4;
                double[] dArr5 = this.bias;
                dArr5[i] = dArr5[i] + d4;
                i++;
            }
        }
        if (d3 <= 0.9d || d3 >= 1.0d) {
            return;
        }
        this.weight.mul(d3);
    }

    public abstract void f(double[] dArr);

    public int getInputSize() {
        return this.p;
    }

    public int getOutputSize() {
        return this.n;
    }

    public double[] gradient() {
        return this.outputGradient.get();
    }

    public double[] output() {
        return this.output.get();
    }

    public void propagate(double[] dArr) {
        double[] dArr2 = this.output.get();
        System.arraycopy(this.bias, 0, dArr2, 0, this.n);
        this.weight.mv(1.0d, dArr, 1.0d, dArr2);
        f(dArr2);
    }

    public void update(int i, double d, double d2, double d3, double d4, double d5) {
        double d6;
        Matrix matrix = this.weightGradient.get();
        double[] dArr = this.biasGradient.get();
        double d7 = i;
        double d8 = d / d7;
        if (d4 <= 0.0d || d4 >= 1.0d) {
            d6 = d8;
        } else {
            matrix.div(d7);
            for (int i2 = 0; i2 < this.n; i2++) {
                dArr[i2] = dArr[i2] / d7;
            }
            Matrix matrix2 = this.rmsWeightGradient.get();
            double[] dArr2 = this.rmsBiasGradient.get();
            double d9 = 1.0d - d4;
            for (int i3 = 0; i3 < this.p; i3++) {
                for (int i4 = 0; i4 < this.n; i4++) {
                    matrix2.set(i4, i3, (matrix2.get(i4, i3) * d4) + (MathEx.sqr(matrix.get(i4, i3)) * d9));
                }
            }
            for (int i5 = 0; i5 < this.n; i5++) {
                dArr2[i5] = (dArr2[i5] * d4) + (MathEx.sqr(dArr[i5]) * d9);
            }
            for (int i6 = 0; i6 < this.p; i6++) {
                for (int i7 = 0; i7 < this.n; i7++) {
                    matrix.div(i7, i6, Math.sqrt(d5 + matrix2.get(i7, i6)));
                }
            }
            for (int i8 = 0; i8 < this.n; i8++) {
                dArr[i8] = dArr[i8] / Math.sqrt(d5 + dArr2[i8]);
            }
            d6 = d;
        }
        if (d2 <= 0.0d || d2 >= 1.0d) {
            this.weight.add(d6, matrix);
            for (int i9 = 0; i9 < this.n; i9++) {
                double[] dArr3 = this.bias;
                dArr3[i9] = dArr3[i9] + (dArr[i9] * d6);
            }
        } else {
            Matrix matrix3 = this.weightUpdate.get();
            double[] dArr4 = this.biasUpdate.get();
            matrix3.add(d2, d6, matrix);
            for (int i10 = 0; i10 < this.n; i10++) {
                dArr4[i10] = (dArr4[i10] * d2) + (dArr[i10] * d6);
            }
            this.weight.add(1.0d, matrix3);
            MathEx.add(this.bias, dArr4);
        }
        if (d3 > 0.9d && d3 < 1.0d) {
            this.weight.mul(d3);
        }
        matrix.fill(0.0d);
        Arrays.fill(dArr, 0.0d);
    }
}
