package smile.feature;

import java.util.function.IntFunction;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.DoubleVector;
import smile.math.MathEx;

/* loaded from: classes5.dex */
public class Scaler implements FeatureTransform {
    private static final long serialVersionUID = 2;
    double[] hi;
    double[] lo;
    StructType schema;

    public Scaler(StructType structType, double[] dArr, double[] dArr2) {
        if (structType.length() != dArr.length || dArr.length != dArr2.length) {
            throw new IllegalArgumentException("Schema and scaling factor size don't match");
        }
        this.schema = structType;
        this.lo = dArr;
        this.hi = dArr2;
        for (int i = 0; i < dArr.length; i++) {
            double d = dArr2[i] - dArr[i];
            dArr2[i] = d;
            if (MathEx.isZero(d)) {
                dArr2[i] = 1.0d;
            }
        }
    }

    public static Scaler fit(DataFrame dataFrame) {
        if (dataFrame.isEmpty()) {
            throw new IllegalArgumentException("Empty data frame");
        }
        StructType schema = dataFrame.schema();
        int length = schema.length();
        double[] dArr = new double[length];
        double[] dArr2 = new double[schema.length()];
        for (int i = 0; i < length; i++) {
            if (schema.field(i).isNumeric()) {
                dArr[i] = dataFrame.doubleVector(i).stream().min().getAsDouble();
                dArr2[i] = dataFrame.doubleVector(i).stream().max().getAsDouble();
            }
        }
        return new Scaler(schema, dArr, dArr2);
    }

    public static Scaler fit(double[][] dArr) {
        return fit(DataFrame.of(dArr, new String[0]));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double scale(double d, int i) {
        double d2 = (d - this.lo[i]) / this.hi[i];
        if (d2 < 0.0d) {
            d2 = 0.0d;
        }
        if (d2 > 1.0d) {
            return 1.0d;
        }
        return d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$toString$1$smile-feature-Scaler, reason: not valid java name */
    public /* synthetic */ String m10004lambda$toString$1$smilefeatureScaler(int i) {
        return String.format("%s[%.4f, %.4f]", this.schema.field(i).name, Double.valueOf(this.lo[i]), Double.valueOf(this.hi[i]));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$transform$0$smile-feature-Scaler, reason: not valid java name */
    public /* synthetic */ double m10005lambda$transform$0$smilefeatureScaler(int i, Tuple tuple) {
        return scale(tuple.getDouble(i), i);
    }

    public String toString() {
        return (String) IntStream.range(0, this.lo.length).mapToObj(new IntFunction() { // from class: smile.feature.Scaler$$ExternalSyntheticLambda0
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return Scaler.this.m10004lambda$toString$1$smilefeatureScaler(i);
            }
        }).collect(Collectors.joining(",", "Scaler(", ")"));
    }

    @Override // smile.feature.FeatureTransform
    public DataFrame transform(DataFrame dataFrame) {
        if (!this.schema.equals(dataFrame.schema())) {
            throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", dataFrame.schema(), this.schema));
        }
        BaseVector[] baseVectorArr = new BaseVector[this.schema.length()];
        for (final int i = 0; i < this.lo.length; i++) {
            StructField field = this.schema.field(i);
            if (field.isNumeric()) {
                baseVectorArr[i] = DoubleVector.of(field, dataFrame.stream().mapToDouble(new ToDoubleFunction() { // from class: smile.feature.Scaler$$ExternalSyntheticLambda1
                    @Override // java.util.function.ToDoubleFunction
                    public final double applyAsDouble(Object obj) {
                        return Scaler.this.m10005lambda$transform$0$smilefeatureScaler(i, (Tuple) obj);
                    }
                }));
            } else {
                baseVectorArr[i] = dataFrame.column(i);
            }
        }
        return DataFrame.of(baseVectorArr);
    }

    @Override // smile.feature.FeatureTransform
    public Tuple transform(final Tuple tuple) {
        if (this.schema.equals(tuple.schema())) {
            return new AbstractTuple() { // from class: smile.feature.Scaler.1
                @Override // smile.data.Tuple
                public Object get(int i) {
                    return Scaler.this.schema.field(i).isNumeric() ? Double.valueOf(Scaler.this.scale(tuple.getDouble(i), i)) : tuple.get(i);
                }

                @Override // smile.data.Tuple
                public StructType schema() {
                    return Scaler.this.schema;
                }
            };
        }
        throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", tuple.schema(), this.schema));
    }

    @Override // smile.feature.FeatureTransform
    public double[] transform(double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = scale(dArr[i], i);
        }
        return dArr2;
    }
}
