/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.util;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.kylin.query.relnode.KapAggregateRel;
import org.apache.kylin.query.relnode.KapJoinRel;
import org.apache.kylin.query.relnode.KapProjectRel;

public class RexUtils {
    private RexUtils() {
    }

    public static boolean joinMoreThanOneTable(Join join) {
        HashSet<Integer> left = new HashSet<Integer>();
        HashSet<Integer> right = new HashSet<Integer>();
        Set<Integer> indexes = RexUtils.getAllInputRefs(join.getCondition()).stream().map(RexSlot::getIndex).collect(Collectors.toSet());
        RexUtils.splitJoinInputIndex(join, indexes, left, right);
        return !RexUtils.colsComeFromSameSideOfJoin(join.getLeft(), left) || !RexUtils.colsComeFromSameSideOfJoin(join.getRight(), right);
    }

    private static boolean colsComeFromSameSideOfJoin(RelNode rel, Set<Integer> indexes) {
        if (rel instanceof Join) {
            Join join = (Join)rel;
            HashSet<Integer> left = new HashSet<Integer>();
            HashSet<Integer> right = new HashSet<Integer>();
            RexUtils.splitJoinInputIndex(join, indexes, left, right);
            if (left.isEmpty()) {
                return RexUtils.colsComeFromSameSideOfJoin(join.getRight(), right);
            }
            if (right.isEmpty()) {
                return RexUtils.colsComeFromSameSideOfJoin(join.getLeft(), left);
            }
            return false;
        }
        if (rel instanceof Project) {
            Set<Integer> inputIndexes = indexes.stream().map(idx -> (RexNode)((Project)rel).getProjects().get((int)idx)).flatMap(rex -> RexUtils.getAllInputRefs(rex).stream()).map(RexSlot::getIndex).collect(Collectors.toSet());
            return RexUtils.colsComeFromSameSideOfJoin(((Project)rel).getInput(), inputIndexes);
        }
        if (rel instanceof TableScan || rel instanceof Values) {
            return true;
        }
        return RexUtils.colsComeFromSameSideOfJoin(rel.getInput(0), indexes);
    }

    public static void splitJoinInputIndex(Join joinRel, Collection<Integer> indexes, Set<Integer> leftInputIndexes, Set<Integer> rightInputIndexes) {
        indexes.forEach(idx -> {
            if (idx < joinRel.getLeft().getRowType().getFieldCount()) {
                leftInputIndexes.add((Integer)idx);
            } else {
                rightInputIndexes.add(idx - joinRel.getLeft().getRowType().getFieldCount());
            }
        });
    }

    public static int countOperatorCall(RexNode condition, final Class<? extends SqlOperator> sqlOperator) {
        final AtomicInteger likeCount = new AtomicInteger(0);
        RexVisitorImpl<Void> likeVisitor = new RexVisitorImpl<Void>(true){

            public Void visitCall(RexCall call) {
                if (call.getOperator().getClass().equals(sqlOperator)) {
                    likeCount.incrementAndGet();
                }
                return (Void)super.visitCall(call);
            }
        };
        condition.accept((RexVisitor)likeVisitor);
        return likeCount.get();
    }

    public static Set<RexInputRef> getAllInputRefs(RexNode rexNode) {
        if (rexNode instanceof RexInputRef) {
            return Collections.singleton((RexInputRef)rexNode);
        }
        if (rexNode instanceof RexCall) {
            return RexUtils.getAllInputRefsCall((RexCall)rexNode);
        }
        return Collections.emptySet();
    }

    private static Set<RexInputRef> getAllInputRefsCall(RexCall rexCall) {
        return rexCall.getOperands().stream().flatMap(rexNode -> RexUtils.getAllInputRefs(rexNode).stream()).collect(Collectors.toSet());
    }

