/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.data;

import deepnetts.data.MLDataItem;
import deepnetts.data.TabularDataSet;
import deepnetts.data.preprocessing.scale.MaxScaler;
import deepnetts.data.preprocessing.scale.Standardizer;
import deepnetts.util.ColumnType;
import deepnetts.util.CsvFormat;
import deepnetts.util.DeepNettsException;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.regex.Pattern;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.data.preprocessing.Scaler;

public class DataSets {
    public static final String DELIMITER_SPACE = " ";
    public static final String DELIMITER_COMMA = ",";
    public static final String DELIMITER_SEMICOLON = ";";
    public static final String DELIMITER_TAB = "\t";

    public static TabularDataSet readCsv(File csvFile, int numInputs, int numOutputs, boolean hasColumnNames, String delimiter) throws FileNotFoundException, IOException {
        String[] colNames;
        TabularDataSet dataSet = new TabularDataSet(numInputs, numOutputs);
        BufferedReader br = new BufferedReader(new FileReader(csvFile));
        String line = null;
        if (hasColumnNames) {
            line = br.readLine().trim();
            colNames = line.split(delimiter);
            dataSet.setColumnNames(colNames);
        } else {
            colNames = new String[numInputs + numOutputs];
            for (int i = 0; i < numInputs; ++i) {
                colNames[i] = "in" + (i + 1);
            }
            for (int j = 0; j < numOutputs; ++j) {
                colNames[numInputs + j] = "out" + (j + 1);
            }
            dataSet.setColumnNames(colNames);
        }
        while ((line = br.readLine()) != null) {
            if ((line = line.trim()).isEmpty()) continue;
            String[] values = line.split(delimiter);
            if (values.length != numInputs + numOutputs) {
                throw new DeepNettsException("Wrong number of values in the row " + (dataSet.size() + 1) + ": found " + values.length + " expected " + (numInputs + numOutputs));
            }
            float[] in = new float[numInputs];
            float[] out = new float[numOutputs];
            try {
                for (int i = 0; i < numInputs; ++i) {
                    in[i] = Float.parseFloat(values[i]);
                }
                for (int j = 0; j < numOutputs; ++j) {
                    out[j] = Float.parseFloat(values[numInputs + j]);
                }
            }
            catch (NumberFormatException nex) {
                throw new DeepNettsException("Error parsing csv, number expected line in " + (dataSet.size() + 1) + ": " + nex.getMessage(), nex);
            }
            dataSet.add(new TabularDataSet.Item(in, out));
        }
        return dataSet;
    }

    public static TabularDataSet readCsv(String fileName, int numInputs, int numOutputs, boolean hasColumnNames, String delimiter) throws IOException {
        return DataSets.readCsv(new File(fileName), numInputs, numOutputs, hasColumnNames, delimiter);
    }

    public static TabularDataSet readCsv(String fileName, int numInputs, int numOutputs, boolean hasColumnNames) throws IOException {
        return DataSets.readCsv(new File(fileName), numInputs, numOutputs, hasColumnNames, DELIMITER_COMMA);
    }

    public static TabularDataSet readCsv(String fileName, int numInputs, int numOutputs, String delimiter) throws IOException {
        return DataSets.readCsv(new File(fileName), numInputs, numOutputs, false, delimiter);
    }

    public static TabularDataSet readCsv(String fileName, int numInputs, int numOutputs) throws IOException {
        return DataSets.readCsv(new File(fileName), numInputs, numOutputs, false, DELIMITER_COMMA);
    }

