/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.collections.map.ListOrderedMap;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.DataFrames;

public class Normalization {
    public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame) {
        return Normalization.zeromeanUnitVariance(frame, Collections.emptyList());
    }

    public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data) {
        return Normalization.zeromeanUnitVariance(schema, data, Collections.emptyList());
    }

    public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max) {
        return Normalization.normalize(dataFrame, min, max, Collections.emptyList());
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max) {
        Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
        return (JavaRDD)DataFrames.toRecords(Normalization.normalize(frame, min, max, Collections.emptyList())).getSecond();
    }

    public static Dataset<Row> normalize(Dataset<Row> dataFrame) {
        return Normalization.normalize(dataFrame, 0.0, 1.0, Collections.emptyList());
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data) {
        return Normalization.normalize(schema, data, 0.0, 1.0, Collections.emptyList());
    }

    public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame, List<String> skipColumns) {
        List<String> columnsList = DataFrames.toList(frame.columns());
        columnsList.removeAll(skipColumns);
        String[] columnNames = DataFrames.toArray(columnsList);
        List<Row> stdDevMean = Normalization.stdDevMeanColumns(frame, columnNames);
        for (int i = 0; i < columnNames.length; ++i) {
            String columnName = columnNames[i];
            double std = ((Number)stdDevMean.get(0).get(i)).doubleValue();
            double mean = ((Number)stdDevMean.get(1).get(i)).doubleValue();
            if (std == 0.0) {
                std = 1.0;
            }
            frame = frame.withColumn(columnName, frame.col(columnName).minus((Object)mean).divide((Object)std));
        }
        return frame;
    }

    public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data, List<String> skipColumns) {
        Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
        return (JavaRDD)DataFrames.toRecords(Normalization.zeromeanUnitVariance(frame, skipColumns)).getSecond();
    }

    public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema, JavaRDD<List<List<Writable>>> sequence) {
        return Normalization.zeroMeanUnitVarianceSequence(schema, sequence, null);
    }

    public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema, JavaRDD<List<List<Writable>>> sequence, List<String> excludeColumns) {
        Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, sequence);
        if (excludeColumns == null) {
            excludeColumns = Arrays.asList("__SEQ_UUID", "__SEQ_IDX");
        } else {
            excludeColumns = new ArrayList<String>(excludeColumns);
            excludeColumns.add("__SEQ_UUID");
            excludeColumns.add("__SEQ_IDX");
        }
        frame = Normalization.zeromeanUnitVariance(frame, excludeColumns);
        return (JavaRDD)DataFrames.toRecordsSequence(frame).getSecond();
    }

    public static List<Row> minMaxColumns(Dataset<Row> data, List<String> columns) {
        String[] arr = new String[columns.size()];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = columns.get(i);
        }
        return Normalization.minMaxColumns(data, arr);
    }

    public static List<Row> minMaxColumns(Dataset<Row> data, String ... columns) {
        return Normalization.aggregate(data, columns, new String[]{"min", "max"});
    }

    public static List<Row> stdDevMeanColumns(Dataset<Row> data, List<String> columns) {
        String[] arr = new String[columns.size()];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = columns.get(i);
        }
        return Normalization.stdDevMeanColumns(data, arr);
    }

    public static List<Row> stdDevMeanColumns(Dataset<Row> data, String ... columns) {
        return Normalization.aggregate(data, columns, new String[]{"stddev", "mean"});
    }

    /*
     * WARNING - void declaration
     */
    public static List<Row> aggregate(Dataset<Row> data, String[] columns, String[] functions2) {
        String[] rest = new String[columns.length - 1];
        System.arraycopy(columns, 1, rest, 0, rest.length);
        ArrayList<Row> rows = new ArrayList<Row>();
        for (String op : functions2) {
            void var13_20;
            ListOrderedMap expressions = new ListOrderedMap();
            for (String string : columns) {
                expressions.put(string, op);
            }
            Dataset aggregated = data.agg((Map)expressions);
            String[] columns2 = aggregated.columns();
            TreeMap<String, String> opReplace = new TreeMap<String, String>();
            for (String s3 : columns2) {
                if (s3.contains("min(") || s3.contains("max(")) {
                    opReplace.put(s3, s3.replace(op, "").replaceAll("[()]", ""));
                    continue;
                }
                if (s3.contains("avg")) {
                    opReplace.put(s3, s3.replace("avg", "").replaceAll("[()]", ""));
                    continue;
                }
                opReplace.put(s3, s3.replace(op, "").replaceAll("[()]", ""));
            }
            Object var13_19 = null;
            for (Map.Entry entries : opReplace.entrySet()) {
                if (var13_20 == null) {
                    Dataset dataset = aggregated.withColumnRenamed((String)entries.getKey(), (String)entries.getValue());
                    continue;
                }
                Dataset dataset = var13_20.withColumnRenamed((String)entries.getKey(), (String)entries.getValue());
            }
            Dataset dataset = var13_20.select(DataFrames.toColumns(columns));
            rows.addAll(dataset.collectAsList());
        }
        return rows;
    }

    public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max, List<String> skipColumns) {
        List<String> columnsList = DataFrames.toList(dataFrame.columns());
        columnsList.removeAll(skipColumns);
        String[] columnNames = DataFrames.toArray(columnsList);
        List<Row> minMax = Normalization.minMaxColumns(dataFrame, columnNames);
        for (int i = 0; i < columnNames.length; ++i) {
            String columnName = columnNames[i];
            double dMin = ((Number)minMax.get(0).get(i)).doubleValue();
            double dMax = ((Number)minMax.get(1).get(i)).doubleValue();
            double maxSubMin = dMax - dMin;
            if (maxSubMin == 0.0) {
                maxSubMin = 1.0;
            }
            Column newCol = dataFrame.col(columnName).minus((Object)dMin).divide((Object)maxSubMin).multiply((Object)(max - min)).plus((Object)min);
            dataFrame = dataFrame.withColumn(columnName, newCol);
        }
        return dataFrame;
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max, List<String> skipColumns) {
        Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
        return (JavaRDD)DataFrames.toRecords(Normalization.normalize(frame, min, max, skipColumns)).getSecond();
    }

    public static JavaRDD<List<List<Writable>>> normalizeSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
        return Normalization.normalizeSequence(schema, data, 0.0, 1.0);
    }

    public static JavaRDD<List<List<Writable>>> normalizeSequence(Schema schema, JavaRDD<List<List<Writable>>> data, double min, double max) {
        return Normalization.normalizeSequence(schema, data, min, max, null);
    }

    public static JavaRDD<List<List<Writable>>> normalizeSequence(Schema schema, JavaRDD<List<List<Writable>>> data, double min, double max, List<String> excludeColumns) {
        if (excludeColumns == null) {
            excludeColumns = Arrays.asList("__SEQ_UUID", "__SEQ_IDX");
        } else {
            excludeColumns = new ArrayList<String>(excludeColumns);
            excludeColumns.add("__SEQ_UUID");
            excludeColumns.add("__SEQ_IDX");
        }
        Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, data);
        return (JavaRDD)DataFrames.toRecordsSequence(Normalization.normalize(frame, min, max, excludeColumns)).getSecond();
    }

    public static Dataset<Row> normalize(Dataset<Row> dataFrame, List<String> skipColumns) {
        return Normalization.normalize(dataFrame, 0.0, 1.0, skipColumns);
    }

    public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, List<String> skipColumns) {
        return Normalization.normalize(schema, data, 0.0, 1.0, skipColumns);
    }
}

