/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.logical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.FilterJoinRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.drill.exec.planner.DrillRelBuilder;
import org.apache.drill.exec.planner.logical.DrillRelFactories;

public class DrillFilterJoinRules {
    public static final FilterJoinRule.Predicate EQUAL_IS_DISTINCT_FROM = (join, joinType, exp) -> {
        if (joinType != JoinRelType.INNER) {
            return true;
        }
        ArrayList tmpLeftKeys = new ArrayList();
        ArrayList tmpRightKeys = new ArrayList();
        ArrayList<RelDataTypeField> sysFields = new ArrayList<RelDataTypeField>();
        ArrayList<Integer> filterNulls = new ArrayList<Integer>();
        List<RelNode> inputs = Arrays.asList(join.getLeft(), join.getRight());
        ArrayList<RexNode> nonEquiList = new ArrayList<RexNode>();
        DrillFilterJoinRules.splitJoinCondition(sysFields, inputs, exp, Arrays.asList(tmpLeftKeys, tmpRightKeys), filterNulls, null, nonEquiList);
        RexNode remaining = RexUtil.composeConjunction((RexBuilder)inputs.get(0).getCluster().getRexBuilder(), nonEquiList);
        return remaining.isAlwaysTrue();
    };
    public static final FilterJoinRule.Predicate STRICT_EQUAL_IS_DISTINCT_FROM = (join, joinType, exp) -> {
        if (joinType != JoinRelType.INNER) {
            return true;
        }
        ArrayList tmpLeftKeys = new ArrayList();
        ArrayList tmpRightKeys = new ArrayList();
        ArrayList filterNulls = new ArrayList();
        RexNode remaining = RelOptUtil.splitJoinCondition((RelNode)join.getLeft(), (RelNode)join.getRight(), (RexNode)exp, tmpLeftKeys, tmpRightKeys, filterNulls);
        return remaining.isAlwaysTrue();
    };
    public static final RelOptRule FILTER_INTO_JOIN = FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.DEFAULT.withPredicate(EQUAL_IS_DISTINCT_FROM).withRelBuilderFactory(DrillRelFactories.LOGICAL_BUILDER).toRule();
    public static final RelOptRule DRILL_FILTER_INTO_JOIN = FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.DEFAULT.withPredicate(STRICT_EQUAL_IS_DISTINCT_FROM).withRelBuilderFactory(DrillRelBuilder.proto(DrillRelFactories.DRILL_LOGICAL_PROJECT_FACTORY, DrillRelFactories.DRILL_LOGICAL_FILTER_FACTORY)).toRule();
    public static final RelOptRule JOIN_PUSH_CONDITION = FilterJoinRule.JoinConditionPushRule.JoinConditionPushRuleConfig.DEFAULT.withPredicate(EQUAL_IS_DISTINCT_FROM).withRelBuilderFactory(DrillRelFactories.LOGICAL_BUILDER).toRule();

