package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.Collectors;
import smile.math.TimeFunction;

/* loaded from: classes5.dex */
public abstract class MultilayerPerceptron implements Serializable {
    private static final long serialVersionUID = 2;
    protected Layer[] net;
    protected OutputLayer output;
    protected int p;
    protected transient ThreadLocal<double[]> target;
    protected TimeFunction learningRate = TimeFunction.constant(0.01d);
    protected TimeFunction momentum = TimeFunction.constant(0.0d);
    protected double rho = 0.0d;
    protected double epsilon = 1.0E-7d;
    protected double lambda = 0.0d;
    protected int t = 0;

    public MultilayerPerceptron(Layer... layerArr) {
        if (layerArr.length < 2) {
            throw new IllegalArgumentException("Too few layers: " + layerArr.length);
        }
        Layer layer = layerArr[0];
        int i = 1;
        while (i < layerArr.length) {
            Layer layer2 = layerArr[i];
            if (layer2.getInputSize() != layer.getOutputSize()) {
                throw new IllegalArgumentException(String.format("Invalid network architecture. Layer %d has %d neurons while layer %d takes %d inputs", Integer.valueOf(i - 1), Integer.valueOf(layer.getOutputSize()), Integer.valueOf(i), Integer.valueOf(layer2.getInputSize())));
            }
            i++;
            layer = layer2;
        }
        this.output = (OutputLayer) layerArr[layerArr.length - 1];
        this.net = (Layer[]) Arrays.copyOf(layerArr, layerArr.length - 1);
        this.p = layerArr[0].getInputSize();
        init();
    }

    private void init() {
        this.target = new ThreadLocal<double[]>() { // from class: smile.base.mlp.MultilayerPerceptron.1
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[MultilayerPerceptron.this.output.getOutputSize()];
            }
        };
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        init();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void backpropagate(double[] dArr, boolean z) {
        this.output.computeOutputGradient(this.target.get(), 1.0d);
        Layer layer = this.output;
        for (int length = this.net.length - 1; length >= 0; length--) {
            layer.backpropagate(this.net[length].gradient());
            layer = this.net[length];
        }
        layer.backpropagate(null);
        int i = 0;
        if (!z) {
            double[] dArr2 = dArr;
            for (Layer layer2 : this.net) {
                layer2.computeGradient(dArr2);
                dArr2 = layer2.output();
            }
            this.output.computeGradient(dArr2);
            return;
        }
        double apply = this.learningRate.apply(this.t);
        if (apply <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + apply);
        }
        double apply2 = this.momentum.apply(this.t);
        if (apply2 < 0.0d || apply2 >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + apply2);
        }
        double d = 1.0d - ((2.0d * apply) * this.lambda);
        if (d < 0.9d) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or L2 regularization (lambda = %.2f) such that weight decay = %.2f", Double.valueOf(apply), Double.valueOf(this.lambda), Double.valueOf(d)));
        }
        Layer[] layerArr = this.net;
        int length2 = layerArr.length;
        double[] dArr3 = dArr;
        while (i < length2) {
            Layer layer3 = layerArr[i];
            layer3.computeGradientUpdate(dArr3, apply, apply2, d);
            dArr3 = layer3.output();
            i++;
            apply2 = apply2;
        }
        this.output.computeGradientUpdate(dArr3, apply, apply2, d);
    }

    public double getLearningRate() {
        return this.learningRate.apply(this.t);
    }

    public double getMomentum() {
        return this.momentum.apply(this.t);
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void propagate(double[] dArr) {
        int i = 0;
        while (true) {
            Layer[] layerArr = this.net;
            if (i >= layerArr.length) {
                this.output.propagate(dArr);
                return;
            } else {
                layerArr[i].propagate(dArr);
                dArr = this.net[i].output();
                i++;
            }
        }
    }

    public void setLearningRate(TimeFunction timeFunction) {
        this.learningRate = timeFunction;
    }

    public void setMomentum(TimeFunction timeFunction) {
        this.momentum = timeFunction;
    }

    public void setRMSProp(double d, double d2) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid rho = " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid epsilon = " + d2);
        }
        this.rho = d;
        this.epsilon = d2;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + d);
        }
        this.lambda = d;
    }

    public String toString() {
        return String.format("x(%d) -> %s -> %s(learning rate = %s, momentum = %s, weight decay = %.2f)", Integer.valueOf(this.p), Arrays.stream(this.net).map(new Function() { // from class: smile.base.mlp.MultilayerPerceptron$$ExternalSyntheticLambda0
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                String obj2;
                obj2 = ((Layer) obj).toString();
                return obj2;
            }
        }).collect(Collectors.joining(" -> ")), this.output, this.learningRate, this.momentum, Double.valueOf(this.lambda));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update(int i) {
        double apply = this.learningRate.apply(this.t);
        if (apply <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + apply);
        }
        double apply2 = this.momentum.apply(this.t);
        if (apply2 < 0.0d || apply2 >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + apply2);
        }
        double d = 1.0d - ((2.0d * apply) * this.lambda);
        if (d < 0.9d) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or decay (lambda = %.2f)", Double.valueOf(apply), Double.valueOf(this.lambda)));
        }
        Layer[] layerArr = this.net;
        int i2 = 0;
        for (int length = layerArr.length; i2 < length; length = length) {
            layerArr[i2].update(i, apply, apply2, d, this.rho, this.epsilon);
            i2++;
            apply2 = apply2;
            layerArr = layerArr;
        }
        this.output.update(i, apply, apply2, d, this.rho, this.epsilon);
    }
}
