/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.norm;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public class BatchNorm
extends ParameterBlock {
    private static final byte VERSION = 1;
    private int axis;
    private float epsilon;
    private float momentum;
    private long inChannels;
    private boolean center;
    private boolean scale;
    private Parameter gamma;
    private Parameter beta;
    private Parameter runningMean;
    private Parameter runningVar;

    BatchNorm(Builder builder) {
        this.axis = builder.axis;
        this.epsilon = builder.epsilon;
        this.momentum = builder.momentum;
        this.center = builder.center;
        this.scale = builder.scale;
        this.gamma = new Parameter("gamma", this, ParameterType.GAMMA, this.scale);
        this.beta = new Parameter("beta", this, ParameterType.BETA, this.center);
        this.runningMean = new Parameter("runningMean", this, ParameterType.RUNNING_MEAN, false);
        this.runningVar = new Parameter("runningVar", this, ParameterType.RUNNING_VAR, false);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        return ex.batchNorm(inputs, this.epsilon, this.momentum, this.axis, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[]{inputShapes[0]};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Arrays.asList(this.gamma, this.beta, this.runningMean, this.runningVar);
    }

    @Override
    public void beforeInitialize(Shape[] inputShapes) {
        this.inputShapes = inputShapes;
        this.inChannels = inputShapes[0].size(this.axis);
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        switch (name) {
            case "gamma": 
            case "beta": 
            case "runningMean": 
            case "runningVar": {
                return new Shape(this.inChannels);
            }
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        if (inputs.size() != 1) {
            throw new IllegalArgumentException("Linear requires exactly 1 NDArray");
        }
        NDArray data = inputs.singletonOrThrow();
        Device device = data.getDevice();
        NDArray gammaValue = parameterStore.getValue(this.gamma, device);
        NDArray betaValue = parameterStore.getValue(this.beta, device);
        NDArray runningMeanValue = parameterStore.getValue(this.runningMean, device);
        NDArray runningVarValue = parameterStore.getValue(this.runningVar, device);
        return new NDList(data, gammaValue, betaValue, runningMeanValue, runningVarValue);
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(1);
        os.writeLong(this.inChannels);
        this.gamma.save(os);
        this.beta.save(os);
        this.runningMean.save(os);
        this.runningVar.save(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.inChannels = is.readLong();
        this.gamma.load(manager, is);
        this.beta.load(manager, is);
        this.runningMean.load(manager, is);
        this.runningVar.load(manager, is);
    }

    public static final class Builder {
        private int axis = 1;
        private float epsilon = 1.0E-5f;
        private float momentum = 0.9f;
        private boolean center = true;
        private boolean scale = true;

        public Builder optAxis(int val) {
            this.axis = val;
            return this;
        }

        public Builder optCenter(boolean val) {
            this.center = val;
            return this;
        }

        public Builder optScale(boolean val) {
            this.scale = val;
            return this;
        }

        public Builder optEpsilon(float val) {
            this.epsilon = val;
            return this;
        }

        public Builder optMomentum(float val) {
            this.momentum = val;
            return this;
        }

        public BatchNorm build() {
            return new BatchNorm(this);
        }
    }
}

