package org.encog.neural.networks.training.propagation.sgd.update;

import org.encog.neural.networks.training.propagation.sgd.StochasticGradientDescent;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/sgd/update/AdamUpdate.class */
public class AdamUpdate implements UpdateRule {
    private StochasticGradientDescent training;
    private double[] m;
    private double[] v;
    private double beta1 = 0.9d;
    private double beta2 = 0.999d;
    private double eps = 1.0E-8d;

    @Override // org.encog.neural.networks.training.propagation.sgd.update.UpdateRule
    public void init(StochasticGradientDescent stochasticGradientDescent) {
        this.training = stochasticGradientDescent;
        this.m = new double[stochasticGradientDescent.getFlat().getWeights().length];
        this.v = new double[stochasticGradientDescent.getFlat().getWeights().length];
    }

    @Override // org.encog.neural.networks.training.propagation.sgd.update.UpdateRule
    public void update(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr2.length; i++) {
            this.m[i] = (this.beta1 * this.m[i]) + ((1.0d - this.beta1) * dArr[i]);
            this.v[i] = (this.beta2 * this.v[i]) + ((1.0d - this.beta2) * dArr[i] * dArr[i]);
            int i2 = i;
            dArr2[i2] = dArr2[i2] + ((this.training.getLearningRate() * (this.m[i] / (1.0d - Math.pow(this.beta1, this.training.getIteration())))) / (Math.sqrt(this.v[i] / (1.0d - Math.pow(this.beta2, this.training.getIteration()))) + this.eps));
        }
    }
}
