/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.RowExpressionRewriteRuleSet;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class RewriteCaseToMap
extends RowExpressionRewriteRuleSet {
    public RewriteCaseToMap(FunctionAndTypeManager functionAndTypeManager) {
        super(new Rewriter(functionAndTypeManager));
    }

    @Override
    public boolean isRewriterEnabled(Session session) {
        return session.getSystemProperty("rewrite_case_to_map_enabled", Boolean.class);
    }

    @Override
    public Set<Rule<?>> rules() {
        return ImmutableSet.of(this.projectRowExpressionRewriteRule(), this.filterRowExpressionRewriteRule(), this.joinRowExpressionRewriteRule());
    }

    private static class CaseToMapRewriter
    extends RowExpressionRewriter<Void> {
        private final FunctionAndTypeManager functionAndTypeManager;
        private final FunctionResolution functionResolution;
        private final LogicalRowExpressions logicalRowExpressions;

        private CaseToMapRewriter(FunctionAndTypeManager functionAndTypeManager) {
            this.functionAndTypeManager = functionAndTypeManager;
            this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
            this.logicalRowExpressions = new LogicalRowExpressions((DeterminismEvaluator)new RowExpressionDeterminismEvaluator(functionAndTypeManager), (StandardFunctionResolution)this.functionResolution, (FunctionMetadataManager)functionAndTypeManager);
        }

        private boolean addKeyValue(RowExpression key, Set<RowExpression> keySet, List<RowExpression> keys, RowExpression value, List<RowExpression> values) {
            if (!(key instanceof ConstantExpression) || ((ConstantExpression)key).getValue() == null || keys.size() > 0 && !keys.get(0).getType().equals(key.getType())) {
                return false;
            }
            if (keySet.add(key)) {
                if (values.size() > 0 && !values.get(0).getType().equals(value.getType())) {
                    return false;
                }
                keys.add(key);
                values.add(value);
            }
            return true;
        }

        public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
            int start;
            RowExpression checkExpr;
            if (node.getForm() != SpecialFormExpression.Form.SWITCH) {
                return this.rewriteRowExpression((RowExpression)node, context, treeRewriter);
            }
            int numArgs = node.getArguments().size();
            RowExpression lastArg = (RowExpression)node.getArguments().get(numArgs - 1);
            Preconditions.checkState((numArgs >= 2 ? 1 : 0) != 0);
            Preconditions.checkState((!(lastArg instanceof SpecialFormExpression) || !((SpecialFormExpression)lastArg).getForm().equals((Object)SpecialFormExpression.Form.WHEN) ? 1 : 0) != 0);
            if (!(lastArg instanceof ConstantExpression)) {
                return node;
            }
            RowExpression firstArg = (RowExpression)node.getArguments().get(0);
            HashSet<RowExpression> keySet = new HashSet<RowExpression>();
            ArrayList<RowExpression> whens = new ArrayList<RowExpression>(node.getArguments().size());
            ArrayList<RowExpression> thens = new ArrayList<RowExpression>(node.getArguments().size());
            if (!(firstArg instanceof SpecialFormExpression) || !((SpecialFormExpression)firstArg).getForm().equals((Object)SpecialFormExpression.Form.WHEN)) {
                checkExpr = firstArg.equals((Object)Expressions.constant(true, (Type)BooleanType.BOOLEAN)) ? null : firstArg;
                start = 1;
            } else {
                checkExpr = null;
                start = 0;
            }
            for (int i = start; i < numArgs - 1; ++i) {
                RowExpression key;
                RowExpression curCheck;
                RowExpression whenClause = (RowExpression)node.getArguments().get(i);
                RowExpression value = (RowExpression)whenClause.getChildren().get(1);
                if (!(value instanceof ConstantExpression)) {
                    return node;
                }
                RowExpression when = (RowExpression)whenClause.getChildren().get(0);
                if (when instanceof ConstantExpression) {
                    if (this.addKeyValue(when, keySet, whens, value, thens)) continue;
                    return node;
                }
                if (this.logicalRowExpressions.isEqualsExpression(when)) {
                    RowExpression lhs = (RowExpression)when.getChildren().get(0);
                    RowExpression rhs = (RowExpression)when.getChildren().get(1);
                    if (!lhs.getType().equals(rhs.getType())) {
                        return node;
                    }
                    if (lhs instanceof ConstantExpression) {
                        curCheck = rhs;
                        key = lhs;
                    } else if (rhs instanceof ConstantExpression) {
                        curCheck = lhs;
                        key = rhs;
                    } else {
                        return node;
                    }
                    if (checkExpr == null) {
                        checkExpr = curCheck;
                    } else if (!curCheck.equals((Object)checkExpr)) {
                        return node;
                    }
                    if (this.addKeyValue(key, keySet, whens, value, thens)) continue;
                    return node;
                }
                if (when instanceof SpecialFormExpression && ((SpecialFormExpression)when).getForm() == SpecialFormExpression.Form.IN) {
                    curCheck = (RowExpression)((SpecialFormExpression)when).getArguments().get(0);
                    if (checkExpr == null) {
                        checkExpr = curCheck;
                    } else if (!curCheck.equals((Object)checkExpr)) {
                        return node;
                    }
                    for (int j = 1; j < ((SpecialFormExpression)when).getArguments().size(); ++j) {
                        key = (RowExpression)((SpecialFormExpression)when).getArguments().get(j);
                        if (this.addKeyValue(key, keySet, whens, value, thens)) continue;
                        return node;
                    }
                    continue;
                }
                return node;
            }
            if (checkExpr == null) {
                return node;
            }
            RowExpression mapLookup = this.makeMapAndAccess(whens, thens, checkExpr);
            if (lastArg != null && !lastArg.equals((Object)Expressions.constantNull(((RowExpression)thens.get(0)).getType()))) {
                CallExpression keyArray = Expressions.call("ARRAY", this.functionResolution.arrayConstructor(whens.stream().map(x -> x.getType()).collect(Collectors.toList())), (Type)new ArrayType(((RowExpression)whens.get(0)).getType()), whens);
                CallExpression contains = Expressions.call(this.functionAndTypeManager, "contains", (Type)BooleanType.BOOLEAN, new RowExpression[]{keyArray, checkExpr});
                return Expressions.coalesce(mapLookup, (RowExpression)Expressions.specialForm(SpecialFormExpression.Form.IF, mapLookup.getType(), new RowExpression[]{contains, Expressions.constant(null, mapLookup.getType()), lastArg}));
            }
            return mapLookup;
        }

        private RowExpression makeMapAndAccess(List<RowExpression> keys, List<RowExpression> values, RowExpression mapIndex) {
            CallExpression keyArray = Expressions.call("ARRAY", this.functionResolution.arrayConstructor(keys.stream().map(x -> x.getType()).collect(Collectors.toList())), (Type)new ArrayType(keys.get(0).getType()), keys);
            CallExpression valueArray = Expressions.call("ARRAY", this.functionResolution.arrayConstructor(values.stream().map(x -> x.getType()).collect(Collectors.toList())), (Type)new ArrayType(values.get(0).getType()), values);
            Type keyType = keys.get(0).getType();
            Type valueType = values.get(0).getType();
            MethodHandle keyEquals = this.functionAndTypeManager.getJavaScalarFunctionImplementation(this.functionAndTypeManager.resolveOperator(OperatorType.EQUAL, TypeSignatureProvider.fromTypes((Type[])new Type[]{keyType, keyType}))).getMethodHandle();
            MethodHandle keyHashcode = this.functionAndTypeManager.getJavaScalarFunctionImplementation(this.functionAndTypeManager.resolveOperator(OperatorType.HASH_CODE, TypeSignatureProvider.fromTypes((Type[])new Type[]{keyType}))).getMethodHandle();
            CallExpression map = Expressions.call(this.functionAndTypeManager, "MAP", (Type)new MapType(keyType, valueType, keyEquals, keyHashcode), new RowExpression[]{keyArray, valueArray});
            return Expressions.call(this.functionAndTypeManager, "element_at", valueType, new RowExpression[]{map, mapIndex});
        }
    }

    private static class Rewriter
    implements RowExpressionRewriteRuleSet.PlanRowExpressionRewriter {
        private final CaseToMapRewriter caseToMapRewriter;

        public Rewriter(FunctionAndTypeManager functionAndTypeManager) {
            Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.caseToMapRewriter = new CaseToMapRewriter(functionAndTypeManager);
        }

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context) {
            return RowExpressionTreeRewriter.rewriteWith((RowExpressionRewriter)this.caseToMapRewriter, (RowExpression)expression);
        }
    }
}

