/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.metadata;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.MetadataDef;
import org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.checkerframework.checker.nullness.qual.KeyFor;
import org.checkerframework.checker.nullness.qual.Nullable;

public class RelMdExpressionLineage
implements MetadataHandler<BuiltInMetadata.ExpressionLineage> {
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource(new RelMdExpressionLineage(), BuiltInMetadata.ExpressionLineage.Handler.class);

    protected RelMdExpressionLineage() {
    }

    @Override
    public MetadataDef<BuiltInMetadata.ExpressionLineage> getDef() {
        return BuiltInMetadata.ExpressionLineage.DEF;
    }

    public @Nullable Set<RexNode> getExpressionLineage(RelNode rel, RelMetadataQuery mq, RexNode outputExpression) {
        return null;
    }

    public @Nullable Set<RexNode> getExpressionLineage(RelSubset rel, RelMetadataQuery mq, RexNode outputExpression) {
        RelNode bestOrOriginal = Util.first(rel.getBest(), rel.getOriginal());
        if (bestOrOriginal == null) {
            return null;
        }
        return mq.getExpressionLineage(bestOrOriginal, outputExpression);
    }

    public @Nullable Set<RexNode> getExpressionLineage(TableScan rel, RelMetadataQuery mq, RexNode outputExpression) {
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ImmutableBitSet inputFieldsUsed = RelMdExpressionLineage.extractInputRefs(outputExpression);
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexTableInputRef inputRef = RexTableInputRef.of(RexTableInputRef.RelTableRef.of(rel.getTable(), 0), RexInputRef.of(idx, rel.getRowType().getFieldList()));
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            mapping.put(ref, (Set<RexNode>)ImmutableSet.of((Object)inputRef));
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Aggregate rel, RelMetadataQuery mq, RexNode outputExpression) {
        RelNode input = rel.getInput();
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ImmutableBitSet inputFieldsUsed = RelMdExpressionLineage.extractInputRefs(outputExpression);
        for (int idx : inputFieldsUsed) {
            if (idx < rel.getGroupCount()) continue;
            return null;
        }
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexInputRef inputRef = RexInputRef.of(rel.getGroupSet().nth(idx), input.getRowType().getFieldList());
            Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputRef);
            if (originalExprs == null) {
                return null;
            }
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            mapping.put(ref, originalExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Join rel, RelMetadataQuery mq, RexNode outputExpression) {
        Set<RexTableInputRef.RelTableRef> leftTableRefs;
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        RelNode leftInput = rel.getLeft();
        RelNode rightInput = rel.getRight();
        int nLeftColumns = leftInput.getRowType().getFieldList().size();
        ImmutableBitSet inputFieldsUsed = RelMdExpressionLineage.extractInputRefs(outputExpression);
        if (rel.getJoinType().isOuterJoin()) {
            if (rel.getJoinType() == JoinRelType.LEFT) {
                ImmutableBitSet rightFields = ImmutableBitSet.range(nLeftColumns, rel.getRowType().getFieldCount());
                if (inputFieldsUsed.intersects(rightFields)) {
                    return null;
                }
            } else if (rel.getJoinType() == JoinRelType.RIGHT) {
                ImmutableBitSet leftFields = ImmutableBitSet.range(0, nLeftColumns);
                if (inputFieldsUsed.intersects(leftFields)) {
                    return null;
                }
            } else {
                return null;
            }
        }
        if ((leftTableRefs = mq.getTableReferences(leftInput)) == null) {
            return null;
        }
        Set<RexTableInputRef.RelTableRef> rightTableRefs = mq.getTableReferences(rightInput);
        if (rightTableRefs == null) {
            return null;
        }
        HashMultimap qualifiedNamesToRefs = HashMultimap.create();
        HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef> currentTablesMapping = new HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef>();
        for (RexTableInputRef.RelTableRef leftRef : leftTableRefs) {
            qualifiedNamesToRefs.put(leftRef.getQualifiedName(), (Object)leftRef);
        }
        for (RexTableInputRef.RelTableRef rightRef : rightTableRefs) {
            int shift = 0;
            Collection lRefs = qualifiedNamesToRefs.get(rightRef.getQualifiedName());
            if (lRefs != null) {
                shift = lRefs.size();
            }
            currentTablesMapping.put(rightRef, RexTableInputRef.RelTableRef.of(rightRef.getTable(), shift + rightRef.getEntityNumber()));
        }
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            Set<RexNode> originalExprs;
            RexInputRef inputRef;
            if (idx < nLeftColumns) {
                inputRef = RexInputRef.of(idx, leftInput.getRowType().getFieldList());
                originalExprs = mq.getExpressionLineage(leftInput, inputRef);
                if (originalExprs == null) {
                    return null;
                }
                mapping.put(RexInputRef.of(idx, rel.getRowType().getFieldList()), originalExprs);
                continue;
            }
            inputRef = RexInputRef.of(idx - nLeftColumns, rightInput.getRowType().getFieldList());
            originalExprs = mq.getExpressionLineage(rightInput, inputRef);
            if (originalExprs == null) {
                return null;
            }
            RelDataType fullRowType = SqlValidatorUtil.createJoinType(rexBuilder.getTypeFactory(), rel.getLeft().getRowType(), rel.getRight().getRowType(), null, (List<RelDataTypeField>)ImmutableList.of());
            ImmutableSet updatedExprs = ImmutableSet.copyOf(Util.transform(originalExprs, e -> RexUtil.swapTableReferences(rexBuilder, e, currentTablesMapping)));
            mapping.put(RexInputRef.of(idx, fullRowType), (Set<RexNode>)updatedExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Union rel, RelMetadataQuery mq, RexNode outputExpression) {
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ImmutableBitSet inputFieldsUsed = RelMdExpressionLineage.extractInputRefs(outputExpression);
        HashMultimap qualifiedNamesToRefs = HashMultimap.create();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (RelNode input : rel.getInputs()) {
            HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef> currentTablesMapping = new HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef>();
            Set<RexTableInputRef.RelTableRef> tableRefs = mq.getTableReferences(input);
            if (tableRefs == null) {
                return null;
            }
            for (RexTableInputRef.RelTableRef tableRef : tableRefs) {
                int shift = 0;
                Collection lRefs = qualifiedNamesToRefs.get(tableRef.getQualifiedName());
                if (lRefs != null) {
                    shift = lRefs.size();
                }
                currentTablesMapping.put(tableRef, RexTableInputRef.RelTableRef.of(tableRef.getTable(), shift + tableRef.getEntityNumber()));
            }
            Iterator<Comparable<RexTableInputRef.RelTableRef>> iterator = inputFieldsUsed.iterator();
            while (iterator.hasNext()) {
                int idx = (Integer)iterator.next();
                RexInputRef inputRef = RexInputRef.of(idx, input.getRowType().getFieldList());
                Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputRef);
                if (originalExprs == null) {
                    return null;
                }
                RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
                Set updatedExprs = originalExprs.stream().map(e -> RexUtil.swapTableReferences(rexBuilder, e, currentTablesMapping)).collect(Collectors.toSet());
                Set set = (Set)mapping.get(ref);
                if (set != null) {
                    set.addAll(updatedExprs);
                    continue;
                }
                mapping.put(ref, updatedExprs);
            }
            for (RexTableInputRef.RelTableRef newRef : currentTablesMapping.values()) {
                qualifiedNamesToRefs.put(newRef.getQualifiedName(), (Object)newRef);
            }
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Project rel, RelMetadataQuery mq, RexNode outputExpression) {
        RelNode input = rel.getInput();
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ImmutableBitSet inputFieldsUsed = RelMdExpressionLineage.extractInputRefs(outputExpression);
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexNode inputExpr = rel.getProjects().get(idx);
            Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputExpr);
            if (originalExprs == null) {
                return null;
            }
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            mapping.put(ref, originalExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Filter rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Sort rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    public @Nullable Set<RexNode> getExpressionLineage(TableModify rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Exchange rel, RelMetadataQuery mq, RexNode outputExpression) {
        return mq.getExpressionLineage(rel.getInput(), outputExpression);
    }

    public @Nullable Set<RexNode> getExpressionLineage(Calc calc, RelMetadataQuery mq, RexNode outputExpression) {
        RelNode input = calc.getInput();
        RexBuilder rexBuilder = calc.getCluster().getRexBuilder();
        ImmutableBitSet inputFieldsUsed = RelMdExpressionLineage.extractInputRefs(outputExpression);
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        Pair<ImmutableList<RexNode>, ImmutableList<RexNode>> calcProjectsAndFilter = calc.getProgram().split();
        for (int idx : inputFieldsUsed) {
            RexNode inputExpr = (RexNode)calcProjectsAndFilter.getKey().get(idx);
            Set<RexNode> originalExprs = mq.getExpressionLineage(input, inputExpr);
            if (originalExprs == null) {
                return null;
            }
            RexInputRef ref = RexInputRef.of(idx, calc.getRowType().getFieldList());
            mapping.put(ref, originalExprs);
        }
        return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, outputExpression, mapping);
    }

    protected static @Nullable Set<RexNode> createAllPossibleExpressions(RexBuilder rexBuilder, RexNode expr, Map<RexInputRef, Set<RexNode>> mapping) {
        ImmutableBitSet predFieldsUsed = RelMdExpressionLineage.extractInputRefs(expr);
        if (predFieldsUsed.isEmpty()) {
            return ImmutableSet.of((Object)expr);
        }
        try {
            return RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, expr, predFieldsUsed, mapping, new HashMap<RexInputRef, RexNode>());
        }
        catch (UnsupportedOperationException e) {
            return null;
        }
    }

    private static Set<RexNode> createAllPossibleExpressions(RexBuilder rexBuilder, RexNode expr, ImmutableBitSet predFieldsUsed, Map<RexInputRef, Set<RexNode>> mapping, Map<RexInputRef, RexNode> singleMapping) {
        @KeyFor(value={"mapping"}) RexInputRef inputRef = mapping.keySet().iterator().next();
        Set<RexNode> replacements = Objects.requireNonNull(mapping.remove(inputRef), () -> "mapping.remove(inputRef) is null for " + inputRef);
        HashSet<RexNode> result = new HashSet<RexNode>();
        assert (!replacements.isEmpty());
        if (predFieldsUsed.indexOf(inputRef.getIndex()) != -1) {
            for (RexNode replacement : replacements) {
                singleMapping.put(inputRef, replacement);
                RelMdExpressionLineage.createExpressions(rexBuilder, expr, predFieldsUsed, mapping, singleMapping, result);
                singleMapping.remove(inputRef);
            }
        } else {
            RelMdExpressionLineage.createExpressions(rexBuilder, expr, predFieldsUsed, mapping, singleMapping, result);
        }
        mapping.put(inputRef, replacements);
        return result;
    }

    private static void createExpressions(RexBuilder rexBuilder, RexNode expr, ImmutableBitSet predFieldsUsed, Map<RexInputRef, Set<RexNode>> mapping, Map<RexInputRef, RexNode> singleMapping, Set<RexNode> result) {
        if (mapping.isEmpty()) {
            RexReplacer replacer = new RexReplacer(singleMapping);
            ArrayList<RexNode> updatedPreds = new ArrayList<RexNode>(1);
            updatedPreds.add(rexBuilder.copy(expr));
            replacer.mutate(updatedPreds);
            result.addAll(updatedPreds);
        } else {
            result.addAll(RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, expr, predFieldsUsed, mapping, singleMapping));
        }
    }

    private static ImmutableBitSet extractInputRefs(RexNode expr) {
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        expr.accept(inputFinder);
        return inputFinder.build();
    }

    private static class RexReplacer
    extends RexShuttle {
        private final Map<RexInputRef, RexNode> replacementValues;

        RexReplacer(Map<RexInputRef, RexNode> replacementValues) {
            this.replacementValues = replacementValues;
        }

        @Override
        public RexNode visitInputRef(RexInputRef inputRef) {
            return Objects.requireNonNull(this.replacementValues.get(inputRef), () -> "no replacement found for inputRef " + inputRef);
        }
    }
}

