/*
 * Decompiled with CFR 0.152.
 */
package fr.insee.vtl.spark;

import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.Structured;
import java.time.Instant;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Predef;
import scala.collection.Iterable;
import scala.collection.JavaConverters;
import scala.collection.Seq;
import scala.collection.immutable.Map;

public class SparkDataset
implements Dataset {
    private final org.apache.spark.sql.Dataset<Row> sparkDataset;
    private Structured.DataStructure dataStructure = null;
    private java.util.Map<String, Dataset.Role> roles = Collections.emptyMap();

    public SparkDataset(org.apache.spark.sql.Dataset<Row> sparkDataset, java.util.Map<String, Dataset.Role> roles) {
        org.apache.spark.sql.Dataset<Row> castedSparkDataset = SparkDataset.castIfNeeded(Objects.requireNonNull(sparkDataset));
        Structured.DataStructure dataStructure = SparkDataset.fromSparkSchema(sparkDataset.schema(), roles);
        this.sparkDataset = SparkDataset.addMetadata(castedSparkDataset, dataStructure);
        this.roles = Objects.requireNonNull(roles);
    }

    public SparkDataset(org.apache.spark.sql.Dataset<Row> sparkDataset) {
        this.sparkDataset = SparkDataset.castIfNeeded(sparkDataset);
    }

    public SparkDataset(Dataset vtlDataset, java.util.Map<String, Dataset.Role> roles, SparkSession spark) {
        List rows = vtlDataset.getDataPoints().stream().map(points -> RowFactory.create((Object[])points.toArray(new Object[0]))).collect(Collectors.toList());
        StructType schema = SparkDataset.toSparkSchema(vtlDataset.getDataStructure());
        this.sparkDataset = spark.createDataFrame(rows, schema);
        this.roles = Objects.requireNonNull(roles);
    }

    private static org.apache.spark.sql.Dataset<Row> castIfNeeded(org.apache.spark.sql.Dataset<Row> sparkDataset) {
        org.apache.spark.sql.Dataset casted = sparkDataset;
        StructType schema = sparkDataset.schema();
        for (StructField field : JavaConverters.asJavaCollection((Iterable)schema)) {
            if (DataTypes.IntegerType.sameType(field.dataType())) {
                casted = casted.withColumn(field.name(), casted.col(field.name()).cast(DataTypes.LongType));
                continue;
            }
            if (DataTypes.FloatType.sameType(field.dataType())) {
                casted = casted.withColumn(field.name(), casted.col(field.name()).cast(DataTypes.DoubleType));
                continue;
            }
            if (!DecimalType.class.equals(field.dataType().getClass())) continue;
            casted = casted.withColumn(field.name(), casted.col(field.name()).cast(DataTypes.DoubleType));
        }
        return casted;
    }

    private static org.apache.spark.sql.Dataset<Row> addMetadata(org.apache.spark.sql.Dataset<Row> sparkDataset, Structured.DataStructure structure) {
        org.apache.spark.sql.Dataset casted = sparkDataset;
        for (StructField field : JavaConverters.asJavaCollection((Iterable)SparkDataset.toSparkSchema(structure))) {
            String name = field.name();
            casted = casted.withColumn(name, casted.col(name), field.metadata());
        }
        return casted;
    }

    public static StructType toSparkSchema(Structured.DataStructure structure) {
        ArrayList<StructField> schema = new ArrayList<StructField>();
        for (Structured.Component component : structure.values()) {
            Map md = JavaConverters.mapAsScalaMap(java.util.Map.of("vtlRole", component.getRole().name())).toMap(Predef.$conforms());
            schema.add(DataTypes.createStructField((String)component.getName(), (DataType)SparkDataset.fromVtlType(component.getType()), (boolean)true, (Metadata)new Metadata(md)));
        }
        return DataTypes.createStructType(schema);
    }

    public static Structured.DataStructure fromSparkSchema(StructType schema, java.util.Map<String, Dataset.Role> roles) {
        ArrayList<Structured.Component> components = new ArrayList<Structured.Component>();
        for (StructField field : JavaConverters.asJavaCollection((Iterable)schema)) {
            Dataset.Role fieldRole;
            if (roles.containsKey(field.name())) {
                fieldRole = roles.get(field.name());
            } else if (field.metadata().contains("vtlRole")) {
                String roleName = field.metadata().getString("vtlRole");
                fieldRole = Dataset.Role.valueOf((String)roleName);
            } else {
                fieldRole = Dataset.Role.MEASURE;
            }
            components.add(new Structured.Component(field.name(), SparkDataset.toVtlType(field.dataType()), fieldRole, null));
        }
        return new Structured.DataStructure(components);
    }

    public static Class<?> toVtlType(DataType dataType) {
        if (DataTypes.StringType.sameType(dataType)) {
            return String.class;
        }
        if (DataTypes.IntegerType.sameType(dataType)) {
            return Long.class;
        }
        if (DataTypes.LongType.sameType(dataType)) {
            return Long.class;
        }
        if (DataTypes.FloatType.sameType(dataType)) {
            return Double.class;
        }
        if (DataTypes.DoubleType.sameType(dataType)) {
            return Double.class;
        }
        if (DataTypes.BooleanType.sameType(dataType)) {
            return Boolean.class;
        }
        if (DecimalType.class.equals(dataType.getClass())) {
            return Double.class;
        }
        if (DataTypes.DateType.sameType(dataType)) {
            return Instant.class;
        }
        if (DataTypes.TimestampType.sameType(dataType)) {
            return Instant.class;
        }
        throw new UnsupportedOperationException("unsupported type " + dataType);
    }

    public static DataType fromVtlType(Class<?> type) {
        if (String.class.equals(type)) {
            return DataTypes.StringType;
        }
        if (Long.class.equals(type)) {
            return DataTypes.LongType;
        }
        if (Double.class.equals(type)) {
            return DataTypes.DoubleType;
        }
        if (Boolean.class.equals(type)) {
            return DataTypes.BooleanType;
        }
        if (Instant.class.equals(type)) {
            return DataTypes.TimestampType;
        }
        if (LocalDate.class.equals(type)) {
            return DataTypes.DateType;
        }
        throw new UnsupportedOperationException("unsupported type " + type);
    }

    public org.apache.spark.sql.Dataset<Row> getSparkDataset() {
        return this.sparkDataset;
    }

    public List<Structured.DataPoint> getDataPoints() {
        List rows = this.sparkDataset.collectAsList();
        return rows.stream().map(row -> JavaConverters.seqAsJavaList((Seq)row.toSeq())).map(row -> new Structured.DataPoint(this.getDataStructure(), (Collection)row)).collect(Collectors.toList());
    }

    public Structured.DataStructure getDataStructure() {
        if (this.dataStructure == null) {
            this.dataStructure = SparkDataset.fromSparkSchema(this.sparkDataset.schema(), this.roles);
        }
        return this.dataStructure;
    }
}