    public static CsvFormat detectCsvFormat(String fileName) throws FileNotFoundException, IOException {
        BufferedReader br = new BufferedReader(new FileReader(fileName));
        String firstLine = br.readLine();
        String delimiter = null;
        if (firstLine.contains(DELIMITER_COMMA)) {
            delimiter = DELIMITER_COMMA;
        } else if (firstLine.contains(DELIMITER_SEMICOLON)) {
            delimiter = DELIMITER_SEMICOLON;
        } else if (firstLine.contains(DELIMITER_TAB)) {
            delimiter = DELIMITER_TAB;
        } else if (firstLine.contains(DELIMITER_SPACE)) {
            delimiter = DELIMITER_SPACE;
        } else {
            throw new DeepNettsException("Unknown delimiter");
        }
        boolean hasColumnNames = false;
        String[] columnNames = null;
        String[] firstLineFields = firstLine.split(delimiter);
        int colCount = firstLineFields.length;
        String intRegex = "^-?[0-9]+$";
        String decimalRegex = "^-?[0-9]+\\.[0-9]+$";
        String binaryRegex = "^[01]$";
        String numRegex = "^-?[0-9]+\\.?[0-9]+$";
        String alphaNumRegex = "^[a-zA-Z0-9_\\s\\-]+$";
        String alphaRegex = "^[a-zA-Z_\\s\\-]+$";
        boolean allNumeric = true;
        boolean allAlphaNum = true;
        for (String field : firstLineFields) {
            boolean isNum = Pattern.matches(numRegex, field);
            boolean isAlphaNum = Pattern.matches(alphaNumRegex, field);
            allNumeric = allNumeric && isNum;
            allAlphaNum = allAlphaNum && isAlphaNum;
        }
        if (allNumeric) {
            hasColumnNames = false;
        } else if (allAlphaNum) {
            columnNames = firstLineFields;
            hasColumnNames = true;
        } else {
            hasColumnNames = false;
        }
        String[][] sampleRows = new String[5][colCount];
        for (int i = 0; i < 5; ++i) {
            String line = br.readLine();
            String[] fields = line.split(delimiter);
            sampleRows[i] = fields;
        }
        ColumnType[] colTypes = new ColumnType[colCount];
        for (int c = 0; c < colCount; ++c) {
            boolean allColsAlphaNum = true;
            boolean allColsBinary = true;
            boolean allColsDecimal = true;
            boolean allColsInt = true;
            for (int r = 0; r < 5; ++r) {
                boolean isBinary = Pattern.matches(binaryRegex, sampleRows[r][c]);
                allColsBinary = allColsBinary && isBinary;
                boolean isInt = Pattern.matches(intRegex, sampleRows[r][c]);
                allColsInt = allColsInt && isInt;
                boolean isDecimal = Pattern.matches(decimalRegex, sampleRows[r][c]);
                allColsDecimal = allColsDecimal && (isDecimal || isInt);
                boolean isAlphaNum = Pattern.matches(alphaNumRegex, sampleRows[r][c]);
                allColsAlphaNum = allColsAlphaNum && isAlphaNum;
            }
            colTypes[c] = allColsBinary ? ColumnType.BINARY : (allColsInt ? ColumnType.INTEGER : (allColsDecimal ? ColumnType.DECIMAL : ColumnType.STRING));
        }
        CsvFormat csvFormat = new CsvFormat();
        csvFormat.setDelimiter(delimiter);
        csvFormat.setColumnTypes(colTypes);
        csvFormat.setColumnNames(columnNames);
        csvFormat.setHasHeader(hasColumnNames);
        return csvFormat;
    }

    public static Scaler scaleMax(DataSet dataSet) {
        MaxScaler scaler = new MaxScaler((DataSet<MLDataItem>)dataSet);
        scaler.apply((DataSet<MLDataItem>)dataSet);
        return scaler;
    }

    public static Scaler standardize(DataSet dataSet) {
        Standardizer scaler = new Standardizer((DataSet<MLDataItem>)dataSet);
        scaler.apply((DataSet<MLDataItem>)dataSet);
        return scaler;
    }

    public static float[] oneHotEncode(String label, String[] labels) {
        float[] vect = new float[labels.length];
        for (int i = 0; i < labels.length; ++i) {
            if (!labels[i].equals(label)) continue;
            vect[i] = 1.0f;
        }
        return vect;
    }

    public static DataSet<?>[] trainTestSplit(DataSet<?> dataSet, double split) {
        dataSet.shuffle();
        return dataSet.split(new double[]{split, 1.0 - split});
    }
}

