/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.data.preprocessing.scale;

import deepnetts.data.MLDataItem;
import deepnetts.util.Tensor;
import java.io.Serializable;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.data.preprocessing.Scaler;

public class Standardizer
implements Scaler<DataSet<MLDataItem>>,
Serializable {
    private final Tensor mean;
    private final Tensor std;

    public Standardizer(DataSet<MLDataItem> dataSet) {
        Tensor t = ((MLDataItem)dataSet.get(0)).getInput();
        this.mean = new Tensor(t.getCols());
        this.std = new Tensor(t.getCols());
        for (MLDataItem item : dataSet) {
            this.mean.add(item.getInput());
        }
        this.mean.div(dataSet.size());
        for (MLDataItem item : dataSet) {
            Tensor diff = item.getInput().copy();
            diff.sub(this.mean);
            diff.multiplyElementWise(diff);
            this.std.add(diff);
        }
        this.std.div(dataSet.size() - 1);
        this.std.sqrt();
    }

    public void apply(DataSet<MLDataItem> dataSet) {
        for (MLDataItem item : dataSet) {
            item.getInput().sub(this.mean);
            item.getInput().div(this.std);
        }
    }
}