    private static void splitJoinCondition(List<RelDataTypeField> sysFieldList, List<RelNode> inputs, RexNode condition, List<List<RexNode>> joinKeys, List<Integer> filterNulls, List<SqlOperator> rangeOp, List<RexNode> nonEquiList) {
        int sysFieldCount = sysFieldList.size();
        RelOptCluster cluster = inputs.get(0).getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        ImmutableBitSet[] inputsRange = new ImmutableBitSet[inputs.size()];
        int totalFieldCount = 0;
        for (int i = 0; i < inputs.size(); ++i) {
            int firstField = totalFieldCount + sysFieldCount;
            totalFieldCount = firstField + inputs.get(i).getRowType().getFieldCount();
            inputsRange[i] = ImmutableBitSet.range((int)firstField, (int)totalFieldCount);
        }
        int[] adjustments = new int[totalFieldCount];
        for (int i = 0; i < inputs.size(); ++i) {
            int adjustment;
            for (int j = adjustment = inputsRange[i].nextSetBit(0); j < inputsRange[i].length(); ++j) {
                adjustments[j] = -adjustment;
            }
        }
        if (condition.getKind() == SqlKind.AND) {
            for (RexNode operand : ((RexCall)condition).getOperands()) {
                DrillFilterJoinRules.splitJoinCondition(sysFieldList, inputs, operand, joinKeys, filterNulls, rangeOp, nonEquiList);
            }
            return;
        }
        if (condition instanceof RexCall) {
            RexNode leftKey = null;
            RexNode rightKey = null;
            int leftInput = 0;
            int rightInput = 0;
            List leftFields = null;
            List rightFields = null;
            boolean reverse = false;
            RexCall call = RelOptUtil.collapseExpandedIsNotDistinctFromExpr((RexCall)((RexCall)condition), (RexBuilder)rexBuilder);
            SqlKind kind = call.getKind();
            if (kind == SqlKind.EQUALS || filterNulls != null && kind == SqlKind.IS_NOT_DISTINCT_FROM || rangeOp != null && rangeOp.isEmpty() && (kind == SqlKind.GREATER_THAN || kind == SqlKind.GREATER_THAN_OR_EQUAL || kind == SqlKind.LESS_THAN || kind == SqlKind.LESS_THAN_OR_EQUAL)) {
                List operands = call.getOperands();
                RexNode op0 = (RexNode)operands.get(0);
                RexNode op1 = (RexNode)operands.get(1);
                ImmutableBitSet projRefs0 = RelOptUtil.InputFinder.bits((RexNode)op0);
                ImmutableBitSet projRefs1 = RelOptUtil.InputFinder.bits((RexNode)op1);
                boolean foundBothInputs = false;
                for (int i = 0; i < inputs.size() && !foundBothInputs; ++i) {
                    if (projRefs0.intersects(inputsRange[i]) && projRefs0.union(inputsRange[i]).equals((Object)inputsRange[i])) {
                        if (leftKey == null) {
                            leftKey = op0;
                            leftInput = i;
                            leftFields = inputs.get(leftInput).getRowType().getFieldList();
                            continue;
                        }
                        rightKey = op0;
                        rightInput = i;
                        rightFields = inputs.get(rightInput).getRowType().getFieldList();
                        reverse = true;
                        foundBothInputs = true;
                        continue;
                    }
                    if (!projRefs1.intersects(inputsRange[i]) || !projRefs1.union(inputsRange[i]).equals((Object)inputsRange[i])) continue;
                    if (leftKey == null) {
                        leftKey = op1;
                        leftInput = i;
                        leftFields = inputs.get(leftInput).getRowType().getFieldList();
                        continue;
                    }
                    rightKey = op1;
                    rightInput = i;
                    rightFields = inputs.get(rightInput).getRowType().getFieldList();
                    foundBothInputs = true;
                }
                if (leftKey != null && rightKey != null) {
                    RelDataType rightKeyType;
                    rightKey = (RexNode)rightKey.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, rightFields, rightFields, adjustments));
                    RelDataType leftKeyType = (leftKey = (RexNode)leftKey.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, leftFields, leftFields, adjustments))).getType();
                    if (leftKeyType != (rightKeyType = rightKey.getType())) {
                        RelDataType targetKeyType = typeFactory.leastRestrictive(Arrays.asList(leftKeyType, rightKeyType));
                        if (targetKeyType == null) {
                            throw new AssertionError((Object)("Cannot find common type for join keys " + leftKey + " (type " + leftKeyType + ") and " + rightKey + " (type " + rightKeyType + ")"));
                        }
                        if (leftKeyType != targetKeyType) {
                            leftKey = rexBuilder.makeCast(targetKeyType, leftKey);
                        }
                        if (rightKeyType != targetKeyType) {
                            rightKey = rexBuilder.makeCast(targetKeyType, rightKey);
                        }
                    }
                }
            }
            if (rangeOp == null && (leftKey == null || rightKey == null)) {
                ImmutableBitSet projRefs = RelOptUtil.InputFinder.bits((RexNode)condition);
                leftKey = null;
                rightKey = null;
                boolean foundInput = false;
                for (int i = 0; i < inputs.size() && !foundInput; ++i) {
                    if (!inputsRange[i].contains(projRefs)) continue;
                    leftInput = i;
                    leftFields = inputs.get(leftInput).getRowType().getFieldList();
                    leftKey = (RexNode)condition.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, leftFields, leftFields, adjustments));
                    rightKey = rexBuilder.makeLiteral(true);
                    kind = SqlKind.EQUALS;
                    foundInput = true;
                }
            }
            if (leftKey != null && rightKey != null) {
                DrillFilterJoinRules.addJoinKey(joinKeys.get(leftInput), leftKey, rangeOp != null && !rangeOp.isEmpty());
                DrillFilterJoinRules.addJoinKey(joinKeys.get(rightInput), rightKey, rangeOp != null && !rangeOp.isEmpty());
                if (filterNulls != null && kind == SqlKind.EQUALS) {
                    filterNulls.add(joinKeys.get(leftInput).size() - 1);
                }
                if (rangeOp != null && kind != SqlKind.EQUALS && kind != SqlKind.IS_DISTINCT_FROM) {
                    SqlOperator op = call.getOperator();
                    if (reverse) {
                        op = Objects.requireNonNull(op.reverse());
                    }
                    rangeOp.add(op);
                }
                return;
            }
        }
        nonEquiList.add(condition);
    }

    private static void addJoinKey(List<RexNode> joinKeyList, RexNode key, boolean preserveLastElementInList) {
        if (!joinKeyList.isEmpty() && preserveLastElementInList) {
            joinKeyList.add(joinKeyList.size() - 1, key);
        } else {
            joinKeyList.add(key);
        }
    }
}

