package org.encog.ml.importance;

import org.encog.EncogError;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLContext;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.util.EngineArray;

/* loaded from: input_file:org/encog/ml/importance/PerturbationFeatureImportanceCalc.class */
public class PerturbationFeatureImportanceCalc extends AbstractFeatureImportance {
    private GenerateRandom rnd = new MersenneTwisterGenerateRandom();
    private double[] shuffleColumn;

    @Override // org.encog.ml.importance.FeatureImportance
    public void performRanking() {
        throw new EncogError("This algorithm requires a dataset to measure performance against, please call performRanking with a dataset.");
    }

    private double calculateRegressionError(MLDataSet mLDataSet, int i) {
        ErrorCalculation errorCalculation = new ErrorCalculation();
        if (getModel() instanceof MLContext) {
            ((MLContext) getModel()).clearContext();
        }
        for (int i2 = 0; i2 < mLDataSet.size(); i2++) {
            this.shuffleColumn[i2] = mLDataSet.get(i2).getInput().getData(i);
        }
        BasicMLData basicMLData = new BasicMLData(mLDataSet.getInputSize());
        try {
            int size = mLDataSet.size();
            for (int i3 = 0; i3 < size; i3++) {
                MLDataPair mLDataPair = mLDataSet.get(i3);
                EngineArray.arrayCopy(mLDataPair.getInput().getData(), basicMLData.getData());
                if (i3 != size - 1) {
                    int nextInt = this.rnd.nextInt(mLDataSet.size() - i3);
                    double d = this.shuffleColumn[i3];
                    this.shuffleColumn[i3] = this.shuffleColumn[nextInt];
                    this.shuffleColumn[nextInt] = d;
                    basicMLData.setData(i, this.shuffleColumn[i3]);
                }
                errorCalculation.updateError(getModel().compute(basicMLData).getData(), mLDataPair.getIdeal().getData(), mLDataPair.getSignificance());
            }
            return errorCalculation.calculate();
        } catch (EncogError e) {
            return Double.NaN;
        }
    }

    @Override // org.encog.ml.importance.FeatureImportance
    public void performRanking(MLDataSet mLDataSet) {
        this.shuffleColumn = new double[mLDataSet.size()];
        double d = 0.0d;
        for (int i = 0; i < getModel().getInputCount(); i++) {
            FeatureRank featureRank = getFeatures().get(i);
            double calculateRegressionError = calculateRegressionError(mLDataSet, i);
            featureRank.setTotalWeight(calculateRegressionError);
            d = Math.max(d, calculateRegressionError);
        }
        for (FeatureRank featureRank2 : getFeatures()) {
            featureRank2.setImportancePercent(featureRank2.getTotalWeight() / d);
        }
    }

    public GenerateRandom getRnd() {
        return this.rnd;
    }

    public void setRnd(GenerateRandom generateRandom) {
        this.rnd = generateRandom;
    }
}
