/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.stats;

import com.facebook.presto.hive.$internal.org.slf4j.Logger;
import com.facebook.presto.hive.$internal.org.slf4j.LoggerFactory;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelVisitor;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SemiJoin;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMdRowCount;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.stats.HiveRelMdUniqueKeys;

public class HiveRelMdRowCount
extends RelMdRowCount {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveRelMdRowCount.class.getName());
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource((Method)BuiltInMethod.ROW_COUNT.method, (MetadataHandler)new HiveRelMdRowCount());

    protected HiveRelMdRowCount() {
    }

    public Double getRowCount(Join join, RelMetadataQuery mq) {
        PKFKRelationInfo pkfk = HiveRelMdRowCount.analyzeJoinForPKFK(join, mq);
        if (pkfk != null) {
            double selectivity = pkfk.pkInfo.selectivity * pkfk.ndvScalingFactor;
            selectivity = Math.min(1.0, selectivity);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Identified Primary - Foreign Key relation: {} {}", (Object)RelOptUtil.toString((RelNode)join), (Object)pkfk);
            }
            return pkfk.fkInfo.rowCount * selectivity;
        }
        return join.estimateRowCount(mq);
    }

    public Double getRowCount(SemiJoin rel, RelMetadataQuery mq) {
        PKFKRelationInfo pkfk = HiveRelMdRowCount.analyzeJoinForPKFK((Join)rel, mq);
        if (pkfk != null) {
            double selectivity = pkfk.pkInfo.selectivity * pkfk.ndvScalingFactor;
            selectivity = Math.min(1.0, selectivity);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Identified Primary - Foreign Key relation: {} {}", (Object)RelOptUtil.toString((RelNode)rel), (Object)pkfk);
            }
            return pkfk.fkInfo.rowCount * selectivity;
        }
        return super.getRowCount(rel, mq);
    }

    public Double getRowCount(Sort rel, RelMetadataQuery mq) {
        int limit;
        int offset;
        Double offsetLimit;
        Double rowCount = mq.getRowCount(rel.getInput());
        if (rowCount != null && rel.fetch != null && (offsetLimit = new Double((offset = rel.offset == null ? 0 : RexLiteral.intValue((RexNode)rel.offset)) + (limit = RexLiteral.intValue((RexNode)rel.fetch)))) < rowCount) {
            return offsetLimit;
        }
        return rowCount;
    }

    public static PKFKRelationInfo analyzeJoinForPKFK(Join joinRel, RelMetadataQuery mq) {
        int pkSide;
        boolean rightIsKey;
        RelNode left = (RelNode)joinRel.getInputs().get(0);
        RelNode right = (RelNode)joinRel.getInputs().get(1);
        List initJoinFilters = RelOptUtil.conjunctions((RexNode)joinRel.getCondition());
        if (initJoinFilters.isEmpty()) {
            return null;
        }
        ArrayList<RexNode> leftFilters = new ArrayList<RexNode>();
        ArrayList<RexNode> rightFilters = new ArrayList<RexNode>();
        ArrayList<RexNode> joinFilters = new ArrayList<RexNode>(initJoinFilters);
        if (joinRel instanceof SemiJoin) {
            return null;
        }
        RelOptUtil.classifyFilters((RelNode)joinRel, joinFilters, (JoinRelType)joinRel.getJoinType(), (boolean)false, (!joinRel.getJoinType().generatesNullsOnRight() ? 1 : 0) != 0, (!joinRel.getJoinType().generatesNullsOnLeft() ? 1 : 0) != 0, joinFilters, leftFilters, rightFilters);
        Pair<Integer, Integer> joinCols = HiveRelMdRowCount.canHandleJoin(joinRel, leftFilters, rightFilters, joinFilters);
        if (joinCols == null) {
            return null;
        }
        int leftColIdx = (Integer)joinCols.left;
        int rightColIdx = (Integer)joinCols.right;
        RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
        RexNode leftPred = RexUtil.composeConjunction((RexBuilder)rexBuilder, leftFilters, (boolean)true);
        RexNode rightPred = RexUtil.composeConjunction((RexBuilder)rexBuilder, rightFilters, (boolean)true);
        ImmutableBitSet lBitSet = ImmutableBitSet.of((int[])new int[]{leftColIdx});
        ImmutableBitSet rBitSet = ImmutableBitSet.of((int[])new int[]{rightColIdx});
        boolean leftIsKey = (joinRel.getJoinType() == JoinRelType.INNER || joinRel.getJoinType() == JoinRelType.RIGHT) && !(joinRel instanceof SemiJoin) && HiveRelMdRowCount.isKey(lBitSet, left, mq);
        boolean bl = rightIsKey = (joinRel.getJoinType() == JoinRelType.INNER || joinRel.getJoinType() == JoinRelType.LEFT) && HiveRelMdRowCount.isKey(rBitSet, right, mq);
        if (!leftIsKey && !rightIsKey) {
            return null;
        }
        double leftRowCount = mq.getRowCount(left);
        double rightRowCount = mq.getRowCount(right);
        if (leftIsKey && rightIsKey && rightRowCount < leftRowCount) {
            leftIsKey = false;
        }
        int n = leftIsKey ? 0 : (pkSide = rightIsKey ? 1 : -1);
        boolean isPKSideSimpleTree = pkSide != -1 ? IsSimpleTreeOnJoinKey.check(pkSide == 0 ? left : right, pkSide == 0 ? leftColIdx : rightColIdx, mq) : false;
        double leftNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(left, lBitSet, leftPred) : -1.0;
        double rightNDV = isPKSideSimpleTree ? mq.getDistinctRowCount(right, rBitSet, rightPred) : -1.0;
        double ndvScalingFactor = 1.0;
        if (isPKSideSimpleTree) {
            double d = ndvScalingFactor = pkSide == 0 ? leftNDV / rightNDV : rightNDV / leftNDV;
        }
        if (pkSide == 0) {
            FKSideInfo fkInfo = new FKSideInfo(rightRowCount, rightNDV);
            double pkSelectivity = HiveRelMdRowCount.pkSelectivity(joinRel, mq, true, left, leftRowCount);
            PKSideInfo pkInfo = new PKSideInfo(leftRowCount, leftNDV, joinRel.getJoinType().generatesNullsOnRight() ? 1.0 : pkSelectivity);
            return new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree);
        }
        if (pkSide == 1) {
            FKSideInfo fkInfo = new FKSideInfo(leftRowCount, leftNDV);
            double pkSelectivity = HiveRelMdRowCount.pkSelectivity(joinRel, mq, false, right, rightRowCount);
            PKSideInfo pkInfo = new PKSideInfo(rightRowCount, rightNDV, joinRel.getJoinType().generatesNullsOnLeft() ? 1.0 : pkSelectivity);
            return new PKFKRelationInfo(1, fkInfo, pkInfo, ndvScalingFactor, isPKSideSimpleTree);
        }
        return null;
    }

    private static double pkSelectivity(Join joinRel, RelMetadataQuery mq, boolean leftChild, RelNode child, double childRowCount) {
        if (leftChild && joinRel.getJoinType().generatesNullsOnRight() || !leftChild && joinRel.getJoinType().generatesNullsOnLeft()) {
            return 1.0;
        }
        HiveTableScan tScan = HiveRelMdUniqueKeys.getTableScan(child, true);
        if (tScan != null) {
            double tRowCount = mq.getRowCount((RelNode)tScan);
            return childRowCount / tRowCount;
        }
        return 1.0;
    }

    private static boolean isKey(ImmutableBitSet c, RelNode rel, RelMetadataQuery mq) {
        boolean isKey = false;
        Set keys = mq.getUniqueKeys(rel);
        if (keys != null) {
            for (ImmutableBitSet key : keys) {
                if (!key.equals((Object)c)) continue;
                isKey = true;
                break;
            }
        }
        return isKey;
    }

    private static Pair<Integer, Integer> canHandleJoin(Join joinRel, List<RexNode> leftFilters, List<RexNode> rightFilters, List<RexNode> joinFilters) {
        if (joinFilters.size() != 1) {
            return null;
        }
        RexNode joinCond = joinFilters.get(0);
        if (!(joinCond instanceof RexCall)) {
            return null;
        }
        if (((RexCall)joinCond).getOperator() != SqlStdOperatorTable.EQUALS) {
            return null;
        }
        ImmutableBitSet leftCols = RelOptUtil.InputFinder.bits((RexNode)((RexNode)((RexCall)joinCond).getOperands().get(0)));
        ImmutableBitSet rightCols = RelOptUtil.InputFinder.bits((RexNode)((RexNode)((RexCall)joinCond).getOperands().get(1)));
        if (leftCols.cardinality() != 1 || rightCols.cardinality() != 1) {
            return null;
        }
        int nFieldsLeft = joinRel.getLeft().getRowType().getFieldList().size();
        int nFieldsRight = joinRel.getRight().getRowType().getFieldList().size();
        int nSysFields = joinRel.getSystemFieldList().size();
        ImmutableBitSet rightFieldsBitSet = ImmutableBitSet.range((int)(nSysFields + nFieldsLeft), (int)(nSysFields + nFieldsLeft + nFieldsRight));
        if (rightFieldsBitSet.contains(leftCols)) {
            ImmutableBitSet t = leftCols;
            leftCols = rightCols;
            rightCols = t;
        }
        int leftColIdx = leftCols.nextSetBit(0) - nSysFields;
        int rightColIdx = rightCols.nextSetBit(0) - (nSysFields + nFieldsLeft);
        return new Pair((Object)leftColIdx, (Object)rightColIdx);
    }

    private static class IsSimpleTreeOnJoinKey
    extends RelVisitor {
        int joinKey;
        boolean simpleTree;
        RelMetadataQuery mq;

        static boolean check(RelNode r, int joinKey, RelMetadataQuery mq) {
            IsSimpleTreeOnJoinKey v = new IsSimpleTreeOnJoinKey(joinKey, mq);
            v.go(r);
            return v.simpleTree;
        }

        IsSimpleTreeOnJoinKey(int joinKey, RelMetadataQuery mq) {
            this.joinKey = joinKey;
            this.mq = mq;
            this.simpleTree = true;
        }

        public void visit(RelNode node, int ordinal, RelNode parent) {
            if (node instanceof HepRelVertex) {
                node = ((HepRelVertex)node).getCurrentRel();
            }
            this.simpleTree = node instanceof TableScan ? true : (node instanceof Project ? this.isSimple((Project)node) : (node instanceof Filter ? this.isSimple((Filter)node, this.mq) : false));
            if (this.simpleTree) {
                super.visit(node, ordinal, parent);
            }
        }

        private boolean isSimple(Project project) {
            RexNode r = (RexNode)project.getProjects().get(this.joinKey);
            if (r instanceof RexInputRef) {
                this.joinKey = ((RexInputRef)r).getIndex();
                return true;
            }
            return false;
        }

        private boolean isSimple(Filter filter, RelMetadataQuery mq) {
            ImmutableBitSet condBits = RelOptUtil.InputFinder.bits((RexNode)filter.getCondition());
            return HiveRelMdRowCount.isKey(condBits, (RelNode)filter, mq);
        }
    }

    static class PKSideInfo
    extends FKSideInfo {
        public final double selectivity;

        public PKSideInfo(double rowCount, double distinctCount, double selectivity) {
            super(rowCount, distinctCount);
            this.selectivity = selectivity;
        }

        @Override
        public String toString() {
            return String.format("PKInfo(rowCount=%.2f,ndv=%.2f,selectivity=%.2f)", this.rowCount, this.distinctCount, this.selectivity);
        }
    }

    static class FKSideInfo {
        public final double rowCount;
        public final double distinctCount;

        public FKSideInfo(double rowCount, double distinctCount) {
            this.rowCount = rowCount;
            this.distinctCount = distinctCount;
        }

        public String toString() {
            return String.format("FKInfo(rowCount=%.2f,ndv=%.2f)", this.rowCount, this.distinctCount);
        }
    }

    static class PKFKRelationInfo {
        public final int fkSide;
        public final double ndvScalingFactor;
        public final FKSideInfo fkInfo;
        public final PKSideInfo pkInfo;
        public final boolean isPKSideSimple;

        PKFKRelationInfo(int fkSide, FKSideInfo fkInfo, PKSideInfo pkInfo, double ndvScalingFactor, boolean isPKSideSimple) {
            this.fkSide = fkSide;
            this.fkInfo = fkInfo;
            this.pkInfo = pkInfo;
            this.ndvScalingFactor = ndvScalingFactor;
            this.isPKSideSimple = isPKSideSimple;
        }

        public String toString() {
            return String.format("Primary - Foreign Key join:\n\tfkSide = %d\n\tFKInfo:%s\n\tPKInfo:%s\n\tisPKSideSimple:%s\n\tNDV Scaling Factor:%.2f\n", this.fkSide, this.fkInfo, this.pkInfo, this.isPKSideSimple, this.ndvScalingFactor);
        }
    }
}

