package smile.base.mlp;

/* loaded from: classes5.dex */
public class OutputLayer extends Layer {
    private static final long serialVersionUID = 2;
    private Cost cost;
    private OutputFunction f;

    /* renamed from: smile.base.mlp.OutputLayer$1, reason: invalid class name */
    /* loaded from: classes5.dex */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$smile$base$mlp$Cost;
        static final /* synthetic */ int[] $SwitchMap$smile$base$mlp$OutputFunction;

        static {
            int[] iArr = new int[Cost.values().length];
            $SwitchMap$smile$base$mlp$Cost = iArr;
            try {
                iArr[Cost.MEAN_SQUARED_ERROR.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$smile$base$mlp$Cost[Cost.LIKELIHOOD.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            int[] iArr2 = new int[OutputFunction.values().length];
            $SwitchMap$smile$base$mlp$OutputFunction = iArr2;
            try {
                iArr2[OutputFunction.SOFTMAX.ordinal()] = 1;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$smile$base$mlp$OutputFunction[OutputFunction.LINEAR.ordinal()] = 2;
            } catch (NoSuchFieldError unused4) {
            }
        }
    }

    public OutputLayer(int i, int i2, OutputFunction outputFunction, Cost cost) {
        super(i, i2);
        int i3 = AnonymousClass1.$SwitchMap$smile$base$mlp$Cost[cost.ordinal()];
        if (i3 != 1) {
            if (i3 == 2 && AnonymousClass1.$SwitchMap$smile$base$mlp$OutputFunction[outputFunction.ordinal()] == 2) {
                throw new IllegalArgumentException("Linear output function is not allowed with likelihood cost function");
            }
        } else if (AnonymousClass1.$SwitchMap$smile$base$mlp$OutputFunction[outputFunction.ordinal()] == 1) {
            throw new IllegalArgumentException("Softmax output function is not allowed with mean squared error cost function");
        }
        this.f = outputFunction;
        this.cost = cost;
    }

    @Override // smile.base.mlp.Layer
    public void backpropagate(double[] dArr) {
        this.weight.tv(this.outputGradient.get(), dArr);
    }

    public void computeOutputGradient(double[] dArr, double d) {
        double[] dArr2 = this.output.get();
        double[] dArr3 = this.outputGradient.get();
        int length = dArr2.length;
        if (dArr.length != length) {
            throw new IllegalArgumentException(String.format("Invalid target vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(length)));
        }
        for (int i = 0; i < length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        this.f.g(this.cost, dArr3, dArr2);
        if (d <= 0.0d || d == 1.0d) {
            return;
        }
        for (int i2 = 0; i2 < length; i2++) {
            dArr3[i2] = dArr3[i2] * d;
        }
    }

    public Cost cost() {
        return this.cost;
    }

    @Override // smile.base.mlp.Layer
    public void f(double[] dArr) {
        this.f.f(dArr);
    }

    public String toString() {
        return String.format("%s(%d) | %s", this.f.name(), Integer.valueOf(this.n), this.cost);
    }
}
