package deepboof.impl.forward.standard;

import deepboof.DeepBoofConstants;
import deepboof.forward.FunctionBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F32;
import java.util.List;

/* loaded from: input_file:deepboof/impl/forward/standard/FunctionBatchNorm_F32.class */
public class FunctionBatchNorm_F32 extends BaseFunction<Tensor_F32> implements FunctionBatchNorm<Tensor_F32> {
    protected boolean requiresGammaBeta;
    protected Tensor_F32 params = new Tensor_F32(0);
    protected float EPS = DeepBoofConstants.TEST_TOL_F32 * 0.1f;

    public FunctionBatchNorm_F32(boolean z) {
        this.requiresGammaBeta = z;
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _initialize() {
        this.shapeOutput = (int[]) this.shapeInput.clone();
        int[] WI = TensorOps.WI(this.shapeInput, this.requiresGammaBeta ? 4 : 2);
        this.shapeParameters.add(WI);
        this.params.reshape(WI);
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _setParameters(List<Tensor_F32> list) {
        this.params.setTo(list.get(0));
        int length = this.params.length();
        int i = this.requiresGammaBeta ? 4 : 2;
        int i2 = 1;
        while (true) {
            int i3 = i2;
            if (i3 >= length) {
                return;
            }
            this.params.d[i3] = 1.0f / ((float) Math.sqrt(this.params.d[i3] + this.EPS));
            i2 = i3 + i;
        }
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F32 tensor_F32, Tensor_F32 tensor_F322) {
        if (tensor_F32.getDimension() <= 1) {
            throw new IllegalArgumentException("Input tensor must be at least 2D.  First dimension of batch.");
        }
        int outerLength = TensorOps.outerLength(tensor_F32.shape, 1);
        int i = tensor_F32.startIndex;
        int i2 = tensor_F322.startIndex;
        if (!this.requiresGammaBeta) {
            for (int i3 = 0; i3 < this.miniBatchSize; i3++) {
                int i4 = this.params.startIndex;
                int i5 = i + outerLength;
                while (i < i5) {
                    int i6 = i4;
                    int i7 = i4 + 1;
                    float f = this.params.d[i6];
                    i4 = i7 + 1;
                    int i8 = i2;
                    i2++;
                    int i9 = i;
                    i++;
                    tensor_F322.d[i8] = (tensor_F32.d[i9] - f) * this.params.d[i7];
                }
            }
            return;
        }
        for (int i10 = 0; i10 < this.miniBatchSize; i10++) {
            int i11 = this.params.startIndex;
            int i12 = i + outerLength;
            while (i < i12) {
                int i13 = i11;
                int i14 = i11 + 1;
                float f2 = this.params.d[i13];
                int i15 = i14 + 1;
                float f3 = this.params.d[i14];
                int i16 = i15 + 1;
                float f4 = this.params.d[i15];
                i11 = i16 + 1;
                int i17 = i2;
                i2++;
                int i18 = i;
                i++;
                tensor_F322.d[i17] = ((tensor_F32.d[i18] - f2) * f4 * f3) + this.params.d[i16];
            }
        }
    }

    @Override // deepboof.forward.BatchNorm
    public double getEPS() {
        return this.EPS;
    }

    @Override // deepboof.forward.BatchNorm
    public void setEPS(double d) {
        this.EPS = (float) d;
    }

    @Override // deepboof.forward.BatchNorm
    public boolean hasGammaBeta() {
        return this.requiresGammaBeta;
    }

    @Override // deepboof.Function
    public Class<Tensor_F32> getTensorType() {
        return Tensor_F32.class;
    }
}
