/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.DeepBoofConstants;
import deepboof.forward.FunctionBatchNorm;
import deepboof.impl.forward.standard.BaseFunction;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F32;
import java.util.List;

public class FunctionBatchNorm_F32
extends BaseFunction<Tensor_F32>
implements FunctionBatchNorm<Tensor_F32> {
    protected boolean requiresGammaBeta;
    protected Tensor_F32 params = new Tensor_F32(0);
    protected float EPS = DeepBoofConstants.TEST_TOL_F32 * 0.1f;

    public FunctionBatchNorm_F32(boolean requiresGammaBeta) {
        this.requiresGammaBeta = requiresGammaBeta;
    }

    @Override
    public void _initialize() {
        this.shapeOutput = (int[])this.shapeInput.clone();
        int[] shapeParam = TensorOps.WI(this.shapeInput, this.requiresGammaBeta ? 4 : 2);
        this.shapeParameters.add(shapeParam);
        this.params.reshape(shapeParam);
    }

    @Override
    public void _setParameters(List<Tensor_F32> parameters) {
        this.params.setTo(parameters.get(0));
        int N = this.params.length();
        int stride = this.requiresGammaBeta ? 4 : 2;
        for (int i = 1; i < N; i += stride) {
            this.params.d[i] = 1.0f / (float)Math.sqrt(this.params.d[i] + this.EPS);
        }
    }

    @Override
    public void _forward(Tensor_F32 input, Tensor_F32 output) {
        if (input.getDimension() <= 1) {
            throw new IllegalArgumentException("Input tensor must be at least 2D. First dimension of batch.");
        }
        int D = TensorOps.outerLength(input.shape, 1);
        int indexIn = input.startIndex;
        int indexOut = output.startIndex;
        if (this.requiresGammaBeta) {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                int indexP = this.params.startIndex;
                int end = indexIn + D;
                while (indexIn < end) {
                    float mean = this.params.d[indexP++];
                    float inv_stdev_eps = this.params.d[indexP++];
                    float gamma = this.params.d[indexP++];
                    float beta = this.params.d[indexP++];
                    output.d[indexOut++] = (input.d[indexIn++] - mean) * (gamma * inv_stdev_eps) + beta;
                }
            }
        } else {
            for (int stack = 0; stack < this.miniBatchSize; ++stack) {
                int indexP = this.params.startIndex;
                int end = indexIn + D;
                while (indexIn < end) {
                    float mean = this.params.d[indexP++];
                    float inv_stdev_eps = this.params.d[indexP++];
                    output.d[indexOut++] = (input.d[indexIn++] - mean) * inv_stdev_eps;
                }
            }
        }
    }

    @Override
    public double getEPS() {
        return this.EPS;
    }

    @Override
    public void setEPS(double EPS) {
        this.EPS = (float)EPS;
    }

    @Override
    public boolean hasGammaBeta() {
        return this.requiresGammaBeta;
    }

    @Override
    public Class<Tensor_F32> getTensorType() {
        return Tensor_F32.class;
    }
}

