package org.encog.ml.hmm.train.bw;

import java.util.Arrays;
import java.util.List;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

/* loaded from: input_file:org/encog/ml/hmm/train/bw/BaseBaumWelch.class */
public abstract class BaseBaumWelch implements MLTrain {
    private int iterations;
    private HiddenMarkovModel method;
    private final MLSequenceSet training;

    public BaseBaumWelch(HiddenMarkovModel hiddenMarkovModel, MLSequenceSet mLSequenceSet) {
        this.method = hiddenMarkovModel;
        this.training = mLSequenceSet;
    }

    @Override // org.encog.ml.train.MLTrain
    public void addStrategy(Strategy strategy) {
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    protected double[][] estimateGamma(double[][][] dArr, ForwardBackwardCalculator forwardBackwardCalculator) {
        double[][] dArr2 = new double[dArr.length + 1][dArr[0].length];
        for (int i = 0; i < dArr.length + 1; i++) {
            Arrays.fill(dArr2[i], 0.0d);
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 < dArr[0].length; i3++) {
                for (int i4 = 0; i4 < dArr[0].length; i4++) {
                    double[] dArr3 = dArr2[i2];
                    int i5 = i3;
                    dArr3[i5] = dArr3[i5] + dArr[i2][i3][i4];
                }
            }
        }
        for (int i6 = 0; i6 < dArr[0].length; i6++) {
            for (int i7 = 0; i7 < dArr[0].length; i7++) {
                double[] dArr4 = dArr2[dArr.length];
                int i8 = i6;
                dArr4[i8] = dArr4[i8] + dArr[dArr.length - 1][i7][i6];
            }
        }
        return dArr2;
    }

    public abstract double[][][] estimateXi(MLDataSet mLDataSet, ForwardBackwardCalculator forwardBackwardCalculator, HiddenMarkovModel hiddenMarkovModel);

    @Override // org.encog.ml.train.MLTrain
    public void finishTraining() {
    }

    public abstract ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel);

    @Override // org.encog.ml.train.MLTrain
    public double getError() {
        return 0.0d;
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override // org.encog.ml.train.MLTrain
    public int getIteration() {
        return this.iterations;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.method;
    }

    @Override // org.encog.ml.train.MLTrain
    public List<Strategy> getStrategies() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLDataSet getTraining() {
        return this.training;
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean isTrainingDone() {
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        try {
            HiddenMarkovModel m1212clone = this.method.m1212clone();
            double[][] dArr = new double[this.training.getSequenceCount()];
            double[][] dArr2 = new double[this.method.getStateCount()][this.method.getStateCount()];
            double[] dArr3 = new double[this.method.getStateCount()];
            Arrays.fill(dArr3, 0.0d);
            for (int i = 0; i < this.method.getStateCount(); i++) {
                Arrays.fill(dArr2[i], 0.0d);
            }
            int i2 = 0;
            for (MLDataSet mLDataSet : this.training.getSequences()) {
                ForwardBackwardCalculator generateForwardBackwardCalculator = generateForwardBackwardCalculator(mLDataSet, this.method);
                double[][][] estimateXi = estimateXi(mLDataSet, generateForwardBackwardCalculator, this.method);
                int i3 = i2;
                i2++;
                double[][] estimateGamma = estimateGamma(estimateXi, generateForwardBackwardCalculator);
                dArr[i3] = estimateGamma;
                for (int i4 = 0; i4 < this.method.getStateCount(); i4++) {
                    for (int i5 = 0; i5 < mLDataSet.size() - 1; i5++) {
                        int i6 = i4;
                        dArr3[i6] = dArr3[i6] + estimateGamma[i5][i4];
                        for (int i7 = 0; i7 < this.method.getStateCount(); i7++) {
                            double[] dArr4 = dArr2[i4];
                            int i8 = i7;
                            dArr4[i8] = dArr4[i8] + estimateXi[i5][i4][i7];
                        }
                    }
                }
            }
            for (int i9 = 0; i9 < this.method.getStateCount(); i9++) {
                if (dArr3[i9] == 0.0d) {
                    for (int i10 = 0; i10 < this.method.getStateCount(); i10++) {
                        m1212clone.setTransitionProbability(i9, i10, this.method.getTransitionProbability(i9, i10));
                    }
                } else {
                    for (int i11 = 0; i11 < this.method.getStateCount(); i11++) {
                        m1212clone.setTransitionProbability(i9, i11, dArr2[i9][i11] / dArr3[i9]);
                    }
                }
            }
            for (int i12 = 0; i12 < this.method.getStateCount(); i12++) {
                m1212clone.setPi(i12, 0.0d);
            }
            for (int i13 = 0; i13 < this.training.getSequenceCount(); i13++) {
                for (int i14 = 0; i14 < this.method.getStateCount(); i14++) {
                    m1212clone.setPi(i14, m1212clone.getPi(i14) + (dArr[i13][0][i14] / this.training.getSequenceCount()));
                }
            }
            for (int i15 = 0; i15 < this.method.getStateCount(); i15++) {
                double[] dArr5 = new double[this.training.size()];
                double d = 0.0d;
                int i16 = 0;
                int i17 = 0;
                for (MLDataSet mLDataSet2 : this.training.getSequences()) {
                    int i18 = 0;
                    while (i18 < mLDataSet2.size()) {
                        long j = dArr[i17][i18][i15];
                        dArr5[i16] = j;
                        d += j;
                        i18++;
                        i16++;
                    }
                    i17++;
                }
                while (true) {
                    i16--;
                    if (i16 >= 0) {
                        dArr5[i16] = dArr5[i16] / d;
                    }
                }
                m1212clone.getStateDistribution(i15).fit(this.training, dArr5);
            }
            this.method = m1212clone;
        } catch (CloneNotSupportedException e) {
            throw new InternalError();
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            iteration();
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.ml.train.MLTrain
    public void setError(double d) {
    }

    @Override // org.encog.ml.train.MLTrain
    public void setIteration(int i) {
        this.iterations = i;
    }
}
