package smile.feature;

import java.lang.reflect.Array;
import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.gap.BitString;
import smile.gap.Chromosome;
import smile.gap.Crossover;
import smile.gap.Fitness;
import smile.gap.GeneticAlgorithm;
import smile.gap.Selection;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.validation.metric.ClassificationMetric;
import smile.validation.metric.RegressionMetric;

/* loaded from: classes5.dex */
public class GAFE {
    private Crossover crossover;
    private double crossoverRate;
    private int elitism;
    private double mutationRate;
    private Selection selection;

    public GAFE() {
        this(Selection.Tournament(3, 0.95d), 1, Crossover.TWO_POINT, 1.0d, 0.01d);
    }

    public GAFE(Selection selection, int i, Crossover crossover, double d, double d2) {
        this.selection = selection;
        this.elitism = i;
        this.crossover = crossover;
        this.crossoverRate = d;
        this.mutationRate = d2;
    }

    public static Fitness<BitString> fitness(final String str, final DataFrame dataFrame, final DataFrame dataFrame2, final ClassificationMetric classificationMetric, final BiFunction<Formula, DataFrame, DataFrameClassifier> biFunction) {
        final String[] names = dataFrame.names();
        final int[] intArray = dataFrame2.column(str).toIntArray();
        return new Fitness() { // from class: smile.feature.GAFE$$ExternalSyntheticLambda1
            @Override // smile.gap.Fitness
            public final double score(Chromosome chromosome) {
                return GAFE.lambda$fitness$2(names, str, biFunction, dataFrame, classificationMetric, intArray, dataFrame2, (BitString) chromosome);
            }
        };
    }

    public static Fitness<BitString> fitness(final String str, final DataFrame dataFrame, final DataFrame dataFrame2, final RegressionMetric regressionMetric, final BiFunction<Formula, DataFrame, DataFrameRegression> biFunction) {
        final String[] names = dataFrame.names();
        final double[] doubleArray = dataFrame2.column(str).toDoubleArray();
        return new Fitness() { // from class: smile.feature.GAFE$$ExternalSyntheticLambda2
            @Override // smile.gap.Fitness
            public final double score(Chromosome chromosome) {
                return GAFE.lambda$fitness$3(names, str, biFunction, dataFrame, regressionMetric, doubleArray, dataFrame2, (BitString) chromosome);
            }
        };
    }

    public static Fitness<BitString> fitness(final double[][] dArr, final double[] dArr2, final double[][] dArr3, final double[] dArr4, final RegressionMetric regressionMetric, final BiFunction<double[][], double[], Regression<double[]>> biFunction) {
        return new Fitness() { // from class: smile.feature.GAFE$$ExternalSyntheticLambda0
            @Override // smile.gap.Fitness
            public final double score(Chromosome chromosome) {
                return GAFE.lambda$fitness$1(dArr, dArr3, biFunction, dArr2, regressionMetric, dArr4, (BitString) chromosome);
            }
        };
    }

    public static Fitness<BitString> fitness(final double[][] dArr, final int[] iArr, final double[][] dArr2, final int[] iArr2, final ClassificationMetric classificationMetric, final BiFunction<double[][], int[], Classifier<double[]>> biFunction) {
        return new Fitness() { // from class: smile.feature.GAFE$$ExternalSyntheticLambda3
            @Override // smile.gap.Fitness
            public final double score(Chromosome chromosome) {
                return GAFE.lambda$fitness$0(dArr, dArr2, biFunction, iArr, classificationMetric, iArr2, (BitString) chromosome);
            }
        };
    }

    private static int[] indexOfOnes(byte[] bArr) {
        int sum = MathEx.sum(bArr);
        if (sum == 0) {
            return null;
        }
        int[] iArr = new int[sum];
        int i = 0;
        for (int i2 = 0; i2 < bArr.length; i2++) {
            if (bArr[i2] == 1) {
                iArr[i] = i2;
                i++;
            }
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$fitness$0(double[][] dArr, double[][] dArr2, BiFunction biFunction, int[] iArr, ClassificationMetric classificationMetric, int[] iArr2, BitString bitString) {
        int[] indexOfOnes = indexOfOnes(bitString.bits());
        if (indexOfOnes == null) {
            return 0.0d;
        }
        double[][] select = select(dArr, indexOfOnes);
        return classificationMetric.score(iArr2, ((Classifier) biFunction.apply(select, iArr)).predict((Object[]) select(dArr2, indexOfOnes)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$fitness$1(double[][] dArr, double[][] dArr2, BiFunction biFunction, double[] dArr3, RegressionMetric regressionMetric, double[] dArr4, BitString bitString) {
        int[] indexOfOnes = indexOfOnes(bitString.bits());
        if (indexOfOnes == null) {
            return Double.NEGATIVE_INFINITY;
        }
        double[][] select = select(dArr, indexOfOnes);
        return -regressionMetric.score(dArr4, ((Regression) biFunction.apply(select, dArr3)).predict((Object[]) select(dArr2, indexOfOnes)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$fitness$2(String[] strArr, String str, BiFunction biFunction, DataFrame dataFrame, ClassificationMetric classificationMetric, int[] iArr, DataFrame dataFrame2, BitString bitString) {
        String[] selectedFeatures = selectedFeatures(bitString.bits(), strArr, str);
        if (selectedFeatures == null) {
            return 0.0d;
        }
        return classificationMetric.score(iArr, ((DataFrameClassifier) biFunction.apply(Formula.of(str, selectedFeatures), dataFrame)).predict(dataFrame2));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$fitness$3(String[] strArr, String str, BiFunction biFunction, DataFrame dataFrame, RegressionMetric regressionMetric, double[] dArr, DataFrame dataFrame2, BitString bitString) {
        String[] selectedFeatures = selectedFeatures(bitString.bits(), strArr, str);
        if (selectedFeatures == null) {
            return Double.NEGATIVE_INFINITY;
        }
        return -regressionMetric.score(dArr, ((DataFrameRegression) biFunction.apply(Formula.of(str, selectedFeatures), dataFrame)).predict(dataFrame2));
    }

    private static double[][] select(double[][] dArr, int[] iArr) {
        int length = iArr.length;
        int length2 = dArr.length;
        double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length2, length);
        for (int i = 0; i < length2; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                dArr2[i][i2] = dArr[i][iArr[i2]];
            }
        }
        return dArr2;
    }

    private static String[] selectedFeatures(byte[] bArr, String[] strArr, String str) {
        int sum = MathEx.sum(bArr);
        if (sum == 0) {
            return null;
        }
        String[] strArr2 = new String[sum];
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < bArr.length; i3++) {
            if (strArr[i3].equals(str)) {
                i++;
            }
            if (bArr[i3] == 1) {
                strArr2[i2] = strArr[i3 + i];
                i2++;
            }
        }
        return strArr2;
    }

    public BitString[] apply(int i, int i2, int i3, Fitness<BitString> fitness) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + i);
        }
        BitString[] bitStringArr = new BitString[i];
        for (int i4 = 0; i4 < i; i4++) {
            bitStringArr[i4] = new BitString(i3, fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        new GeneticAlgorithm(bitStringArr, this.selection, this.elitism).evolve(i2);
        return bitStringArr;
    }
}
