package org.encog.mathutil.error;

/* loaded from: input_file:org/encog/mathutil/error/ErrorCalculation.class */
public class ErrorCalculation {
    private static ErrorCalculationMode mode = ErrorCalculationMode.MSE;
    private double globalError;
    private int setSize;
    private double sum;
    private double min;
    private double max;

    public static ErrorCalculationMode getMode() {
        return mode;
    }

    public static void setMode(ErrorCalculationMode errorCalculationMode) {
        mode = errorCalculationMode;
    }

    public final double calculate() {
        if (this.setSize == 0) {
            return 0.0d;
        }
        switch (getMode()) {
            case RMS:
                return calculateRMS();
            case MSE:
                return calculateMSE();
            case ESS:
                return calculateESS();
            case LOGLOSS:
            case HOT_LOGLOSS:
                return calculateLogLoss();
            case NRMSE_MEAN:
                return calculateMeanNRMSE();
            case NRMSE_RANGE:
                return calculateRangeNRMSE();
            default:
                return calculateMSE();
        }
    }

    public final double calculateMSE() {
        if (this.setSize == 0) {
            return 0.0d;
        }
        return this.globalError / this.setSize;
    }

    public final double calculateESS() {
        if (this.setSize == 0) {
            return 0.0d;
        }
        return this.globalError / 2.0d;
    }

    public final double calculateMeanNRMSE() {
        return calculateRMS() / (this.sum / this.setSize);
    }

    public final double calculateRangeNRMSE() {
        return calculateRMS() / (this.max - this.min);
    }

    public final double calculateRMS() {
        if (this.setSize == 0) {
            return 0.0d;
        }
        return Math.sqrt(this.globalError / this.setSize);
    }

    public final double calculateLogLoss() {
        return this.globalError * ((-1.0d) / this.setSize);
    }

    public final void reset() {
        this.globalError = 0.0d;
        this.setSize = 0;
    }

    public final void updateError(double d, double d2) {
        if (getMode() == ErrorCalculationMode.LOGLOSS || getMode() == ErrorCalculationMode.HOT_LOGLOSS) {
            this.globalError += Math.log(d) * d2;
            this.setSize++;
            return;
        }
        double d3 = d2 - d;
        this.globalError += d3 * d3;
        this.sum += d2;
        if (this.setSize == 0) {
            this.max = d;
            this.min = d;
        } else {
            this.min = Math.min(d, this.min);
            this.max = Math.max(d, this.max);
        }
        this.setSize++;
    }

    public final void updateError(double[] dArr, double[] dArr2, double d) {
        if (getMode() == ErrorCalculationMode.HOT_LOGLOSS) {
            this.setSize++;
            for (int i = 0; i < dArr.length; i++) {
                if (dArr2[i] > 1.0E-13d) {
                    this.globalError += Math.log(dArr[i]) * dArr2[i];
                }
            }
            return;
        }
        if (getMode() == ErrorCalculationMode.LOGLOSS) {
            this.setSize++;
            this.globalError += Math.log(dArr[(int) dArr2[0]]);
            return;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d2 = (dArr2[i2] - dArr[i2]) * d;
            this.sum += dArr2[i2];
            if (this.setSize == 0) {
                double d3 = dArr[i2];
                this.max = d3;
                this.min = d3;
            } else {
                this.min = Math.min(dArr[i2], this.min);
                this.max = Math.max(dArr[i2], this.max);
            }
            this.globalError += d2 * d2;
        }
        this.setSize += dArr2.length;
    }
}
