/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.logical;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.Pair;
import org.apache.drill.exec.physical.base.DbGroupScan;
import org.apache.drill.exec.planner.index.rules.MatchFunction;
import org.apache.drill.exec.planner.logical.DrillAggregateRel;
import org.apache.drill.exec.planner.logical.DrillFilterRel;
import org.apache.drill.exec.planner.logical.DrillJoin;
import org.apache.drill.exec.planner.logical.DrillLimitRel;
import org.apache.drill.exec.planner.logical.DrillProjectRel;
import org.apache.drill.exec.planner.logical.DrillScanRel;
import org.apache.drill.exec.planner.logical.DrillSemiJoinRel;
import org.apache.drill.exec.planner.logical.RelOptHelper;
import org.apache.drill.exec.planner.logical.RowKeyJoinCallContext;
import org.apache.drill.exec.planner.logical.RowKeyJoinRel;
import org.apache.drill.exec.planner.physical.PrelUtil;
import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DrillPushRowKeyJoinToScanRule
extends RelOptRule {
    static final Logger logger = LoggerFactory.getLogger(DrillPushRowKeyJoinToScanRule.class);
    public final MatchFunction match;
    public static DrillPushRowKeyJoinToScanRule JOIN = new DrillPushRowKeyJoinToScanRule(RelOptHelper.any(DrillJoin.class), "DrillPushRowKeyJoinToScanRule_Join", new MatchRelJ());

    private DrillPushRowKeyJoinToScanRule(RelOptRuleOperand operand, String description, MatchFunction match) {
        super(operand, description);
        this.match = match;
    }

    public boolean matches(RelOptRuleCall call) {
        return this.match.match(call);
    }

    public void onMatch(RelOptRuleCall call) {
        this.doOnMatch((RowKeyJoinCallContext)this.match.onMatch(call));
    }

    private static boolean canSwapJoinInputs(DrillJoin joinRel, RowKeyJoinCallContext.RowKey rowKeyLocation) {
        if (rowKeyLocation == RowKeyJoinCallContext.RowKey.LEFT || rowKeyLocation == RowKeyJoinCallContext.RowKey.BOTH) {
            return DrillPushRowKeyJoinToScanRule.canSwapJoinInputsInternal(joinRel.getRight());
        }
        if (rowKeyLocation == RowKeyJoinCallContext.RowKey.RIGHT) {
            return false;
        }
        return false;
    }

    private static boolean canSwapJoinInputsInternal(RelNode rel) {
        if (rel instanceof DrillAggregateRel && ((DrillAggregateRel)rel).getAggCallList().size() > 0) {
            return false;
        }
        if (rel instanceof HepRelVertex) {
            return DrillPushRowKeyJoinToScanRule.canSwapJoinInputsInternal(((HepRelVertex)rel).getCurrentRel());
        }
        if (rel instanceof RelSubset) {
            if (((RelSubset)rel).getBest() != null) {
                return DrillPushRowKeyJoinToScanRule.canSwapJoinInputsInternal(((RelSubset)rel).getBest());
            }
            return DrillPushRowKeyJoinToScanRule.canSwapJoinInputsInternal(((RelSubset)rel).getOriginal());
        }
        for (RelNode child : rel.getInputs()) {
            if (DrillPushRowKeyJoinToScanRule.canSwapJoinInputsInternal(child)) continue;
            return false;
        }
        return true;
    }

    private static Pair<Boolean, Pair<RowKeyJoinCallContext.RowKey, Integer>> canPushRowKeyJoinToScan(DrillJoin joinRel, RelOptPlanner planner) {
        double sel;
        RowKeyJoinCallContext.RowKey rowKeyLoc = RowKeyJoinCallContext.RowKey.NONE;
        logger.debug("canPushRowKeyJoinToScan(): Check: Rel={}", (Object)joinRel);
        if (joinRel instanceof RowKeyJoinRel) {
            logger.debug("SKIP: Join is a RowKeyJoin");
            return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
        }
        if (joinRel.getJoinType() != JoinRelType.INNER) {
            logger.debug("SKIP: JoinType={} - NOT an INNER join", (Object)joinRel.getJoinType());
            return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
        }
        if (joinRel.getCondition().getKind() != SqlKind.EQUALS || joinRel.getLeftKeys().size() != 1 || joinRel.getRightKeys().size() != 1) {
            logger.debug("SKIP: #LeftKeys={}, #RightKeys={} - NOT single predicate join condition", (Object)joinRel.getLeftKeys().size(), (Object)joinRel.getRightKeys().size());
            return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
        }
        boolean hasLeftRowKeyCol = false;
        boolean hasRightRowKeyCol = false;
        int leftRowKeyPos = -1;
        int rightRowKeyPos = -1;
        if (joinRel.getCondition() instanceof RexCall) {
            for (RexNode op : ((RexCall)joinRel.getCondition()).getOperands()) {
                if (!(op instanceof RexInputRef)) continue;
                int pos = ((RexInputRef)op).getIndex();
                if (pos < joinRel.getLeft().getRowType().getFieldList().size()) {
                    if (!DrillPushRowKeyJoinToScanRule.isRowKeyColumn(((RexInputRef)op).getIndex(), joinRel.getLeft())) continue;
                    logger.debug("FOUND Primary-key: Side=LEFT, RowType={}", (Object)joinRel.getLeft().getRowType());
                    hasLeftRowKeyCol = true;
                    leftRowKeyPos = pos;
                    break;
                }
                if (!DrillPushRowKeyJoinToScanRule.isRowKeyColumn(pos - joinRel.getLeft().getRowType().getFieldList().size(), joinRel.getRight())) continue;
                logger.debug("FOUND Primary-key: Side=RIGHT, RowType={}", (Object)joinRel.getRight().getRowType());
                hasRightRowKeyCol = true;
                rightRowKeyPos = pos;
                break;
            }
        }
        if (!hasLeftRowKeyCol && !hasRightRowKeyCol) {
            logger.debug("SKIP: Primary-key = column condition NOT found");
            return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
        }
        RelNode leftScan = DrillPushRowKeyJoinToScanRule.getValidJoinInput(joinRel.getLeft());
        RelNode rightScan = DrillPushRowKeyJoinToScanRule.getValidJoinInput(joinRel.getRight());
        if (leftScan == null && rightScan == null) {
            logger.debug("SKIP: Blocking operators between join and scans");
            return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
        }
        if (leftScan != null && hasLeftRowKeyCol) {
            rowKeyLoc = RowKeyJoinCallContext.RowKey.LEFT;
        }
        if (rightScan != null && hasRightRowKeyCol) {
            rowKeyLoc = rowKeyLoc == RowKeyJoinCallContext.RowKey.LEFT ? RowKeyJoinCallContext.RowKey.BOTH : RowKeyJoinCallContext.RowKey.RIGHT;
        }
        RelMetadataQuery mq = RelMetadataQuery.instance();
        double ncSel = PrelUtil.getPlannerSettings(planner).getRowKeyJoinConversionSelThreshold();
        if (rowKeyLoc == RowKeyJoinCallContext.RowKey.NONE) {
            return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
        }
        if (rowKeyLoc == RowKeyJoinCallContext.RowKey.LEFT) {
            sel = DrillPushRowKeyJoinToScanRule.computeSelectivity(joinRel.getRight().estimateRowCount(mq), leftScan.estimateRowCount(mq));
            if (sel > ncSel) {
                logger.debug("SKIP: SEL= {}/{} = {}\\%, THRESHOLD={}\\%", new Object[]{joinRel.getRight().estimateRowCount(mq), leftScan.estimateRowCount(mq), sel * 100.0, ncSel * 100.0});
                return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
            }
        } else {
            sel = DrillPushRowKeyJoinToScanRule.computeSelectivity(joinRel.getLeft().estimateRowCount(mq), rightScan.estimateRowCount(mq));
            if (sel > ncSel) {
                logger.debug("SKIP: SEL= {}/{} = {}\\%, THRESHOLD={}\\%", new Object[]{joinRel.getLeft().estimateRowCount(mq), rightScan.estimateRowCount(mq), sel * 100.0, ncSel * 100.0});
                return Pair.of((Object)false, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)-1));
            }
        }
        int rowKeyPos = rowKeyLoc == RowKeyJoinCallContext.RowKey.RIGHT ? rightRowKeyPos : leftRowKeyPos;
        logger.info("FOUND Primary-key: Side={}, RowTypePos={}, Sel={}, Threshold={}", new Object[]{rowKeyLoc.name(), rowKeyPos, sel, ncSel});
        return Pair.of((Object)true, (Object)Pair.of((Object)((Object)rowKeyLoc), (Object)rowKeyPos));
    }

    private static double computeSelectivity(double selectRows, double totalRows) {
        if (totalRows <= 0.0) {
            return 1.0;
        }
        return Math.min(1.0, Math.max(0.0, selectRows / totalRows));
    }

    public static RelNode getValidJoinInput(RelNode rel) {
        if (rel instanceof DrillScanRel) {
            return rel;
        }
        if (rel instanceof DrillProjectRel || rel instanceof DrillFilterRel || rel instanceof DrillLimitRel) {
            for (RelNode child : rel.getInputs()) {
                RelNode tgt = DrillPushRowKeyJoinToScanRule.getValidJoinInput(child);
                if (tgt == null) continue;
                return tgt;
            }
        } else {
            if (rel instanceof HepRelVertex) {
                return DrillPushRowKeyJoinToScanRule.getValidJoinInput(((HepRelVertex)rel).getCurrentRel());
            }
            if (rel instanceof RelSubset) {
                if (((RelSubset)rel).getBest() != null) {
                    return DrillPushRowKeyJoinToScanRule.getValidJoinInput(((RelSubset)rel).getBest());
                }
                return DrillPushRowKeyJoinToScanRule.getValidJoinInput(((RelSubset)rel).getOriginal());
            }
        }
        return null;
    }

    private static boolean isRowKeyColumn(int index, RelNode rel) {
        RelNode curRel = rel;
        int curIndex = index;
        while (curRel != null && !(curRel instanceof DrillScanRel)) {
            DrillProjectRel projectRel;
            List childExprs;
            logger.debug("IsRowKeyColumn: Rel={}, RowTypePos={}, RowType={}", new Object[]{curRel.toString(), curIndex, curRel.getRowType().toString()});
            if (curRel instanceof HepRelVertex) {
                curRel = ((HepRelVertex)curRel).getCurrentRel();
            } else if (curRel instanceof RelSubset) {
                curRel = ((RelSubset)curRel).getBest() != null ? ((RelSubset)curRel).getBest() : ((RelSubset)curRel).getOriginal();
            } else {
                RelNode child = null;
                for (RelNode input : curRel.getInputs()) {
                    if (input.getRowType().getFieldList().size() <= curIndex) {
                        curIndex -= input.getRowType().getFieldList().size();
                        continue;
                    }
                    child = input;
                    break;
                }
                curRel = child;
            }
            if (!(curRel instanceof DrillProjectRel) || (childExprs = (projectRel = (DrillProjectRel)curRel).getProjects()) == null || childExprs.size() <= 0) continue;
            if (childExprs.get(curIndex) instanceof RexInputRef) {
                curIndex = ((RexInputRef)childExprs.get(curIndex)).getIndex();
                continue;
            }
            logger.debug("IsRowKeyColumn: ABORT: Primary-key EXPR$={}", (Object)((RexNode)childExprs.get(curIndex)).toString());
            return false;
        }
        logger.debug("IsRowKeyColumn:Primary-key Col={} ", curRel != null ? curRel.getRowType().getFieldNames().get(curIndex) : "??");
        if (curRel != null && curRel instanceof DrillScanRel && ((DrillScanRel)curRel).getGroupScan() instanceof DbGroupScan) {
            DbGroupScan dbGroupScan = (DbGroupScan)((DrillScanRel)curRel).getGroupScan();
            String rowKeyName = dbGroupScan.getRowKeyName();
            DbGroupScan restrictedGroupScan = dbGroupScan.getRestrictedScan(((DrillScanRel)curRel).getColumns());
            if (restrictedGroupScan != null && ((String)curRel.getRowType().getFieldNames().get(curIndex)).equalsIgnoreCase(rowKeyName)) {
                logger.debug("IsRowKeyColumn: FOUND: Rel={}, RowTypePos={}, RowType={}", new Object[]{curRel.toString(), curIndex, curRel.getRowType().toString()});
                return true;
            }
        }
        logger.debug("IsRowKeyColumn: NOT FOUND");
        return false;
    }

    protected void doOnMatch(RowKeyJoinCallContext rkjCallContext) {
        if (rkjCallContext.getRowKeyLocation() != RowKeyJoinCallContext.RowKey.NONE) {
            this.doOnMatch(rkjCallContext.getCall(), rkjCallContext.getRowKeyPosition(), rkjCallContext.mustSwapInputs(), rkjCallContext.getJoinRel(), rkjCallContext.getUpperProjectRel(), rkjCallContext.getFilterRel(), rkjCallContext.getLowerProjectRel(), rkjCallContext.getScanRel());
        }
    }

    private void doOnMatch(RelOptRuleCall call, int rowKeyPosition, boolean swapInputs, DrillJoin joinRel, DrillProjectRel upperProjectRel, DrillFilterRel filterRel, DrillProjectRel lowerProjectRel, DrillScanRel scanRel) {
        logger.debug("Transforming: Swapping of join inputs is required!");
        RelNode right = swapInputs ? joinRel.getLeft() : joinRel.getRight();
        ImmutableList<Integer> leftJoinKeys = ImmutableList.of(Integer.valueOf(rowKeyPosition));
        List<Integer> rightJoinKeys = swapInputs ? joinRel.getLeftKeys() : joinRel.getRightKeys();
        DbGroupScan restrictedGroupScan = ((DbGroupScan)scanRel.getGroupScan()).getRestrictedScan(scanRel.getColumns());
        DrillScanRel leftRel = new DrillScanRel(scanRel.getCluster(), scanRel.getTraitSet(), scanRel.getTable(), restrictedGroupScan, scanRel.getRowType(), scanRel.getColumns(), scanRel.partitionFilterPushdown());
        if (lowerProjectRel != null) {
            leftRel = lowerProjectRel.copy(lowerProjectRel.getTraitSet(), ImmutableList.of(leftRel));
        }
        if (filterRel != null) {
            leftRel = filterRel.copy(filterRel.getTraitSet(), leftRel, filterRel.getCondition());
        }
        if (upperProjectRel != null) {
            leftRel = upperProjectRel.copy(upperProjectRel.getTraitSet(), ImmutableList.of(leftRel));
        }
        RexNode joinCondition = RelOptUtil.createEquiJoinCondition((RelNode)leftRel, leftJoinKeys, (RelNode)right, rightJoinKeys, (RexBuilder)joinRel.getCluster().getRexBuilder());
        logger.debug("Transforming: LeftKeys={}, LeftRowType={}, RightKeys={}, RightRowType={}", new Object[]{leftJoinKeys, leftRel.getRowType(), rightJoinKeys, right.getRowType()});
        RowKeyJoinRel rowKeyJoin = new RowKeyJoinRel(joinRel.getCluster(), joinRel.getTraitSet(), (RelNode)leftRel, right, joinCondition, joinRel.getJoinType(), joinRel instanceof DrillSemiJoinRel);
        logger.info("Transforming: SUCCESS: Register runtime filter pushdown plan (rowkeyjoin)");
        call.transformTo((RelNode)rowKeyJoin);
    }

    public static class MatchRelJ
    implements MatchFunction<RowKeyJoinCallContext> {
        private List<RelNode> findRelSequence(Class[] relSequence, RelNode startingRel) {
            ArrayList<RelNode> matchingRels = new ArrayList<RelNode>();
            this.findRelSequenceInternal(relSequence, 0, startingRel, matchingRels);
            return matchingRels;
        }

        private void findRelSequenceInternal(Class[] classes, int idx, RelNode rel, List<RelNode> matchingRels) {
            if (rel instanceof HepRelVertex) {
                this.findRelSequenceInternal(classes, idx, ((HepRelVertex)rel).getCurrentRel(), matchingRels);
            } else if (rel instanceof RelSubset) {
                if (((RelSubset)rel).getBest() != null) {
                    this.findRelSequenceInternal(classes, idx, ((RelSubset)rel).getBest(), matchingRels);
                } else {
                    this.findRelSequenceInternal(classes, idx, ((RelSubset)rel).getOriginal(), matchingRels);
                }
            } else if (classes[idx].isInstance(rel)) {
                matchingRels.add(rel);
                if (idx + 1 < classes.length && rel.getInputs().size() > 0) {
                    this.findRelSequenceInternal(classes, idx + 1, rel.getInput(0), matchingRels);
                }
            } else {
                if (logger.isDebugEnabled()) {
                    int i;
                    StringBuffer sb = new StringBuffer();
                    for (i = 0; i < classes.length; ++i) {
                        if (i == classes.length - 1) {
                            sb.append(classes[i].getCanonicalName().toString());
                            continue;
                        }
                        sb.append(classes[i].getCanonicalName().toString() + "->");
                    }
                    String sequence = sb.toString();
                    sb.delete(0, sb.length());
                    for (i = 0; i < matchingRels.size(); ++i) {
                        if (i == matchingRels.size() - 1) {
                            sb.append(matchingRels.get(i).getClass().getCanonicalName().toString());
                            continue;
                        }
                        sb.append(matchingRels.get(i).getClass().getCanonicalName().toString() + "->");
                    }
                    String matchingSequence = sb.toString();
                    logger.debug("FindRelSequence: ABORT: Unexpected Rel={}, After={}, CurSeq={}", new Object[]{rel.getClass().getCanonicalName().toString(), matchingSequence, sequence});
                }
                matchingRels.clear();
            }
        }

        private RowKeyJoinCallContext generateContext(RelOptRuleCall call, DrillJoin joinRel, RelNode joinChildRel, RowKeyJoinCallContext.RowKey rowKeyLoc, int rowKeyPos, boolean swapInputs) {
            Class[] PFPS = new Class[]{DrillProjectRel.class, DrillFilterRel.class, DrillProjectRel.class, DrillScanRel.class};
            Class[] FPS = new Class[]{DrillFilterRel.class, DrillProjectRel.class, DrillScanRel.class};
            Class[] PS = new Class[]{DrillProjectRel.class, DrillScanRel.class};
            Class[] FS = new Class[]{DrillFilterRel.class, DrillScanRel.class};
            Class[] S = new Class[]{DrillScanRel.class};
            logger.debug("GenerateContext(): Primary-key: Side={}, RowTypePos={}, SwapInputs={}", new Object[]{rowKeyLoc.name(), rowKeyPos, swapInputs});
            List<RelNode> matchingRels = this.findRelSequence(PFPS, joinChildRel);
            if (matchingRels.size() > 0) {
                logger.debug("Matched rel sequence : Project->Filter->Project->Scan");
                return new RowKeyJoinCallContext(call, rowKeyLoc, rowKeyPos, swapInputs, joinRel, (DrillProjectRel)matchingRels.get(0), (DrillFilterRel)matchingRels.get(1), (DrillProjectRel)matchingRels.get(2), (DrillScanRel)matchingRels.get(3));
            }
            matchingRels = this.findRelSequence(FPS, joinChildRel);
            if (matchingRels.size() > 0) {
                logger.debug("Matched rel sequence : Filter->Project->Scan");
                return new RowKeyJoinCallContext(call, rowKeyLoc, rowKeyPos, swapInputs, joinRel, null, (DrillFilterRel)matchingRels.get(0), (DrillProjectRel)matchingRels.get(1), (DrillScanRel)matchingRels.get(2));
            }
            matchingRels = this.findRelSequence(PS, joinChildRel);
            if (matchingRels.size() > 0) {
                logger.debug("Matched rel sequence : Project->Scan");
                return new RowKeyJoinCallContext(call, rowKeyLoc, rowKeyPos, swapInputs, joinRel, null, null, (DrillProjectRel)matchingRels.get(0), (DrillScanRel)matchingRels.get(1));
            }
            matchingRels = this.findRelSequence(FS, joinChildRel);
            if (matchingRels.size() > 0) {
                logger.debug("Matched rel sequence : Filter->Scan");
                return new RowKeyJoinCallContext(call, rowKeyLoc, rowKeyPos, swapInputs, joinRel, null, (DrillFilterRel)matchingRels.get(0), null, (DrillScanRel)matchingRels.get(1));
            }
            matchingRels = this.findRelSequence(S, joinChildRel);
            if (matchingRels.size() > 0) {
                logger.debug("Matched rel sequence : Scan");
                return new RowKeyJoinCallContext(call, rowKeyLoc, rowKeyPos, swapInputs, joinRel, null, null, null, (DrillScanRel)matchingRels.get(0));
            }
            logger.debug("Matched rel sequence : None");
            return new RowKeyJoinCallContext(call, RowKeyJoinCallContext.RowKey.NONE, -1, false, null, null, null, null, null);
        }

        @Override
        public boolean match(RelOptRuleCall call) {
            DrillJoin joinRel = (DrillJoin)call.rel(0);
            logger.debug("DrillPushRowKeyJoinToScanRule begin()");
            return (Boolean)DrillPushRowKeyJoinToScanRule.canPushRowKeyJoinToScan((DrillJoin)joinRel, (RelOptPlanner)call.getPlanner()).left;
        }

        @Override
        public RowKeyJoinCallContext onMatch(RelOptRuleCall call) {
            DrillJoin joinRel = (DrillJoin)call.rel(0);
            Pair res = DrillPushRowKeyJoinToScanRule.canPushRowKeyJoinToScan(joinRel, call.getPlanner());
            if (((Boolean)res.left).booleanValue()) {
                if (((Pair)res.right).left == RowKeyJoinCallContext.RowKey.LEFT) {
                    return this.generateContext(call, joinRel, joinRel.getLeft(), (RowKeyJoinCallContext.RowKey)((Object)((Pair)res.right).left), (Integer)((Pair)res.right).right, false);
                }
                if (((Pair)res.right).left == RowKeyJoinCallContext.RowKey.RIGHT) {
                    if (DrillPushRowKeyJoinToScanRule.canSwapJoinInputs(joinRel, (RowKeyJoinCallContext.RowKey)((Object)((Pair)res.right).left))) {
                        return this.generateContext(call, joinRel, joinRel.getRight(), (RowKeyJoinCallContext.RowKey)((Object)((Pair)res.right).left), (Integer)((Pair)res.right).right, true);
                    }
                } else if (((Pair)res.right).left == RowKeyJoinCallContext.RowKey.BOTH) {
                    return this.generateContext(call, joinRel, joinRel.getLeft(), (RowKeyJoinCallContext.RowKey)((Object)((Pair)res.right).left), (Integer)((Pair)res.right).right, false);
                }
            }
            return new RowKeyJoinCallContext(call, RowKeyJoinCallContext.RowKey.NONE, -1, false, null, null, null, null, null);
        }
    }
}

