package opennlp.tools.ml.perceptron;

import j$.util.Map;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import opennlp.tools.ml.AbstractEventModelSequenceTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Sequence;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.ml.model.SequenceStreamEventStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: classes2.dex */
public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceTrainer {
    private static final int EVENT = 2;
    private static final int ITER = 1;
    public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
    private static final int VALUE = 0;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) SimplePerceptronSequenceTrainer.class);
    private MutableContext[] averageParams;
    private int iterations;
    private int numEvents;
    private int numOutcomes;
    private int numPreds;
    private int numSequences;
    private Map<String, Integer> omap;
    private String[] outcomeLabels;
    private int[] outcomeList;
    private MutableContext[] params;
    private Map<String, Integer> pmap;
    private String[] predLabels;
    private SequenceStream<Event> sequenceStream;
    private int[][][] updates;
    private boolean useAverage;

    private void findParameters(int i2) throws IOException {
        logger.info("Performing {} iterations.\n", Integer.valueOf(i2));
        for (int i3 = 1; i3 <= i2; i3++) {
            nextIteration(i3);
        }
        if (this.useAverage) {
            trainingStats(this.averageParams);
        } else {
            trainingStats(this.params);
        }
    }

    private void trainingStats(MutableContext[] mutableContextArr) throws IOException {
        this.sequenceStream.reset();
        int i2 = 0;
        int i3 = 0;
        while (true) {
            Sequence read = this.sequenceStream.read();
            if (read == null) {
                logger.info(". ({}/{}) {}", Integer.valueOf(i2), Integer.valueOf(this.numEvents), Double.valueOf(i2 / this.numEvents));
                return;
            }
            Event[] updateContext = this.sequenceStream.updateContext(read, new PerceptronModel(mutableContextArr, this.predLabels, this.outcomeLabels));
            int i4 = 0;
            while (i4 < updateContext.length) {
                if (this.omap.get(updateContext[i4].getOutcome()).intValue() == this.outcomeList[i3]) {
                    i2++;
                }
                i4++;
                i3++;
            }
        }
    }

    @Override // opennlp.tools.ml.AbstractEventModelSequenceTrainer
    public AbstractModel doTrain(SequenceStream<Event> sequenceStream) throws IOException {
        return trainModel(getIterations(), sequenceStream, getCutoff(), this.trainingParameters.getBooleanParameter("UseAverage", true));
    }

    @Override // opennlp.tools.ml.AbstractEventModelSequenceTrainer
    public /* bridge */ /* synthetic */ MaxentModel doTrain(SequenceStream sequenceStream) throws IOException {
        return doTrain((SequenceStream<Event>) sequenceStream);
    }

    @Override // opennlp.tools.ml.AbstractTrainer
    @Deprecated
    public boolean isValid() {
        try {
            validate();
            return true;
        } catch (IllegalArgumentException unused) {
            return false;
        }
    }

    public void nextIteration(int i2) throws IOException {
        double[] dArr;
        double d;
        int i3;
        int i4;
        ArrayList arrayList;
        int i5;
        Iterator it;
        ArrayList arrayList2;
        int i6;
        int i7 = 1;
        int i8 = i2 - 1;
        ArrayList arrayList3 = new ArrayList(this.numOutcomes);
        for (int i9 = 0; i9 < this.numOutcomes; i9++) {
            arrayList3.add(new HashMap());
        }
        PerceptronModel perceptronModel = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
        this.sequenceStream.reset();
        int i10 = 0;
        int i11 = 0;
        int i12 = 0;
        while (true) {
            Sequence read = this.sequenceStream.read();
            if (read == null) {
                break;
            }
            Event[] updateContext = this.sequenceStream.updateContext(read, perceptronModel);
            Event[] events = read.getEvents();
            int i13 = 0;
            boolean z2 = false;
            while (i13 < events.length) {
                if (updateContext[i13].getOutcome().equals(events[i13].getOutcome())) {
                    i11 += i7;
                } else {
                    z2 = true;
                }
                i13 += i7;
                i12 += i7;
            }
            if (z2) {
                for (int i14 = 0; i14 < this.numOutcomes; i14 += i7) {
                    ((Map) arrayList3.get(i14)).clear();
                }
                if (logger.isTraceEnabled()) {
                    StringBuilder sb = new StringBuilder();
                    int length = events.length;
                    for (int i15 = 0; i15 < length; i15 += i7) {
                        Event event = events[i15];
                        sb.append(" ");
                        sb.append(event.getOutcome());
                    }
                    logger.trace("train: {}", sb);
                }
                int i16 = 0;
                while (i16 < events.length) {
                    String[] context = events[i16].getContext();
                    float[] values = events[i16].getValues();
                    int intValue = this.omap.get(events[i16].getOutcome()).intValue();
                    int i17 = 0;
                    while (i17 < context.length) {
                        float f2 = values != null ? values[i17] : 1.0f;
                        int i18 = i11;
                        Float f3 = (Float) ((Map) arrayList3.get(intValue)).get(context[i17]);
                        ((Map) arrayList3.get(intValue)).put(context[i17], f3 == null ? Float.valueOf(f2) : Float.valueOf(f3.floatValue() + f2));
                        i17++;
                        i11 = i18;
                    }
                    i16++;
                    i12++;
                    i11 = i11;
                }
                i4 = i11;
                if (logger.isTraceEnabled()) {
                    StringBuilder sb2 = new StringBuilder();
                    for (Event event2 : updateContext) {
                        sb2.append(" ");
                        sb2.append(event2.getOutcome());
                    }
                    logger.trace("test: {}", sb2);
                }
                for (Event event3 : updateContext) {
                    String[] context2 = event3.getContext();
                    float[] values2 = event3.getValues();
                    int intValue2 = this.omap.get(event3.getOutcome()).intValue();
                    for (int i19 = 0; i19 < context2.length; i19++) {
                        float f4 = values2 != null ? values2[i19] : 1.0f;
                        Float f5 = (Float) ((Map) arrayList3.get(intValue2)).get(context2[i19]);
                        Float valueOf = f5 == null ? Float.valueOf(f4 * (-1.0f)) : Float.valueOf(f5.floatValue() - f4);
                        if (valueOf.floatValue() == 0.0f) {
                            ((Map) arrayList3.get(intValue2)).remove(context2[i19]);
                        } else {
                            ((Map) arrayList3.get(intValue2)).put(context2[i19], valueOf);
                        }
                    }
                }
                for (int i20 = 0; i20 < this.numOutcomes; i20++) {
                    Iterator it2 = ((Map) arrayList3.get(i20)).keySet().iterator();
                    while (it2.hasNext()) {
                        String str = (String) it2.next();
                        Integer num = (Integer) Map.EL.getOrDefault(this.pmap, str, -1);
                        int intValue3 = num.intValue();
                        if (intValue3 != -1) {
                            Logger logger2 = logger;
                            if (logger2.isTraceEnabled()) {
                                logger2.trace("{} {} {} {}", Integer.valueOf(i10), this.outcomeLabels[i20], str, ((java.util.Map) arrayList3.get(i20)).get(str));
                            }
                            this.params[intValue3].updateParameter(i20, ((Float) ((java.util.Map) arrayList3.get(i20)).get(str)).floatValue());
                            if (this.useAverage) {
                                if (this.updates[intValue3][i20][0] != 0) {
                                    this.averageParams[intValue3].updateParameter(i20, ((i10 - r4[2]) + ((i8 - r4[1]) * this.numSequences)) * r9);
                                    if (logger2.isTraceEnabled()) {
                                        logger2.trace("p avp[{}].{}={}", num, Integer.valueOf(i20), Double.valueOf(this.averageParams[intValue3].getParameters()[i20]));
                                    }
                                }
                                if (logger2.isTraceEnabled()) {
                                    it = it2;
                                    arrayList2 = arrayList3;
                                    i5 = i10;
                                    i6 = i12;
                                    logger2.trace("p updates[{}]{{}]=({},{},{})({},{},{}) -> {}", num, Integer.valueOf(i20), Integer.valueOf(this.updates[intValue3][i20][1]), Integer.valueOf(this.updates[intValue3][i20][2]), Integer.valueOf(this.updates[intValue3][i20][0]), Integer.valueOf(i8), Integer.valueOf(i12), Double.valueOf(this.params[intValue3].getParameters()[i20]), Double.valueOf(this.averageParams[intValue3].getParameters()[i20]));
                                } else {
                                    i5 = i10;
                                    it = it2;
                                    arrayList2 = arrayList3;
                                    i6 = i12;
                                }
                                this.updates[intValue3][i20][0] = (int) this.params[intValue3].getParameters()[i20];
                                int[] iArr = this.updates[intValue3][i20];
                                iArr[1] = i8;
                                iArr[2] = i5;
                                it2 = it;
                                arrayList3 = arrayList2;
                                i12 = i6;
                                i10 = i5;
                            }
                        }
                        i5 = i10;
                        it = it2;
                        arrayList2 = arrayList3;
                        i6 = i12;
                        it2 = it;
                        arrayList3 = arrayList2;
                        i12 = i6;
                        i10 = i5;
                    }
                }
                i3 = i10;
                arrayList = arrayList3;
                perceptronModel = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
            } else {
                i3 = i10;
                i4 = i11;
                arrayList = arrayList3;
            }
            i10 = i3 + 1;
            i11 = i4;
            arrayList3 = arrayList;
            i7 = 1;
        }
        int i21 = i10;
        int i22 = this.iterations;
        double d2 = i22 * i21;
        if (this.useAverage && i8 == i22 - 1) {
            int i23 = 0;
            while (i23 < this.numPreds) {
                double[] parameters = this.averageParams[i23].getParameters();
                int i24 = 0;
                while (i24 < this.numOutcomes) {
                    if (this.updates[i23][i24][0] != 0) {
                        parameters[i24] = parameters[i24] + ((((this.iterations - r6[1]) * this.numSequences) - r6[2]) * r8);
                    }
                    double d3 = parameters[i24];
                    if (d3 != 0.0d) {
                        double d4 = d3 / d2;
                        parameters[i24] = d4;
                        this.averageParams[i23].setParameter(i24, d4);
                        Logger logger3 = logger;
                        if (logger3.isTraceEnabled()) {
                            dArr = parameters;
                            d = d2;
                            logger3.trace("updates[{}][{}]=({},{},{})({},{},{}) -> {}", Integer.valueOf(i23), Integer.valueOf(i24), Integer.valueOf(this.updates[i23][i24][1]), Integer.valueOf(this.updates[i23][i24][2]), Integer.valueOf(this.updates[i23][i24][0]), Integer.valueOf(this.iterations), 0, Double.valueOf(this.params[i23].getParameters()[i24]), Double.valueOf(this.averageParams[i23].getParameters()[i24]));
                            i24++;
                            parameters = dArr;
                            d2 = d;
                        }
                    }
                    dArr = parameters;
                    d = d2;
                    i24++;
                    parameters = dArr;
                    d2 = d;
                }
                i23++;
                d2 = d2;
            }
        }
        logger.info("{}. ({}/{}) {}", Integer.valueOf(i8), Integer.valueOf(i11), Integer.valueOf(this.numEvents), Double.valueOf(i11 / this.numEvents));
    }

    public AbstractModel trainModel(int i2, SequenceStream<Event> sequenceStream, int i3, boolean z2) throws IOException {
        this.iterations = i2;
        this.sequenceStream = sequenceStream;
        this.trainingParameters.put("Cutoff", i3);
        this.trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
        OnePassDataIndexer onePassDataIndexer = new OnePassDataIndexer();
        onePassDataIndexer.init(this.trainingParameters, this.reportMap);
        onePassDataIndexer.index(new SequenceStreamEventStream(sequenceStream));
        this.numSequences = 0;
        sequenceStream.reset();
        while (sequenceStream.read() != null) {
            this.numSequences++;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        this.predLabels = onePassDataIndexer.getPredLabels();
        this.pmap = new HashMap();
        int i4 = 0;
        while (true) {
            String[] strArr = this.predLabels;
            if (i4 >= strArr.length) {
                break;
            }
            this.pmap.put(strArr[i4], Integer.valueOf(i4));
            i4++;
        }
        logger.info("Incorporating indexed data for training... ");
        this.useAverage = z2;
        this.numEvents = onePassDataIndexer.getNumEvents();
        this.iterations = i2;
        this.outcomeLabels = onePassDataIndexer.getOutcomeLabels();
        this.omap = new HashMap();
        int i5 = 0;
        while (true) {
            String[] strArr2 = this.outcomeLabels;
            if (i5 >= strArr2.length) {
                break;
            }
            this.omap.put(strArr2[i5], Integer.valueOf(i5));
            i5++;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        int length = this.predLabels.length;
        this.numPreds = length;
        int length2 = this.outcomeLabels.length;
        this.numOutcomes = length2;
        if (z2) {
            this.updates = (int[][][]) Array.newInstance((Class<?>) Integer.TYPE, length, length2, 3);
        }
        Logger logger2 = logger;
        logger2.info("done.");
        logger2.info("\tNumber of Event Tokens: {} \n\t Number of Outcomes: {} \n\t Number of Predicates: {}", Integer.valueOf(this.numEvents), Integer.valueOf(this.numOutcomes), Integer.valueOf(this.numPreds));
        int i6 = this.numPreds;
        this.params = new MutableContext[i6];
        if (z2) {
            this.averageParams = new MutableContext[i6];
        }
        int[] iArr = new int[this.numOutcomes];
        for (int i7 = 0; i7 < this.numOutcomes; i7++) {
            iArr[i7] = i7;
        }
        for (int i8 = 0; i8 < this.numPreds; i8++) {
            this.params[i8] = new MutableContext(iArr, new double[this.numOutcomes]);
            if (z2) {
                this.averageParams[i8] = new MutableContext(iArr, new double[this.numOutcomes]);
            }
            for (int i9 = 0; i9 < this.numOutcomes; i9++) {
                this.params[i8].setParameter(i9, 0.0d);
                if (z2) {
                    this.averageParams[i8].setParameter(i9, 0.0d);
                }
            }
        }
        Logger logger3 = logger;
        logger3.info("Computing model parameters...");
        findParameters(i2);
        logger3.info("...done.");
        String[] strArr3 = this.predLabels;
        return z2 ? new PerceptronModel(this.averageParams, strArr3, this.outcomeLabels) : new PerceptronModel(this.params, strArr3, this.outcomeLabels);
    }

    @Override // opennlp.tools.ml.AbstractTrainer
    public void validate() {
        super.validate();
        String algorithm = getAlgorithm();
        if (algorithm != null && !PERCEPTRON_SEQUENCE_VALUE.equals(algorithm)) {
            throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE");
        }
    }
}
