package org.encog.ml.train.strategy.end;

import java.io.Serializable;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.util.obj.SerializeObject;
import org.encog.util.simple.EncogUtility;

/* loaded from: input_file:org/encog/ml/train/strategy/end/EarlyStoppingStrategy.class */
public class EarlyStoppingStrategy implements EndTrainingStrategy {
    private MLDataSet validationSet;
    private MLTrain train;
    private boolean stop;
    private double trainingError;
    private double lastValidationError;
    private MLRegression model;
    private int checkFrequency;
    private int lastCheck;
    private int allowedStagnantIterations;
    private int stagnantIterations;
    private MLRegression bestModel;
    private boolean saveBest;
    private double bestValidationError;
    private double minimumImprovement;

    public EarlyStoppingStrategy(MLDataSet mLDataSet) {
        this(mLDataSet, 5, 50);
    }

    public EarlyStoppingStrategy(MLDataSet mLDataSet, int i, int i2) {
        this.minimumImprovement = 1.0E-13d;
        this.validationSet = mLDataSet;
        this.checkFrequency = i;
        this.allowedStagnantIterations = i2;
    }

    @Override // org.encog.ml.train.strategy.Strategy
    public void init(MLTrain mLTrain) {
        this.train = mLTrain;
        this.model = (MLRegression) this.train.getMethod();
        this.stop = false;
        this.lastCheck = 0;
        this.lastValidationError = Double.POSITIVE_INFINITY;
    }

    @Override // org.encog.ml.train.strategy.Strategy
    public void preIteration() {
    }

    @Override // org.encog.ml.train.strategy.Strategy
    public void postIteration() {
        this.lastCheck++;
        this.trainingError = this.train.getError();
        if (this.lastCheck > this.checkFrequency || Double.isInfinite(this.lastValidationError)) {
            double calculateRegressionError = EncogUtility.calculateRegressionError(this.model, this.validationSet);
            double max = Math.max(this.bestValidationError - calculateRegressionError, 0.0d);
            if (Double.isInfinite(calculateRegressionError) || Double.isNaN(calculateRegressionError)) {
                this.stop = true;
            } else if (this.bestValidationError > calculateRegressionError || Double.isInfinite(this.lastValidationError) || max >= this.minimumImprovement) {
                if (this.saveBest) {
                    this.bestModel = (MLRegression) SerializeObject.serializeClone((Serializable) this.model);
                }
                this.bestValidationError = calculateRegressionError;
                this.stagnantIterations = 0;
            } else {
                this.stagnantIterations += this.lastCheck;
                if (this.stagnantIterations > this.allowedStagnantIterations) {
                    this.stop = true;
                }
            }
            this.lastValidationError = calculateRegressionError;
            this.lastCheck = 0;
        }
    }

    @Override // org.encog.ml.train.strategy.end.EndTrainingStrategy
    public boolean shouldStop() {
        return this.stop;
    }

    public double getTrainingError() {
        return this.trainingError;
    }

    public double getValidationError() {
        return this.lastValidationError;
    }

    public int getStagnantIterations() {
        return this.stagnantIterations;
    }

    public void setStagnantIterations(int i) {
        this.stagnantIterations = i;
    }

    public int getAllowedStagnantIterations() {
        return this.allowedStagnantIterations;
    }

    public void setAllowedStagnantIterations(int i) {
        this.allowedStagnantIterations = i;
    }

    public boolean isSaveBest() {
        return this.saveBest;
    }

    public void setSaveBest(boolean z) {
        this.saveBest = z;
    }

    public MLRegression getBestModel() {
        return this.bestModel;
    }

    public double getBestValidationError() {
        return this.bestValidationError;
    }

    public double getMinimumImprovement() {
        return this.minimumImprovement;
    }

    public void setMinimumImprovement(double d) {
        this.minimumImprovement = d;
    }
}
