package org.encog.neural.networks.training.propagation.sgd;

import com.github.sarxos.webcam.WebcamMotionDetector;
import java.util.Iterator;
import org.encog.EncogError;
import org.encog.mathutil.randomize.generate.GenerateRandom;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/sgd/BatchDataSet.class */
public class BatchDataSet implements MLDataSet {
    private MLDataSet dataset;
    private int currentIndex;
    private int batchSize;
    private GenerateRandom random;
    private boolean randomBatches;
    private int[] randomSample;

    /* loaded from: input_file:org/encog/neural/networks/training/propagation/sgd/BatchDataSet$BatchedMLIterator.class */
    public class BatchedMLIterator implements Iterator<MLDataPair> {
        private int currentIndex = 0;

        public BatchedMLIterator() {
        }

        @Override // java.util.Iterator
        public final boolean hasNext() {
            return this.currentIndex < BatchDataSet.this.getBatchSize();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public final MLDataPair next() {
            if (!hasNext()) {
                return null;
            }
            BatchDataSet batchDataSet = BatchDataSet.this;
            int i = this.currentIndex;
            this.currentIndex = i + 1;
            return batchDataSet.get(i);
        }

        @Override // java.util.Iterator
        public final void remove() {
            throw new EncogError("Called remove, unsupported operation.");
        }
    }

    public BatchDataSet(MLDataSet mLDataSet, GenerateRandom generateRandom) {
        this.dataset = mLDataSet;
        this.random = generateRandom;
        setBatchSize(WebcamMotionDetector.DEFAULT_INTERVAL);
    }

    public void setBatchSize(int i) {
        this.batchSize = Math.min(i, this.dataset.size());
        this.randomSample = new int[this.batchSize];
        if (this.randomBatches) {
            generaterandomSample();
        }
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    @Override // java.lang.Iterable
    public Iterator<MLDataPair> iterator() {
        return new BatchedMLIterator();
    }

    @Override // org.encog.ml.data.MLDataSet
    public int getIdealSize() {
        return this.dataset.getIdealSize();
    }

    @Override // org.encog.ml.data.MLDataSet
    public int getInputSize() {
        return this.dataset.getInputSize();
    }

    @Override // org.encog.ml.data.MLDataSet
    public boolean isSupervised() {
        return this.dataset.isSupervised();
    }

    @Override // org.encog.ml.data.MLDataSet
    public long getRecordCount() {
        return this.batchSize;
    }

    @Override // org.encog.ml.data.MLDataSet
    public void getRecord(long j, MLDataPair mLDataPair) {
        this.dataset.getRecord((j + this.currentIndex) % this.dataset.size(), mLDataPair);
    }

    @Override // org.encog.ml.data.MLDataSet
    public MLDataSet openAdditional() {
        BatchDataSet batchDataSet = new BatchDataSet(this.dataset, new MersenneTwisterGenerateRandom(this.random.nextLong()));
        batchDataSet.setBatchSize(getBatchSize());
        return batchDataSet;
    }

    @Override // org.encog.ml.data.MLDataSet
    public void add(MLData mLData) {
        throw new EncogError("Unsupported.");
    }

    @Override // org.encog.ml.data.MLDataSet
    public void add(MLData mLData, MLData mLData2) {
        throw new EncogError("Unsupported.");
    }

    @Override // org.encog.ml.data.MLDataSet
    public void add(MLDataPair mLDataPair) {
        throw new EncogError("Unsupported.");
    }

    @Override // org.encog.ml.data.MLDataSet
    public void close() {
    }

    @Override // org.encog.ml.data.MLDataSet
    public int size() {
        return this.batchSize;
    }

    @Override // org.encog.ml.data.MLDataSet
    public MLDataPair get(int i) {
        int size = (i + this.currentIndex) % this.dataset.size();
        if (this.randomBatches) {
            size = this.randomSample[size];
        }
        return this.dataset.get(size);
    }

    public void advance() {
        if (this.randomBatches) {
            generaterandomSample();
        } else {
            this.currentIndex = (this.currentIndex + this.batchSize) % this.dataset.size();
        }
    }

    public int getCurrentIndex() {
        return this.currentIndex;
    }

    public void setCurrentIndex(int i) {
        this.currentIndex = i;
    }

    public boolean isRandomBatches() {
        return this.randomBatches;
    }

    public void setRandomBatches(boolean z) {
        this.randomBatches = z;
    }

    private void generaterandomSample() {
        int nextInt;
        for (int i = 0; i < this.batchSize; i++) {
            boolean z = true;
            do {
                nextInt = this.random.nextInt(0, this.dataset.size());
                int i2 = 0;
                while (true) {
                    if (i2 >= i) {
                        break;
                    }
                    if (this.randomSample[i2] == nextInt) {
                        z = false;
                        break;
                    }
                    i2++;
                }
            } while (!z);
            this.randomSample[i] = nextInt;
        }
    }
}