    public static boolean isMerelyTableColumnReference(RelNode rel, Collection<Integer> columnIndexes) {
        if (rel instanceof KapProjectRel) {
            return RexUtils.isProjectMerelyTableColumnReference((KapProjectRel)rel, columnIndexes);
        }
        if (rel instanceof KapAggregateRel) {
            return RexUtils.isAggMerelyTableColumnReference((KapAggregateRel)rel, columnIndexes);
        }
        if (rel instanceof KapJoinRel) {
            return RexUtils.isJoinMerelyTableColumnReference(rel, columnIndexes);
        }
        for (RelNode inputRel : rel.getInputs()) {
            if (RexUtils.isMerelyTableColumnReference(inputRel, columnIndexes)) continue;
            return false;
        }
        return true;
    }

    private static boolean isJoinMerelyTableColumnReference(RelNode rel, Collection<Integer> columnIndexes) {
        int offset = 0;
        for (RelNode inputRel : rel.getInputs()) {
            HashSet<Integer> nextInputRefKeys = new HashSet<Integer>();
            for (Integer columnIdx : columnIndexes) {
                if (columnIdx - offset < 0 || columnIdx - offset >= inputRel.getRowType().getFieldCount()) continue;
                nextInputRefKeys.add(columnIdx - offset);
            }
            if (!RexUtils.isMerelyTableColumnReference(inputRel, nextInputRefKeys)) {
                return false;
            }
            offset += inputRel.getRowType().getFieldCount();
        }
        return true;
    }

    private static boolean isAggMerelyTableColumnReference(KapAggregateRel rel, Collection<Integer> columnIndexes) {
        HashSet<Integer> nextInputRefKeys = new HashSet<Integer>();
        KapAggregateRel agg = rel;
        for (Integer columnIdx : columnIndexes) {
            if (columnIdx >= agg.getRewriteGroupKeys().size()) {
                return false;
            }
            nextInputRefKeys.add((Integer)agg.getRewriteGroupKeys().get(columnIdx.intValue()));
        }
        return RexUtils.isMerelyTableColumnReference(agg.getInput(), nextInputRefKeys);
    }

    private static boolean isProjectMerelyTableColumnReference(KapProjectRel rel, Collection<Integer> columnIndexes) {
        HashSet<Integer> nextInputRefKeys = new HashSet<Integer>();
        KapProjectRel project = rel;
        for (Integer columnIdx : columnIndexes) {
            RexNode projExp = project.getProjects().get(columnIdx);
            if (projExp.getKind() == SqlKind.CAST) {
                projExp = (RexNode)((RexCall)projExp).getOperands().get(0);
            }
            if (!(projExp instanceof RexInputRef)) {
                return false;
            }
            nextInputRefKeys.add(((RexInputRef)projExp).getIndex());
        }
        return RexUtils.isMerelyTableColumnReference(project.getInput(), nextInputRefKeys);
    }

    public static boolean isMerelyTableColumnReference(KapJoinRel rel, RexNode condition) {
        return RexUtils.isMerelyTableColumnReference((RelNode)rel, RexUtils.getAllInputRefs(condition).stream().map(RexSlot::getIndex).collect(Collectors.toSet()));
    }

    public static RexNode stripOffCastInColumnEqualPredicate(RexNode predicateNode) {
        if (!(predicateNode instanceof RexCall)) {
            return predicateNode;
        }
        RexCall predicate = (RexCall)predicateNode;
        if (predicate.getKind() == SqlKind.EQUALS) {
            boolean colEqualPredWithCast = false;
            ArrayList predicateOperands = Lists.newArrayList((Iterable)predicate.getOperands());
            for (int predicateOpIdx = 0; predicateOpIdx < predicateOperands.size(); ++predicateOpIdx) {
                RexNode predicateChild = (RexNode)predicateOperands.get(predicateOpIdx);
                if (predicateChild instanceof RexInputRef || !(predicateChild instanceof RexCall) || predicateChild.getKind() != SqlKind.CAST || !(((RexCall)predicateChild).getOperands().get(0) instanceof RexInputRef)) continue;
                predicateOperands.set(predicateOpIdx, ((RexCall)predicateOperands.get(predicateOpIdx)).getOperands().get(0));
                colEqualPredWithCast = true;
            }
            if (colEqualPredWithCast) {
                return predicate.clone(predicate.getType(), (List)predicateOperands);
            }
        }
        return predicate;
    }
}

