package smile.classification;

import smile.base.rbf.RBF;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.util.IntSet;

/* loaded from: classes5.dex */
public class RBFNetwork<T> implements Classifier<T> {
    private static final long serialVersionUID = 2;
    private int k;
    private IntSet labels;
    private boolean normalized;
    private RBF<T>[] rbf;
    private Matrix w;

    public RBFNetwork(int i, RBF<T>[] rbfArr, Matrix matrix, boolean z) {
        this(i, rbfArr, matrix, z, IntSet.of(i));
    }

    public RBFNetwork(int i, RBF<T>[] rbfArr, Matrix matrix, boolean z, IntSet intSet) {
        this.k = i;
        this.rbf = rbfArr;
        this.w = matrix;
        this.normalized = z;
        this.labels = intSet;
    }

    public static <T> RBFNetwork<T> fit(T[] tArr, int[] iArr, RBF<T>[] rbfArr) {
        return fit(tArr, iArr, rbfArr, false);
    }

    public static <T> RBFNetwork<T> fit(T[] tArr, int[] iArr, RBF<T>[] rbfArr, boolean z) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        ClassLabels fit = ClassLabels.fit(iArr);
        int i = fit.k;
        int length = tArr.length;
        int length2 = rbfArr.length;
        Matrix matrix = new Matrix(length, length2 + 1);
        Matrix matrix2 = new Matrix(length, i);
        for (int i2 = 0; i2 < length; i2++) {
            double d = 0.0d;
            for (int i3 = 0; i3 < length2; i3++) {
                double f = rbfArr[i3].f(tArr[i2]);
                matrix.set(i2, i3, f);
                d += f;
            }
            matrix.set(i2, length2, 1.0d);
            if (z) {
                matrix2.set(i2, fit.y[i2], d);
            } else {
                matrix2.set(i2, fit.y[i2], 1.0d);
            }
        }
        matrix.qr(true).solve(matrix2);
        return new RBFNetwork<>(i, rbfArr, matrix2.submatrix(0, 0, length2, i - 1), z, fit.labels);
    }

    public boolean isNormalized() {
        return this.normalized;
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        int length = this.rbf.length;
        double[] dArr = new double[length + 1];
        dArr[length] = 1.0d;
        for (int i = 0; i < length; i++) {
            dArr[i] = this.rbf[i].f(t);
        }
        double[] dArr2 = new double[this.k];
        this.w.tv(dArr, dArr2);
        return this.labels.valueOf(MathEx.whichMax(dArr2));
    }
}
