package deepboof.impl.forward.standard;

import deepboof.BaseTensor;
import deepboof.Tensor;
import deepboof.forward.ConfigPadding;
import deepboof.forward.SpatialPadding2D;

/* loaded from: classes4.dex */
public abstract class BaseSpatialPadding2D<T extends Tensor<T>> extends BaseTensor implements SpatialPadding2D<T> {
    protected int COL0;
    protected int COL1;
    protected int ROW0;
    protected int ROW1;
    protected ConfigPadding config;
    protected T input;

    public BaseSpatialPadding2D(ConfigPadding configPadding) {
        this.config = configPadding;
    }

    public <T extends Tensor<T>> void checkBackwardsShapeChannel(Tensor<T> tensor, Tensor<T> tensor2) {
        if (tensor.getDimension() != 2) {
            throw new IllegalArgumentException("Padded image expected to be a 2D spatial image, i.e. 2 channels");
        }
        if (tensor2.getDimension() != 4) {
            throw new IllegalArgumentException("Original image expected to be a 4D spatial image, i.e. 4 channels");
        }
        if (tensor.length(0) != tensor2.length(2) + this.config.y0 + this.config.y1) {
            throw new IllegalArgumentException("Image heights do not match.  " + tensor.length(0) + " != " + tensor2.length(2) + this.config.y0 + this.config.y1);
        }
        if (tensor.length(1) != tensor2.length(3) + this.config.x0 + this.config.x1) {
            throw new IllegalArgumentException("Image widths do not match.  " + tensor.length(1) + " != " + tensor2.length(3) + this.config.x0 + this.config.x1);
        }
    }

    public <T extends Tensor<T>> void checkBackwardsShapeImage(Tensor<T> tensor, Tensor<T> tensor2) {
        if (tensor.getDimension() != 3) {
            throw new IllegalArgumentException("Padded image expected to be a 3D spatial image, i.e. 3 channels");
        }
        if (tensor2.getDimension() != 4) {
            throw new IllegalArgumentException("Original image expected to be a 4D spatial image, i.e. 4 channels");
        }
        if (tensor.length(0) != tensor2.length(1)) {
            throw new IllegalArgumentException("Image channels do not match.  " + tensor.length(0) + " != " + tensor2.length(1));
        }
        if (tensor.length(1) != tensor2.length(2) + this.config.y0 + this.config.y1) {
            throw new IllegalArgumentException("Image heights do not match.  " + tensor.length(1) + " != " + tensor2.length(2) + this.config.y0 + this.config.y1);
        }
        if (tensor.length(2) != tensor2.length(3) + this.config.x0 + this.config.x1) {
            throw new IllegalArgumentException("Image widths do not match.  " + tensor.length(2) + " != " + tensor2.length(3) + this.config.x0 + this.config.x1);
        }
    }

    @Override // deepboof.forward.SpatialPadding2D
    public int getPaddingCol0() {
        return this.config.x0;
    }

    @Override // deepboof.forward.SpatialPadding2D
    public int getPaddingCol1() {
        return this.config.x1;
    }

    @Override // deepboof.forward.SpatialPadding2D
    public int getPaddingRow0() {
        return this.config.y0;
    }

    @Override // deepboof.forward.SpatialPadding2D
    public int getPaddingRow1() {
        return this.config.y1;
    }

    @Override // deepboof.forward.SpatialPadding2D
    public void setInput(T t) {
        if (t.getDimension() != 4) {
            throw new IllegalArgumentException("Expected 4-DOF spatial tensor");
        }
        this.input = t;
        int length = t.length(2);
        int length2 = t.length(3);
        this.COL0 = this.config.x0;
        this.ROW0 = this.config.y0;
        this.COL1 = length2 + this.config.x0;
        this.ROW1 = length + this.config.y0;
        this.shape = shapeGivenInput(t.shape);
    }

    @Override // deepboof.forward.SpatialPadding2D
    public int[] shapeGivenInput(int... iArr) {
        if (iArr.length == 3) {
            return new int[]{iArr[0], iArr[1] + this.config.y0 + this.config.y1, iArr[2] + this.config.x0 + this.config.x1};
        }
        if (iArr.length == 4) {
            return new int[]{iArr[0], iArr[1], iArr[2] + this.config.y0 + this.config.y1, iArr[3] + this.config.x0 + this.config.x1};
        }
        throw new IllegalArgumentException("Spatial tensor with 3 or 4 dof expected");
    }
}
