package deepboof.impl.forward.standard;

import deepboof.DeepBoofConstants;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

/* loaded from: classes2.dex */
public class FunctionBatchNorm_F64 extends BaseFunction<Tensor_F64> {
    protected boolean requiresGammaBeta;
    protected Tensor_F64 params = new Tensor_F64(0);
    protected double EPS = DeepBoofConstants.TEST_TOL_F64 * 0.1d;

    public FunctionBatchNorm_F64(boolean z) {
        this.requiresGammaBeta = z;
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        if (tensor_F64.getDimension() <= 1) {
            throw new IllegalArgumentException("Input tensor must be at least 2D.  First dimension of batch.");
        }
        int outerLength = TensorOps.outerLength(tensor_F64.shape, 1);
        int i = tensor_F64.startIndex;
        int i2 = tensor_F642.startIndex;
        int i3 = 0;
        if (!this.requiresGammaBeta) {
            while (i3 < this.miniBatchSize) {
                int i4 = this.params.startIndex;
                int i5 = i + outerLength;
                while (i < i5) {
                    double[] dArr = this.params.d;
                    int i6 = i4 + 1;
                    double d = dArr[i4];
                    i4 += 2;
                    tensor_F642.d[i2] = (tensor_F64.d[i] - d) * dArr[i6];
                    i2++;
                    i++;
                }
                i3++;
            }
            return;
        }
        while (i3 < this.miniBatchSize) {
            int i7 = this.params.startIndex;
            int i8 = i + outerLength;
            while (i < i8) {
                double[] dArr2 = this.params.d;
                double d2 = dArr2[i7];
                double d3 = dArr2[i7 + 1];
                int i9 = i7 + 3;
                double d4 = dArr2[i7 + 2];
                i7 += 4;
                tensor_F642.d[i2] = ((tensor_F64.d[i] - d2) * d4 * d3) + dArr2[i9];
                i2++;
                i++;
            }
            i3++;
        }
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _initialize() {
        this.shapeOutput = (int[]) this.shapeInput.clone();
        int[] WI = TensorOps.WI(this.shapeInput, this.requiresGammaBeta ? 4 : 2);
        this.shapeParameters.add(WI);
        this.params.reshape(WI);
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _setParameters(List<Tensor_F64> list) {
        this.params.setTo(list.get(0));
        int length = this.params.length();
        int i = this.requiresGammaBeta ? 4 : 2;
        for (int i2 = 1; i2 < length; i2 += i) {
            double[] dArr = this.params.d;
            dArr[i2] = 1.0d / Math.sqrt(dArr[i2] + this.EPS);
        }
    }

    public boolean hasGammaBeta() {
        return this.requiresGammaBeta;
    }

    public void setEPS(double d) {
        this.EPS = d;
    }
}
