/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.data;

import deepnetts.data.MLDataItem;
import deepnetts.util.RandomGenerator;
import deepnetts.util.Tensor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import javax.visrec.ml.data.BasicDataSet;
import javax.visrec.ml.data.Column;
import javax.visrec.ml.data.DataSet;

public class TabularDataSet<E extends MLDataItem>
extends BasicDataSet<E> {
    private int numInputs;
    private int numOutputs;
    protected String[] columnNames;

    protected TabularDataSet() {
        this.items = new ArrayList();
    }

    public TabularDataSet(int numInputs, int numOutputs) {
        this();
        this.numInputs = numInputs;
        this.numOutputs = numOutputs;
    }

    public int getNumInputs() {
        return this.numInputs;
    }

    public int getNumOutputs() {
        return this.numOutputs;
    }

    public DataSet[] split(int parts) {
        double partSize = (double)Math.round(100.0 / (double)parts) / 100.0;
        double[] partsArr = new double[parts];
        for (int i = 0; i < parts; ++i) {
            partsArr[i] = partSize;
        }
        return this.split(partsArr);
    }

    public DataSet[] split(double ... parts) {
        if (parts.length < 1) {
            throw new IllegalArgumentException("Number of split parts must be greater than one");
        }
        if (parts.length == 1) {
            double[] newParts = new double[]{parts[0], 1.0 - parts[0]};
            parts = newParts;
        }
        double partsSum = 0.0;
        for (int i = 0; i < parts.length; ++i) {
            if (parts[i] <= 0.0) {
                throw new IllegalArgumentException("Value of the part cannot be zero or negative!");
            }
            partsSum += parts[i];
        }
        if (partsSum > 1.0) {
            throw new IllegalArgumentException("Sum of parts cannot be larger than 1!");
        }
        TabularDataSet[] subSets = new TabularDataSet[parts.length];
        int itemIdx = 0;
        this.shuffle();
        for (int p = 0; p < parts.length; ++p) {
            TabularDataSet<E> subSet = new TabularDataSet<E>(this.numInputs, this.numOutputs);
            subSet.setColumnNames(this.columnNames);
            subSet.setColumns(this.getColumns());
            int itemsCount = (int)((double)this.size() * parts[p]);
            for (int j = 0; j < itemsCount; ++j) {
                subSet.add(this.items.get(itemIdx));
                ++itemIdx;
            }
            subSets[p] = subSet;
        }
        return subSets;
    }

    public void shuffle() {
        Random rnd = RandomGenerator.getDefault().getRandom();
        Collections.shuffle(this.items, rnd);
    }

    public void shuffle(int seed) {
        Random rnd = new Random(seed);
        Collections.shuffle(this.items, rnd);
    }

    public String[] getColumnNames() {
        return this.columnNames;
    }

    public void setColumnNames(String[] columnNames) {
        this.columnNames = columnNames;
        ArrayList<Column> columns = new ArrayList<Column>(columnNames.length);
        for (int i = 0; i < columnNames.length; ++i) {
            Column col = new Column(columnNames[i]);
            columns.add(col);
        }
        super.setColumns(columns);
        String[] targetLabels = new String[this.numOutputs];
        for (int i = 0; i < this.numOutputs; ++i) {
            targetLabels[i] = columnNames[this.numInputs + i];
        }
        this.setAsTargetColumns(targetLabels);
    }

    public static class Item
    implements MLDataItem {
        private final Tensor input;
        private final Tensor targetOutput;

        public Item(float[] in, float[] targetOutput) {
            this.input = new Tensor(in);
            this.targetOutput = new Tensor(targetOutput);
        }

        public Item(Tensor input, Tensor targetOutput) {
            this.input = input;
            this.targetOutput = targetOutput;
        }

        @Override
        public Tensor getInput() {
            return this.input;
        }

        @Override
        public Tensor getTargetOutput() {
            return this.targetOutput;
        }

        public int size() {
            return this.input.getCols();
        }

        public String toString() {
            return "BasicDataSetItem{input=" + this.input + ", targetOutput=" + this.targetOutput + '}';
        }
    }
}

