/*
 * Decompiled with CFR 0.152.
 */
package fr.insee.vtl.spark;

import fr.insee.vtl.model.AggregationExpression;
import fr.insee.vtl.model.Analytics;
import fr.insee.vtl.model.DataPointRuleset;
import fr.insee.vtl.model.Dataset;
import fr.insee.vtl.model.DatasetExpression;
import fr.insee.vtl.model.HierarchicalRuleset;
import fr.insee.vtl.model.InMemoryDataset;
import fr.insee.vtl.model.Positioned;
import fr.insee.vtl.model.ProcessingEngine;
import fr.insee.vtl.model.ProcessingEngineFactory;
import fr.insee.vtl.model.ResolvableExpression;
import fr.insee.vtl.model.Structured;
import fr.insee.vtl.model.ValidationOutput;
import fr.insee.vtl.model.VtlFunction;
import fr.insee.vtl.spark.SparkDataset;
import fr.insee.vtl.spark.SparkDatasetExpression;
import fr.insee.vtl.spark.SparkFilterFunction;
import fr.insee.vtl.spark.SparkRowMap;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import javax.script.ScriptEngine;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import scala.collection.Iterator;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class SparkProcessingEngine
implements ProcessingEngine {
    public static final Integer DEFAULT_MEDIAN_ACCURACY = 1000000;
    public static final UnsupportedOperationException UNKNOWN_ANALYTIC_FUNCTION = new UnsupportedOperationException("Unknown analytic function");
    private static final String BOOLVAR = "bool_var";
    private static final String ERRORCODE = "errorcode";
    private static final String ERRORLEVEL = "errorlevel";
    private static final String RULEID = "ruleid";
    private static final String IMBALANCE = "imbalance";
    private static final String NON_NULL = "non_null";
    private static final String NON_ZERO = "non_zero";
    private static final String PARTIAL_NULL = "partial_null";
    private static final String PARTIAL_ZERO = "partial_zero";
    private static final String ALWAYS_NULL = "always_null";
    private static final String ALWAYS_ZERO = "always_zero";
    private final SparkSession spark;

    public SparkProcessingEngine(SparkSession spark) {
        spark.conf().set("spark.sql.datetime.java8API.enabled", true);
        this.spark = Objects.requireNonNull(spark);
    }

    private static Map<String, Dataset.Role> getRoleMap(Collection<Structured.Component> components) {
        return components.stream().collect(Collectors.toMap(Structured.Component::getName, Structured.Component::getRole));
    }

    private static Map<String, Dataset.Role> getRoleMap(Dataset dataset) {
        return SparkProcessingEngine.getRoleMap(dataset.getDataStructure().values());
    }

    private static Column convertAggregation(String columnName, AggregationExpression expression) throws UnsupportedOperationException {
        Column column;
        if (expression instanceof AggregationExpression.MinAggregationExpression) {
            column = functions.min((String)columnName);
        } else if (expression instanceof AggregationExpression.MaxAggregationExpression) {
            column = functions.max((String)columnName);
        } else if (expression instanceof AggregationExpression.AverageAggregationExpression) {
            column = functions.avg((String)columnName);
        } else if (expression instanceof AggregationExpression.SumAggregationExpression) {
            column = functions.sum((String)columnName);
        } else if (expression instanceof AggregationExpression.CountAggregationExpression) {
            column = functions.count((String)"*");
        } else if (expression instanceof AggregationExpression.MedianAggregationExpression) {
            column = functions.percentile_approx((Column)functions.col((String)columnName), (Column)functions.lit((Object)0.5), (Column)functions.lit((Object)DEFAULT_MEDIAN_ACCURACY));
        } else if (expression instanceof AggregationExpression.StdDevSampAggregationExpression) {
            column = functions.stddev_samp((String)columnName);
        } else if (expression instanceof AggregationExpression.VarPopAggregationExpression) {
            column = functions.var_pop((String)columnName);
        } else if (expression instanceof AggregationExpression.VarSampAggregationExpression) {
            column = functions.var_samp((String)columnName);
        } else {
            throw new UnsupportedOperationException("unknown aggregation " + expression.getClass());
        }
        return column.alias(columnName);
    }

    private static WindowSpec buildWindowSpec(List<String> partitionBy) {
        return SparkProcessingEngine.buildWindowSpec(partitionBy, null, null);
    }

    private static WindowSpec buildWindowSpec(List<String> partitionBy, Map<String, Analytics.Order> orderBy) {
        return SparkProcessingEngine.buildWindowSpec(partitionBy, orderBy, null);
    }

    private static WindowSpec buildWindowSpec(List<String> partitionBy, Map<String, Analytics.Order> orderBy, Analytics.WindowSpec window) {
        if (partitionBy == null) {
            partitionBy = List.of();
        }
        WindowSpec windowSpec = Window.partitionBy(SparkProcessingEngine.colNameToCol(partitionBy));
        if (orderBy == null) {
            orderBy = Map.of();
        }
        windowSpec = windowSpec.orderBy(SparkProcessingEngine.buildOrderCol(orderBy));
        if (window instanceof Analytics.DataPointWindow) {
            windowSpec = windowSpec.rowsBetween(-window.getLower().longValue(), window.getUpper().longValue());
        } else if (window instanceof Analytics.RangeWindow) {
            windowSpec = windowSpec.rangeBetween(-window.getLower().longValue(), window.getUpper().longValue());
        }
        return windowSpec;
    }

    public static Seq<Column> colNameToCol(List<String> inputColNames) {
        ArrayList<Column> cols = new ArrayList<Column>();
        for (String colName : inputColNames) {
            cols.add(functions.col((String)colName));
        }
        return ((Iterator)JavaConverters.asScalaIteratorConverter(cols.iterator()).asScala()).toSeq();
    }

    public static Seq<Column> buildOrderCol(Map<String, Analytics.Order> orderCols) {
        ArrayList<Column> orders = new ArrayList<Column>();
        for (Map.Entry<String, Analytics.Order> entry : orderCols.entrySet()) {
            if (entry.getValue().equals((Object)Analytics.Order.DESC)) {
                orders.add(functions.col((String)entry.getKey()).desc());
                continue;
            }
            orders.add(functions.col((String)entry.getKey()));
        }
        return ((Iterator)JavaConverters.asScalaIteratorConverter(orders.iterator()).asScala()).toSeq();
    }

    private static List<String> identifierNames(List<Structured.Component> components) {
        return components.stream().filter(component -> Dataset.Role.IDENTIFIER.equals((Object)component.getRole())).map(Structured.Component::getName).collect(Collectors.toList());
    }

    private SparkDataset asSparkDataset(DatasetExpression expression) {
        if (expression instanceof SparkDatasetExpression) {
            return ((SparkDatasetExpression)expression).resolve(Map.of());
        }
        Dataset dataset = expression.resolve(Map.of());
        if (dataset instanceof SparkDataset) {
            return (SparkDataset)dataset;
        }
        return new SparkDataset(dataset, SparkProcessingEngine.getRoleMap(dataset), this.spark);
    }

    public DatasetExpression executeCalc(DatasetExpression expression, Map<String, ResolvableExpression> expressions, Map<String, Dataset.Role> roles, Map<String, String> expressionStrings) {
        SparkDataset dataset = this.asSparkDataset(expression);
        org.apache.spark.sql.Dataset<Row> ds = dataset.getSparkDataset();
        HashMap<String, String> aliasesToName = new HashMap<String, String>();
        LinkedHashMap<String, ResolvableExpression> renamedExpressions = new LinkedHashMap<String, ResolvableExpression>();
        LinkedHashMap<String, String> renamedExpressionString = new LinkedHashMap<String, String>();
        for (String name : expressions.keySet()) {
            String alias = name + "_" + aliasesToName.size();
            renamedExpressions.put(alias, expressions.get(name));
            renamedExpressionString.put(alias, expressionStrings.get(name));
            aliasesToName.put(alias, name);
        }
        org.apache.spark.sql.Dataset<Row> interpreted = this.executeCalcInterpreted(ds, renamedExpressionString);
        org.apache.spark.sql.Dataset<Row> evaluated = this.executeCalcEvaluated(interpreted, renamedExpressions);
        org.apache.spark.sql.Dataset<Row> renamed = this.rename(evaluated, aliasesToName);
        Map<String, Dataset.Role> roleMap = SparkProcessingEngine.getRoleMap(dataset);
        roleMap.putAll(roles);
        return new SparkDatasetExpression(new SparkDataset(renamed, roleMap), (Positioned)expression);
    }

    private org.apache.spark.sql.Dataset<Row> executeCalcEvaluated(org.apache.spark.sql.Dataset<Row> interpreted, Map<String, ResolvableExpression> expressions) {
        Set<String> columnNames = Set.of(interpreted.columns());
        Column structColumns = functions.struct((Column[])((Column[])columnNames.stream().map(colName -> functions.col((String)colName)).toArray(Column[]::new)));
        for (String name : expressions.keySet()) {
            if (columnNames.contains(name)) continue;
            ResolvableExpression expression = expressions.get(name);
            try {
                UserDefinedFunction exprFunction = functions.udf((UDF1 & Serializable)row -> {
                    SparkRowMap context = new SparkRowMap((Row)row);
                    return expression.resolve((Map)context);
                }, (DataType)SparkDataset.fromVtlType(expression.getType()));
                interpreted = interpreted.withColumn(name, exprFunction.apply(new Column[]{structColumns}));
            }
            catch (Exception e) {
                System.out.println(name);
            }
        }
        return interpreted;
    }

    private org.apache.spark.sql.Dataset<Row> executeCalcInterpreted(org.apache.spark.sql.Dataset<Row> result, Map<String, String> expressionStrings) {
        for (String name : expressionStrings.keySet()) {
            try {
                String expression = expressionStrings.get(name);
                if (expression == null) continue;
                result = result.withColumn(name, functions.expr((String)expression));
            }
            catch (Exception exception) {}
        }
        return result;
    }

    public DatasetExpression executeFilter(DatasetExpression expression, ResolvableExpression filter, String filterText) {
        SparkDataset dataset = this.asSparkDataset(expression);
        org.apache.spark.sql.Dataset<Row> ds = dataset.getSparkDataset();
        try {
            org.apache.spark.sql.Dataset result = ds.filter(filterText);
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result, SparkProcessingEngine.getRoleMap(dataset)), (Positioned)expression);
        }
        catch (Exception e) {
            SparkFilterFunction filterFunction = new SparkFilterFunction(filter);
            org.apache.spark.sql.Dataset result = ds.filter((FilterFunction)filterFunction);
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result, SparkProcessingEngine.getRoleMap(dataset)), (Positioned)expression);
        }
    }

    public DatasetExpression executeRename(DatasetExpression expression, Map<String, String> fromTo) {
        SparkDataset dataset = this.asSparkDataset(expression);
        org.apache.spark.sql.Dataset<Row> result = this.rename(dataset.getSparkDataset(), fromTo);
        Map<String, Dataset.Role> originalRoles = SparkProcessingEngine.getRoleMap(dataset);
        LinkedHashMap<String, Dataset.Role> renamedRoles = new LinkedHashMap<String, Dataset.Role>(originalRoles);
        for (Map.Entry<String, String> fromToEntry : fromTo.entrySet()) {
            renamedRoles.put(fromToEntry.getValue(), originalRoles.get(fromToEntry.getKey()));
        }
        return new SparkDatasetExpression(new SparkDataset(result, renamedRoles), (Positioned)expression);
    }

    public org.apache.spark.sql.Dataset<Row> rename(org.apache.spark.sql.Dataset<Row> dataset, Map<String, String> fromTo) {
        ArrayList<Column> columns = new ArrayList<Column>();
        for (String name : dataset.columns()) {
            if (fromTo.containsKey(name)) {
                columns.add(functions.col((String)name).as(fromTo.get(name)));
                continue;
            }
            if (fromTo.containsValue(name)) continue;
            columns.add(functions.col((String)name));
        }
        return dataset.select(JavaConverters.iterableAsScalaIterable(columns).toSeq());
    }

    public DatasetExpression executeProject(DatasetExpression expression, List<String> columnNames) {
        SparkDataset dataset = this.asSparkDataset(expression);
        List columns = columnNames.stream().map(Column::new).collect(Collectors.toList());
        Seq columnSeq = JavaConverters.iterableAsScalaIterable(columns).toSeq();
        org.apache.spark.sql.Dataset result = dataset.getSparkDataset().select(columnSeq);
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result, SparkProcessingEngine.getRoleMap(dataset)), (Positioned)expression);
    }

    private boolean checkColNameCompatibility(List<DatasetExpression> datasets) {
        boolean result = true;
        Structured.DataStructure baseStructure = datasets.get(0).getDataStructure();
        for (int i = 1; i <= datasets.size() - 1; ++i) {
            Structured.DataStructure curretStructure = datasets.get(i).getDataStructure();
            if (baseStructure.equals((Object)curretStructure)) continue;
            result = false;
            break;
        }
        return result;
    }

    public DatasetExpression executeUnion(List<DatasetExpression> datasets) {
        DatasetExpression dataset = datasets.get(0);
        if (!this.checkColNameCompatibility(datasets)) {
            throw new UnsupportedOperationException("The schema of the dataset is not compatible");
        }
        Structured.DataStructure baseDataStructure = datasets.get(0).getDataStructure();
        Set keys = baseDataStructure.keySet();
        HashMap<String, Dataset.Role> dataRoles = new HashMap<String, Dataset.Role>();
        for (String key : keys) {
            Structured.Component item = (Structured.Component)baseDataStructure.get((Object)key);
            dataRoles.put(item.getName(), item.getRole());
        }
        List colNames = datasets.get(0).getColumnNames();
        ArrayList<String> idColList = new ArrayList<String>();
        Structured.DataStructure structure = dataset.getDataStructure();
        for (String colName : colNames) {
            if (!((Structured.Component)structure.get((Object)colName)).getRole().equals((Object)Dataset.Role.IDENTIFIER)) continue;
            idColList.add(colName);
        }
        int size = datasets.size();
        if (size == 1) {
            return datasets.get(0);
        }
        org.apache.spark.sql.Dataset result = this.asSparkDataset(datasets.get(0)).getSparkDataset();
        for (int i = 1; i <= size - 1; ++i) {
            org.apache.spark.sql.Dataset<Row> current = this.asSparkDataset(datasets.get(i)).getSparkDataset();
            result = result.union(current);
        }
        result = result.dropDuplicates(JavaConverters.iterableAsScalaIterable(idColList).toSeq());
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result, dataRoles), (Positioned)datasets.get(0));
    }

    public DatasetExpression executeAggr(DatasetExpression dataset, List<String> groupBy, Map<String, AggregationExpression> collectorMap) {
        SparkDataset sparkDataset = this.asSparkDataset(dataset);
        List columns = collectorMap.entrySet().stream().map(e -> SparkProcessingEngine.convertAggregation((String)e.getKey(), (AggregationExpression)e.getValue())).collect(Collectors.toList());
        List groupByColumns = groupBy.stream().map(name -> functions.col((String)name)).collect(Collectors.toList());
        org.apache.spark.sql.Dataset result = sparkDataset.getSparkDataset().groupBy(JavaConverters.iterableAsScalaIterable(groupByColumns).toSeq()).agg((Column)columns.get(0), JavaConverters.iterableAsScalaIterable(columns.subList(1, columns.size())).toSeq());
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), (Positioned)dataset);
    }

    public DatasetExpression executeSimpleAnalytic(DatasetExpression dataset, String targetColName, Analytics.Function function, String sourceColName, List<String> partitionBy, Map<String, Analytics.Order> orderBy, Analytics.WindowSpec window) {
        Column column;
        SparkDataset sparkDataset = this.asSparkDataset(dataset);
        WindowSpec windowSpec = SparkProcessingEngine.buildWindowSpec(partitionBy, orderBy, window);
        switch (function) {
            case COUNT: {
                column = functions.count((String)sourceColName).over(windowSpec);
                break;
            }
            case SUM: {
                column = functions.sum((String)sourceColName).over(windowSpec);
                break;
            }
            case MIN: {
                column = functions.min((String)sourceColName).over(windowSpec);
                break;
            }
            case MAX: {
                column = functions.max((String)sourceColName).over(windowSpec);
                break;
            }
            case AVG: {
                column = functions.avg((String)sourceColName).over(windowSpec);
                break;
            }
            case MEDIAN: {
                column = functions.percentile_approx((Column)functions.col((String)sourceColName), (Column)functions.lit((Object)0.5), (Column)functions.lit((Object)DEFAULT_MEDIAN_ACCURACY)).over(windowSpec);
                break;
            }
            case STDDEV_POP: {
                column = functions.stddev_pop((String)sourceColName).over(windowSpec);
                break;
            }
            case STDDEV_SAMP: {
                column = functions.stddev_samp((String)sourceColName).over(windowSpec);
                break;
            }
            case VAR_POP: {
                column = functions.var_pop((String)sourceColName).over(windowSpec);
                break;
            }
            case VAR_SAMP: {
                column = functions.var_samp((String)sourceColName).over(windowSpec);
                break;
            }
            case FIRST_VALUE: {
                column = functions.first((String)sourceColName).over(windowSpec);
                break;
            }
            case LAST_VALUE: {
                column = functions.last((String)sourceColName).over(windowSpec);
                break;
            }
            default: {
                throw UNKNOWN_ANALYTIC_FUNCTION;
            }
        }
        org.apache.spark.sql.Dataset result = sparkDataset.getSparkDataset().withColumn(targetColName, column);
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), (Positioned)dataset);
    }

    public DatasetExpression executeLeadOrLagAn(DatasetExpression dataset, String targetColName, Analytics.Function function, String sourceColName, int offset, List<String> partitionBy, Map<String, Analytics.Order> orderBy) {
        Column column;
        SparkDataset sparkDataset = this.asSparkDataset(dataset);
        WindowSpec windowSpec = SparkProcessingEngine.buildWindowSpec(partitionBy, orderBy);
        switch (function) {
            case LEAD: {
                column = functions.lead((String)sourceColName, (int)offset).over(windowSpec);
                break;
            }
            case LAG: {
                column = functions.lag((String)sourceColName, (int)offset).over(windowSpec);
                break;
            }
            default: {
                throw UNKNOWN_ANALYTIC_FUNCTION;
            }
        }
        org.apache.spark.sql.Dataset result = sparkDataset.getSparkDataset().withColumn(targetColName, column);
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), (Positioned)dataset);
    }

    public DatasetExpression executeRatioToReportAn(DatasetExpression dataset, String targetColName, Analytics.Function function, String sourceColName, List<String> partitionBy) {
        if (!function.equals((Object)Analytics.Function.RATIO_TO_REPORT)) {
            throw UNKNOWN_ANALYTIC_FUNCTION;
        }
        SparkDataset sparkDataset = this.asSparkDataset(dataset);
        WindowSpec windowSpec = SparkProcessingEngine.buildWindowSpec(partitionBy);
        String totalColName = "total_" + sourceColName;
        org.apache.spark.sql.Dataset result = sparkDataset.getSparkDataset().withColumn(totalColName, functions.sum((String)sourceColName).over(windowSpec)).withColumn(targetColName, functions.col((String)sourceColName).divide((Object)functions.col((String)totalColName))).drop(totalColName);
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), (Positioned)dataset);
    }

    public DatasetExpression executeRankAn(DatasetExpression dataset, String targetColName, Analytics.Function function, List<String> partitionBy, Map<String, Analytics.Order> orderBy) {
        if (!function.equals((Object)Analytics.Function.RANK)) {
            throw UNKNOWN_ANALYTIC_FUNCTION;
        }
        SparkDataset sparkDataset = this.asSparkDataset(dataset);
        WindowSpec windowSpec = SparkProcessingEngine.buildWindowSpec(partitionBy, orderBy);
        org.apache.spark.sql.Dataset result = sparkDataset.getSparkDataset().withColumn(targetColName, functions.rank().over(windowSpec));
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), (Positioned)dataset);
    }

    public DatasetExpression executeInnerJoin(Map<String, DatasetExpression> datasets, List<Structured.Component> components) {
        List<org.apache.spark.sql.Dataset<Row>> sparkDatasets = this.toAliasedDatasets(datasets);
        List<String> identifiers = SparkProcessingEngine.identifierNames(components);
        org.apache.spark.sql.Dataset<Row> innerJoin = this.executeJoin(sparkDatasets, identifiers, "inner");
        DatasetExpression datasetExpression = datasets.entrySet().iterator().next().getValue();
        return new SparkDatasetExpression(new SparkDataset(innerJoin, SparkProcessingEngine.getRoleMap(components)), (Positioned)datasetExpression);
    }

    public DatasetExpression executeLeftJoin(Map<String, DatasetExpression> datasets, List<Structured.Component> components) {
        List<org.apache.spark.sql.Dataset<Row>> sparkDatasets = this.toAliasedDatasets(datasets);
        List<String> identifiers = SparkProcessingEngine.identifierNames(components);
        org.apache.spark.sql.Dataset<Row> innerJoin = this.executeJoin(sparkDatasets, identifiers, "left");
        DatasetExpression datasetExpression = datasets.entrySet().iterator().next().getValue();
        return new SparkDatasetExpression(new SparkDataset(innerJoin, SparkProcessingEngine.getRoleMap(components)), (Positioned)datasetExpression);
    }

    public DatasetExpression executeCrossJoin(Map<String, DatasetExpression> datasets, List<Structured.Component> identifiers) {
        List<org.apache.spark.sql.Dataset<Row>> sparkDatasets = this.toAliasedDatasets(datasets);
        org.apache.spark.sql.Dataset<Row> crossJoin = this.executeJoin(sparkDatasets, List.of(), "cross");
        DatasetExpression datasetExpression = datasets.entrySet().iterator().next().getValue();
        return new SparkDatasetExpression(new SparkDataset(crossJoin, SparkProcessingEngine.getRoleMap(identifiers)), (Positioned)datasetExpression);
    }

    public DatasetExpression executeFullJoin(Map<String, DatasetExpression> datasets, List<Structured.Component> identifiers) {
        List<org.apache.spark.sql.Dataset<Row>> sparkDatasets = this.toAliasedDatasets(datasets);
        List<String> identifierNames = SparkProcessingEngine.identifierNames(identifiers);
        org.apache.spark.sql.Dataset<Row> crossJoin = this.executeJoin(sparkDatasets, identifierNames, "outer");
        DatasetExpression datasetExpression = datasets.entrySet().iterator().next().getValue();
        return new SparkDatasetExpression(new SparkDataset(crossJoin, SparkProcessingEngine.getRoleMap(identifiers)), (Positioned)datasetExpression);
    }

    public DatasetExpression executeValidateDPruleset(DataPointRuleset dpr, DatasetExpression dataset, String output, Positioned pos) {
        SparkDataset sparkDataset = this.asSparkDataset(dataset);
        org.apache.spark.sql.Dataset<Row> ds = sparkDataset.getSparkDataset();
        org.apache.spark.sql.Dataset<Row> renamedDs = this.rename(ds, dpr.getAlias());
        SparkDataset sparkDs = new SparkDataset(renamedDs);
        SparkDatasetExpression sparkDsExpr = new SparkDatasetExpression(sparkDs, pos);
        Structured.DataStructure dataStructure = sparkDs.getDataStructure();
        Map<String, Dataset.Role> roleMap = SparkProcessingEngine.getRoleMap(sparkDataset);
        roleMap.put(RULEID, Dataset.Role.IDENTIFIER);
        roleMap.put(BOOLVAR, Dataset.Role.MEASURE);
        roleMap.put(ERRORLEVEL, Dataset.Role.MEASURE);
        roleMap.put(ERRORCODE, Dataset.Role.MEASURE);
        Class errorCodeType = dpr.getErrorCodeType();
        Class errorLevelType = dpr.getErrorLevelType();
        List<DatasetExpression> datasetsExpression = dpr.getRules().stream().map(rule -> {
            String ruleName = rule.getName();
            ResolvableExpression ruleIdExpression = ResolvableExpression.withType(String.class).withPosition(pos).using((VtlFunction & Serializable)context -> ruleName);
            ResolvableExpression antecedentExpression = rule.getBuildAntecedentExpression(dataStructure);
            ResolvableExpression consequentExpression = rule.getBuildConsequentExpression(dataStructure);
            ResolvableExpression errorCodeExpr = rule.getErrorCodeExpression();
            ResolvableExpression errorCodeExpression = ResolvableExpression.withType((Class)errorCodeType).withPosition(pos).using((VtlFunction & Serializable)context -> {
                if (errorCodeExpr == null) {
                    return null;
                }
                Map mapContext = (Map)context;
                Object erCode = errorCodeExpr.resolve(mapContext);
                if (erCode == null) {
                    return null;
                }
                Boolean antecedentValue = (Boolean)antecedentExpression.resolve(mapContext);
                Boolean consequentValue = (Boolean)consequentExpression.resolve(mapContext);
                return Boolean.TRUE.equals(antecedentValue) && Boolean.FALSE.equals(consequentValue) ? errorCodeType.cast(erCode) : null;
            });
            ResolvableExpression errorLevelExpr = rule.getErrorLevelExpression();
            ResolvableExpression errorLevelExpression = ResolvableExpression.withType((Class)errorLevelType).withPosition(pos).using((VtlFunction & Serializable)context -> {
                if (errorLevelExpr == null) {
                    return null;
                }
                Map mapContext = (Map)context;
                Object erLevel = errorLevelExpr.resolve(mapContext);
                if (erLevel == null) {
                    return null;
                }
                Boolean antecedentValue = (Boolean)antecedentExpression.resolve(mapContext);
                Boolean consequentValue = (Boolean)consequentExpression.resolve(mapContext);
                return Boolean.TRUE.equals(antecedentValue) && Boolean.FALSE.equals(consequentValue) ? errorLevelType.cast(erLevel) : null;
            });
            ResolvableExpression BOOLVARExpression = ResolvableExpression.withType(Boolean.class).withPosition(pos).using((VtlFunction & Serializable)context -> {
                Boolean antecedentValue = (Boolean)antecedentExpression.resolve(context);
                Boolean consequentValue = (Boolean)consequentExpression.resolve(context);
                if (antecedentValue == null) {
                    return consequentValue;
                }
                if (consequentValue == null) {
                    return antecedentValue;
                }
                return antecedentValue == false || consequentValue != false;
            });
            HashMap<String, ResolvableExpression> resolvableExpressions = new HashMap<String, ResolvableExpression>();
            resolvableExpressions.put(RULEID, ruleIdExpression);
            resolvableExpressions.put(BOOLVAR, BOOLVARExpression);
            resolvableExpressions.put(ERRORLEVEL, errorLevelExpression);
            resolvableExpressions.put(ERRORCODE, errorCodeExpression);
            return this.executeCalc(sparkDsExpr, resolvableExpressions, roleMap, Map.of());
        }).collect(Collectors.toList());
        org.apache.spark.sql.Dataset<Row> invertRenamedSparkDs = this.rename(this.asSparkDataset(this.executeUnion(datasetsExpression)).getSparkDataset(), this.invertMap(dpr.getAlias()));
        SparkDatasetExpression sparkDatasetExpression = new SparkDatasetExpression(new SparkDataset(invertRenamedSparkDs), pos);
        if (output == null || output.equals(ValidationOutput.INVALID.value)) {
            ResolvableExpression defaultExpression = ResolvableExpression.withType(Boolean.class).withPosition(pos).using((VtlFunction & Serializable)c -> null);
            DatasetExpression filteredDataset = this.executeFilter(sparkDatasetExpression, defaultExpression, "bool_var = false");
            org.apache.spark.sql.Dataset result = this.asSparkDataset(filteredDataset).getSparkDataset().drop(BOOLVAR);
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), pos);
        }
        return sparkDatasetExpression;
    }

    public DatasetExpression executeValidationSimple(DatasetExpression dsExpr, ResolvableExpression errorCodeExpr, ResolvableExpression errorLevelExpr, DatasetExpression imbalanceExpr, String output, Positioned pos) {
        SparkDataset sparkImbalanceDataset = this.asSparkDataset(imbalanceExpr);
        org.apache.spark.sql.Dataset<Row> sparkImbalanceDatasetRow = sparkImbalanceDataset.getSparkDataset();
        String imbalanceMonomeasureName = (String)imbalanceExpr.getDataStructure().values().stream().filter(c -> c.isMeasure()).map(c -> c.getName()).collect(Collectors.toList()).get(0);
        Map<String, String> varsToRename = Map.ofEntries(Map.entry(imbalanceMonomeasureName, IMBALANCE));
        org.apache.spark.sql.Dataset<Row> renamed = this.rename(sparkImbalanceDatasetRow, varsToRename);
        Map<String, Dataset.Role> imbalanceRoleMap = SparkProcessingEngine.getRoleMap(sparkImbalanceDataset);
        SparkDatasetExpression imbalanceRenamedExpr = new SparkDatasetExpression(new SparkDataset(renamed, imbalanceRoleMap), pos);
        Map<String, DatasetExpression> datasetExpressions = Map.ofEntries(Map.entry("dsExpr", dsExpr), Map.entry("imbalanceExpr", imbalanceRenamedExpr));
        List<Structured.Component> components = dsExpr.getDataStructure().values().stream().filter(Structured.Component::isIdentifier).collect(Collectors.toList());
        DatasetExpression datasetExpression = this.executeLeftJoin(datasetExpressions, components);
        SparkDataset sparkDataset = this.asSparkDataset(datasetExpression);
        org.apache.spark.sql.Dataset<Row> ds = sparkDataset.getSparkDataset();
        Class errorCodeType = errorCodeExpr == null ? String.class : errorCodeExpr.getType();
        ResolvableExpression errorCodeExpression = ResolvableExpression.withType((Class)errorCodeType).withPosition(pos).using((VtlFunction & Serializable)context -> {
            Map contextMap = (Map)context;
            if (errorCodeExpr == null) {
                return null;
            }
            Object erCode = errorCodeExpr.resolve(contextMap);
            Boolean boolVar = (Boolean)contextMap.get(BOOLVAR);
            return boolVar != false ? null : errorCodeType.cast(erCode);
        });
        Class errorLevelType = errorLevelExpr == null ? String.class : errorLevelExpr.getType();
        ResolvableExpression errorLevelExpression = ResolvableExpression.withType((Class)errorLevelType).withPosition(pos).using((VtlFunction & Serializable)context -> {
            Map contextMap = (Map)context;
            if (errorLevelExpr == null) {
                return null;
            }
            Object erLevel = errorLevelExpr.resolve(contextMap);
            Boolean boolVar = (Boolean)contextMap.get(BOOLVAR);
            return boolVar != false ? null : errorLevelType.cast(erLevel);
        });
        Map<String, Dataset.Role> roleMap = SparkProcessingEngine.getRoleMap(sparkDataset);
        roleMap.put(ERRORLEVEL, Dataset.Role.MEASURE);
        roleMap.put(ERRORCODE, Dataset.Role.MEASURE);
        Map<String, ResolvableExpression> resolvableExpressions = Map.ofEntries(Map.entry(ERRORLEVEL, errorLevelExpression), Map.entry(ERRORCODE, errorCodeExpression));
        org.apache.spark.sql.Dataset<Row> calculatedDataset = this.executeCalcEvaluated(ds, resolvableExpressions);
        SparkDatasetExpression sparkDatasetExpression = new SparkDatasetExpression(new SparkDataset(calculatedDataset, roleMap), pos);
        if (output == null || output.equals(ValidationOutput.ALL.value)) {
            return sparkDatasetExpression;
        }
        DatasetExpression filteredDataset = this.executeFilter(sparkDatasetExpression, ResolvableExpression.withType(Boolean.class).withPosition(pos).using((VtlFunction & Serializable)c -> null), "bool_var = false");
        org.apache.spark.sql.Dataset result = this.asSparkDataset(filteredDataset).getSparkDataset().drop(BOOLVAR);
        return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), pos);
    }

    public DatasetExpression executeHierarchicalValidation(DatasetExpression dsE, HierarchicalRuleset hr, String componentID, String validationMode, String inputMode, String validationOutput, Positioned pos) {
        DatasetExpression datasetExpression;
        if (inputMode != null && inputMode.equals("dataset_priority")) {
            throw new UnsupportedOperationException("dataset_priority input mode is not supported in check_hierarchy");
        }
        Dataset ds = dsE.resolve(Map.of());
        Map bindings = ds.getDataAsMap().stream().collect(HashMap::new, (acc, dp) -> acc.put(dp.get(componentID).toString(), dp.get(hr.getVariable())), HashMap::putAll);
        Structured.Component measure = (Structured.Component)dsE.getDataStructure().getMeasures().get(0);
        Class measureType = measure.getType();
        Map<String, Dataset.Role> roleMap = SparkProcessingEngine.getRoleMap(ds);
        roleMap.put(RULEID, Dataset.Role.IDENTIFIER);
        roleMap.put(BOOLVAR, Dataset.Role.MEASURE);
        roleMap.put(IMBALANCE, Dataset.Role.MEASURE);
        roleMap.put(ERRORLEVEL, Dataset.Role.MEASURE);
        roleMap.put(ERRORCODE, Dataset.Role.MEASURE);
        Class errorCodeType = hr.getErrorCodeType();
        Class errorLevelType = hr.getErrorLevelType();
        ArrayList<DatasetExpression> datasetsExpression = new ArrayList<DatasetExpression>();
        hr.getRules().forEach(rule -> {
            DatasetExpression filteredDataset;
            try {
                filteredDataset = this.executeFilterForHR(dsE, componentID + " = \"" + rule.getValueDomainValue() + "\"");
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            String ruleName = rule.getName();
            List codeItems = rule.getCodeItems();
            Map<String, Object> ruleBindings = this.extractHRRuleBindings(bindings, codeItems);
            Boolean hasToProduceOutputLine = this.checkRule(codeItems, ruleBindings, validationMode);
            if (Boolean.FALSE.equals(hasToProduceOutputLine)) {
                return;
            }
            ruleBindings = this.buildBindingsWithDefault(ruleBindings, codeItems, validationMode, measureType);
            HashMap<String, Boolean> resolvedRuleExpressions = new HashMap<String, Boolean>();
            HashMap<String, Double> resolvedLeftRuleExpressions = new HashMap<String, Double>();
            HashMap<String, Double> resolvedRightRuleExpressions = new HashMap<String, Double>();
            try {
                resolvedRuleExpressions.put(ruleName, (Boolean)rule.getExpression().resolve(ruleBindings));
            }
            catch (Exception e) {
                resolvedRuleExpressions.put(ruleName, null);
            }
            try {
                resolvedLeftRuleExpressions.put(ruleName, (Double)rule.getLeftExpression().resolve(ruleBindings));
            }
            catch (Exception e) {
                resolvedLeftRuleExpressions.put(ruleName, null);
            }
            try {
                resolvedRightRuleExpressions.put(ruleName, (Double)rule.getRightExpression().resolve(ruleBindings));
            }
            catch (Exception e) {
                resolvedRightRuleExpressions.put(ruleName, null);
            }
            ResolvableExpression ruleIdExpression = ResolvableExpression.withType(String.class).withPosition(pos).using((VtlFunction & Serializable)context -> ruleName);
            String vd = rule.getValueDomainValue();
            ResolvableExpression valueDomainExpression = ResolvableExpression.withType(String.class).withPosition(pos).using((VtlFunction & Serializable)context -> vd);
            Boolean expression = (Boolean)resolvedRuleExpressions.get(ruleName);
            ResolvableExpression errorCodeExpr = rule.getErrorCodeExpression();
            ResolvableExpression errorCodeExpression = ResolvableExpression.withType((Class)errorCodeType).withPosition(pos).using((VtlFunction & Serializable)context -> {
                if (errorCodeExpr == null || expression == null) {
                    return null;
                }
                Map mapContext = (Map)context;
                Object erCode = errorCodeExpr.resolve(mapContext);
                if (erCode == null) {
                    return null;
                }
                return expression.equals(Boolean.FALSE) ? errorCodeType.cast(erCode) : null;
            });
            ResolvableExpression errorLevelExpr = rule.getErrorLevelExpression();
            ResolvableExpression errorLevelExpression = ResolvableExpression.withType((Class)errorLevelType).withPosition(pos).using((VtlFunction & Serializable)context -> {
                if (errorLevelExpr == null || expression == null) {
                    return null;
                }
                Map mapContext = (Map)context;
                Object erLevel = errorLevelExpr.resolve(mapContext);
                if (erLevel == null) {
                    return null;
                }
                return expression.equals(Boolean.FALSE) ? errorLevelType.cast(erLevel) : null;
            });
            ResolvableExpression BoolvarExpression = ResolvableExpression.withType(Boolean.class).withPosition(pos).using((VtlFunction & Serializable)context -> expression);
            ResolvableExpression imbalanceExpression = ResolvableExpression.withType((Class)measureType).withPosition(pos).using((VtlFunction & Serializable)context -> {
                Double leftExpression = (Double)resolvedLeftRuleExpressions.get(ruleName);
                Double rightExpression = (Double)resolvedRightRuleExpressions.get(ruleName);
                if (leftExpression == null || rightExpression == null) {
                    return null;
                }
                if (measureType.isAssignableFrom(Long.class)) {
                    return leftExpression.longValue() - rightExpression.longValue();
                }
                return leftExpression - rightExpression;
            });
            HashMap<String, ResolvableExpression> resolvableExpressions = new HashMap<String, ResolvableExpression>();
            resolvableExpressions.put(RULEID, ruleIdExpression);
            resolvableExpressions.put(componentID, valueDomainExpression);
            resolvableExpressions.put(BOOLVAR, BoolvarExpression);
            resolvableExpressions.put(IMBALANCE, imbalanceExpression);
            resolvableExpressions.put(ERRORLEVEL, errorLevelExpression);
            resolvableExpressions.put(ERRORCODE, errorCodeExpression);
            datasetsExpression.add(this.executeCalc(filteredDataset, resolvableExpressions, roleMap, Map.of()));
        });
        if (datasetsExpression.size() == 0) {
            InMemoryDataset emptyCHDataset = new InMemoryDataset(List.of(), Map.of(measure.getName(), measureType, RULEID, String.class, componentID, String.class, BOOLVAR, Boolean.class, IMBALANCE, Double.class, ERRORLEVEL, errorLevelType, ERRORCODE, errorCodeType), roleMap);
            datasetExpression = DatasetExpression.of((Dataset)emptyCHDataset, (Positioned)pos);
        } else {
            datasetExpression = this.executeUnion(datasetsExpression);
        }
        if (null == validationOutput || validationOutput.equals("invalid")) {
            DatasetExpression filteredDataset = this.executeFilter(datasetExpression, ResolvableExpression.withType(Boolean.class).withPosition(pos).using((VtlFunction & Serializable)c -> null), "bool_var = false");
            org.apache.spark.sql.Dataset result = this.asSparkDataset(filteredDataset).getSparkDataset().drop(BOOLVAR);
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), pos);
        }
        if (validationOutput.equals("all")) {
            String measureName = measure.getName();
            org.apache.spark.sql.Dataset result = this.asSparkDataset(datasetExpression).getSparkDataset().drop(measureName);
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result), pos);
        }
        return datasetExpression;
    }

    private DatasetExpression executeFilterForHR(DatasetExpression expression, String filterText) throws Exception {
        SparkDataset dataset = this.asSparkDataset(expression);
        org.apache.spark.sql.Dataset<Row> ds = dataset.getSparkDataset();
        try {
            org.apache.spark.sql.Dataset result = ds.filter(filterText);
            if (result.isEmpty()) {
                result = ds.limit(1);
            }
            return new SparkDatasetExpression(new SparkDataset((org.apache.spark.sql.Dataset<Row>)result, SparkProcessingEngine.getRoleMap(dataset)), (Positioned)expression);
        }
        catch (Exception e) {
            throw new Exception(e);
        }
    }

    private Map<String, Object> extractHRRuleBindings(Map<String, Object> bindings, List<String> items) {
        HashMap<String, Object> ruleBindings = new HashMap<String, Object>();
        items.forEach(k -> {
            if (bindings.containsKey(k)) {
                Object value = bindings.get(k);
                ruleBindings.put((String)k, value);
            }
        });
        return ruleBindings;
    }

    private Boolean checkRule(List<String> codeItems, Map<String, Object> ruleBindings, String validationMode) {
        if (validationMode == null || validationMode.equals(NON_NULL)) {
            if (codeItems.size() != ruleBindings.size()) {
                return Boolean.FALSE;
            }
            if (ruleBindings.values().stream().noneMatch(Objects::isNull)) {
                return Boolean.TRUE;
            }
            return Boolean.FALSE;
        }
        if (validationMode.equals(NON_ZERO)) {
            if (ruleBindings.values().stream().noneMatch(r -> {
                if (null == r) {
                    return Boolean.TRUE;
                }
                Double d = null;
                if (r.getClass().isAssignableFrom(Long.class)) {
                    d = ((Long)r).doubleValue();
                }
                if (r.getClass().isAssignableFrom(Double.class)) {
                    d = (Double)r;
                }
                if (d.equals(0.0)) {
                    return Boolean.FALSE;
                }
                return Boolean.TRUE;
            })) {
                return Boolean.FALSE;
            }
            return Boolean.TRUE;
        }
        if (validationMode.equals(PARTIAL_NULL) || validationMode.equals(PARTIAL_ZERO)) {
            if (ruleBindings.values().stream().filter(Objects::nonNull).count() > 0L) {
                return Boolean.TRUE;
            }
            return Boolean.FALSE;
        }
        if (validationMode.equals(ALWAYS_NULL) || validationMode.equals(ALWAYS_ZERO)) {
            return Boolean.TRUE;
        }
        return Boolean.FALSE;
    }

    private Map<String, Object> buildBindingsWithDefault(Map<String, Object> bindings, List<String> ruleItems, String validationMode, Class<?> measureType) {
        HashMap<String, Object> bindingsWithDefault = new HashMap<String, Object>();
        ruleItems.forEach(i -> {
            if (bindings.containsKey(i)) {
                bindingsWithDefault.put((String)i, bindings.get(i));
            } else {
                if (List.of(NON_ZERO, PARTIAL_ZERO, ALWAYS_ZERO).contains(validationMode)) {
                    if (measureType.isAssignableFrom(Long.class)) {
                        bindingsWithDefault.put((String)i, 0L);
                    } else {
                        bindingsWithDefault.put((String)i, 0.0);
                    }
                }
                if (List.of(PARTIAL_NULL, ALWAYS_NULL).contains(validationMode)) {
                    bindingsWithDefault.put((String)i, null);
                }
            }
        });
        return bindingsWithDefault;
    }

    private <V, K> Map<V, K> invertMap(Map<K, V> map) {
        return map.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
    }

    private List<org.apache.spark.sql.Dataset<Row>> toAliasedDatasets(Map<String, DatasetExpression> datasets) {
        ArrayList<org.apache.spark.sql.Dataset<Row>> sparkDatasets = new ArrayList<org.apache.spark.sql.Dataset<Row>>();
        for (Map.Entry<String, DatasetExpression> dataset : datasets.entrySet()) {
            org.apache.spark.sql.Dataset sparkDataset = this.asSparkDataset(dataset.getValue()).getSparkDataset().as(dataset.getKey());
            sparkDatasets.add((org.apache.spark.sql.Dataset<Row>)sparkDataset);
        }
        return sparkDatasets;
    }

    public org.apache.spark.sql.Dataset<Row> executeJoin(List<org.apache.spark.sql.Dataset<Row>> sparkDatasets, List<String> identifiers, String type) {
        java.util.Iterator<org.apache.spark.sql.Dataset<Row>> iterator = sparkDatasets.iterator();
        org.apache.spark.sql.Dataset result = iterator.next();
        while (iterator.hasNext()) {
            if (type.equals("cross")) {
                result = result.crossJoin(iterator.next());
                continue;
            }
            result = result.join(iterator.next(), JavaConverters.iterableAsScalaIterable(identifiers).toSeq(), type);
        }
        return result;
    }

    public static class Factory
    implements ProcessingEngineFactory {
        private static final String SPARK_SESSION = "$vtl.spark.session";

        public String getName() {
            return "spark";
        }

        public ProcessingEngine getProcessingEngine(ScriptEngine engine) {
            Object session = engine.get(SPARK_SESSION);
            if (session != null) {
                if (session instanceof SparkSession) {
                    return new SparkProcessingEngine((SparkSession)session);
                }
                throw new IllegalArgumentException("$vtl.spark.session was not a spark session");
            }
            SparkSession activeSession = SparkSession.active();
            if (activeSession != null) {
                return new SparkProcessingEngine(activeSession);
            }
            throw new IllegalArgumentException("no active spark session");
        }
    }
}

