package org.ddogleg.optimization.math;

import org.ejml.data.DMatrix;
import org.ejml.data.DMatrixRMaj;
import org.ejml.data.ReshapeMatrix;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.interfaces.linsol.LinearSolverSparse;

/* loaded from: classes3.dex */
public abstract class HessianSchurComplement_Base<S extends DMatrix> implements HessianSchurComplement<S> {
    protected LinearSolverSparse<S, DMatrixRMaj> solverA;
    protected LinearSolverSparse<S, DMatrixRMaj> solverD;
    DMatrixRMaj b1 = new DMatrixRMaj(1, 1);
    DMatrixRMaj b2 = new DMatrixRMaj(1, 1);
    DMatrixRMaj b2_m = new DMatrixRMaj(1, 1);
    DMatrixRMaj x = new DMatrixRMaj(1, 1);
    DMatrixRMaj x1 = new DMatrixRMaj(1, 1);
    DMatrixRMaj x2 = new DMatrixRMaj(1, 1);
    S A = createMatrix();
    S B = createMatrix();
    S D = createMatrix();
    S tmp0 = createMatrix();
    S D_m = createMatrix();

    public HessianSchurComplement_Base(LinearSolverSparse<S, DMatrixRMaj> linearSolverSparse, LinearSolverSparse<S, DMatrixRMaj> linearSolverSparse2) {
        this.solverA = linearSolverSparse;
        this.solverD = linearSolverSparse2;
    }

    protected abstract void add(double d, S s, double d2, S s2, S s3);

    @Override // org.ddogleg.optimization.math.HessianSchurComplement
    public void computeGradient(S s, S s2, DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2) {
        this.x1.reshape(s.getNumCols(), 1);
        this.x2.reshape(s2.getNumCols(), 1);
        multTransA((HessianSchurComplement_Base<S>) s, dMatrixRMaj, this.x1);
        multTransA((HessianSchurComplement_Base<S>) s2, dMatrixRMaj, this.x2);
        CommonOps_DDRM.insert(this.x1, dMatrixRMaj2, 0, 0);
        CommonOps_DDRM.insert(this.x2, dMatrixRMaj2, this.x1.numRows, 0);
    }

    @Override // org.ddogleg.optimization.math.HessianMath
    public void divideRowsCols(DMatrixRMaj dMatrixRMaj) {
        double[] dArr = dMatrixRMaj.data;
        divideRowsCols(dArr, 0, this.A, dArr, 0);
        divideRowsCols(dArr, 0, this.B, dArr, this.A.getNumCols());
        divideRowsCols(dArr, this.A.getNumRows(), this.D, dArr, this.A.getNumCols());
    }

    protected abstract void divideRowsCols(double[] dArr, int i, S s, double[] dArr2, int i2);

    protected abstract void extractDiag(S s, DMatrixRMaj dMatrixRMaj);

    @Override // org.ddogleg.optimization.math.HessianMath
    public void extractDiagonals(DMatrixRMaj dMatrixRMaj) {
        extractDiag(this.A, this.x1);
        extractDiag(this.D, this.x2);
        dMatrixRMaj.reshape(this.A.getNumCols() + this.D.getNumCols(), 1);
        CommonOps_DDRM.insert(this.x1, dMatrixRMaj, 0, 0);
        CommonOps_DDRM.insert(this.x2, dMatrixRMaj, this.x1.numRows, 0);
    }

    @Override // org.ddogleg.optimization.math.HessianMath
    public void init(int i) {
    }

    @Override // org.ddogleg.optimization.math.HessianMath
    public boolean initializeSolver() {
        return this.solverA.setA(this.A);
    }

    protected abstract double innerProduct(double[] dArr, int i, S s, double[] dArr2, int i2);

    @Override // org.ddogleg.optimization.math.HessianMath
    public double innerVectorHessian(DMatrixRMaj dMatrixRMaj) {
        int numRows = this.A.getNumRows();
        double[] dArr = dMatrixRMaj.data;
        double innerProduct = innerProduct(dArr, 0, this.A, dArr, 0) + 0.0d;
        double[] dArr2 = dMatrixRMaj.data;
        double innerProduct2 = innerProduct + (innerProduct(dArr2, 0, this.B, dArr2, numRows) * 2.0d);
        double[] dArr3 = dMatrixRMaj.data;
        return innerProduct2 + innerProduct(dArr3, numRows, this.D, dArr3, numRows);
    }

    protected abstract void mult(S s, DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2);

    protected abstract void multTransA(S s, S s2, S s3);

    protected abstract void multTransA(S s, DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2);

    @Override // org.ddogleg.optimization.math.HessianMath
    public void setDiagonals(DMatrixRMaj dMatrixRMaj) {
        int numCols = this.A.getNumCols();
        for (int i = 0; i < numCols; i++) {
            this.A.set(i, i, dMatrixRMaj.data[i]);
        }
        int numCols2 = this.D.getNumCols();
        for (int i2 = 0; i2 < numCols2; i2++) {
            this.D.set(i2, i2, dMatrixRMaj.data[i2 + numCols]);
        }
    }

    @Override // org.ddogleg.optimization.math.HessianMath
    public boolean solve(DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2) {
        CommonOps_DDRM.extract(dMatrixRMaj, 0, this.A.getNumCols(), 0, dMatrixRMaj.numCols, this.b1);
        CommonOps_DDRM.extract(dMatrixRMaj, this.A.getNumCols(), dMatrixRMaj.numRows, 0, dMatrixRMaj.numCols, this.b2);
        this.x.reshape(this.A.getNumRows(), 1);
        this.solverA.solve(this.b1, this.x);
        multTransA((HessianSchurComplement_Base<S>) this.B, this.x, this.b2_m);
        DMatrixRMaj dMatrixRMaj3 = this.b2;
        DMatrixRMaj dMatrixRMaj4 = this.b2_m;
        CommonOps_DDRM.subtract(dMatrixRMaj3, dMatrixRMaj4, dMatrixRMaj4);
        ((ReshapeMatrix) this.D_m).reshape(this.A.getNumRows(), this.B.getNumCols());
        this.solverA.solveSparse(this.B, this.D_m);
        multTransA(this.B, this.D_m, this.tmp0);
        add(1.0d, this.D, -1.0d, this.tmp0, this.D_m);
        if (!this.solverD.setA(this.D_m)) {
            return false;
        }
        this.x2.reshape(this.D_m.getNumRows(), this.b2_m.numCols);
        this.solverD.solve(this.b2_m, this.x2);
        mult(this.B, this.x2, this.x1);
        DMatrixRMaj dMatrixRMaj5 = this.b1;
        CommonOps_DDRM.subtract(dMatrixRMaj5, this.x1, dMatrixRMaj5);
        this.solverA.solve(this.b1, this.x1);
        CommonOps_DDRM.insert(this.x1, dMatrixRMaj2, 0, 0);
        CommonOps_DDRM.insert(this.x2, dMatrixRMaj2, this.x1.numRows, 0);
        return true;
    }
}
