/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.physical.stream;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCalc;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalChangelogNormalize;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange;
import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;

@Internal
public class PushFilterPastChangelogNormalizeRule
extends RelRule<Config> {
    public static final RelOptRule INSTANCE = Config.EMPTY.as(Config.class).onMatch().toRule();

    public PushFilterPastChangelogNormalizeRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        StreamPhysicalCalc calc = (StreamPhysicalCalc)call.rel(0);
        StreamPhysicalChangelogNormalize changelogNormalize = (StreamPhysicalChangelogNormalize)call.rel(1);
        RexProgram program = calc.getProgram();
        RexNode condition = RexUtil.toCnf(call.builder().getRexBuilder(), program.expandLocalRef(program.getCondition()));
        Set<Integer> primaryKeyIndices = IntStream.of(changelogNormalize.uniqueKeys()).boxed().collect(Collectors.toSet());
        ArrayList<RexNode> primaryKeyPredicates = new ArrayList<RexNode>();
        ArrayList<RexNode> otherPredicates = new ArrayList<RexNode>();
        this.partitionPrimaryKeyPredicates(RelOptUtil.conjunctions(condition), primaryKeyIndices, primaryKeyPredicates, otherPredicates);
        StreamPhysicalChangelogNormalize newChangelogNormalize = this.pushFiltersThroughChangelogNormalize(call, primaryKeyPredicates);
        this.transformWithRemainingPredicates(call, newChangelogNormalize, otherPredicates);
    }

    private void partitionPrimaryKeyPredicates(List<RexNode> predicates, Set<Integer> primaryKeyIndices, List<RexNode> primaryKeyPredicates, List<RexNode> remainingPredicates) {
        for (RexNode predicate : predicates) {
            int[] inputRefs = RexNodeExtractor.extractRefInputFields(Collections.singletonList(predicate));
            if (Arrays.stream(inputRefs).allMatch(primaryKeyIndices::contains)) {
                primaryKeyPredicates.add(predicate);
                continue;
            }
            remainingPredicates.add(predicate);
        }
    }

    private StreamPhysicalChangelogNormalize pushFiltersThroughChangelogNormalize(RelOptRuleCall call, List<RexNode> primaryKeyPredicates) {
        StreamPhysicalChangelogNormalize changelogNormalize = (StreamPhysicalChangelogNormalize)call.rel(1);
        StreamPhysicalExchange exchange = (StreamPhysicalExchange)call.rel(2);
        if (primaryKeyPredicates.isEmpty()) {
            return changelogNormalize;
        }
        StreamPhysicalCalc pushedFiltersCalc = this.projectIdentityWithConditions(call.builder(), exchange.getInput(), primaryKeyPredicates);
        StreamPhysicalExchange newExchange = (StreamPhysicalExchange)exchange.copy(exchange.getTraitSet(), Collections.singletonList(pushedFiltersCalc));
        return (StreamPhysicalChangelogNormalize)changelogNormalize.copy(changelogNormalize.getTraitSet(), Collections.singletonList(newExchange));
    }

    private StreamPhysicalCalc projectIdentityWithConditions(RelBuilder relBuilder, RelNode newInput, List<RexNode> conditions) {
        RexProgramBuilder programBuilder = new RexProgramBuilder(newInput.getRowType(), relBuilder.getRexBuilder());
        programBuilder.addIdentity();
        RexNode condition = relBuilder.and(conditions);
        if (!condition.isAlwaysTrue()) {
            programBuilder.addCondition(condition);
        }
        RexProgram newProgram = programBuilder.getProgram();
        return new StreamPhysicalCalc(newInput.getCluster(), newInput.getTraitSet(), newInput, newProgram, newProgram.getOutputRowType());
    }

    private StreamPhysicalCalc projectWith(RelBuilder relBuilder, StreamPhysicalCalc projectFromCalc, StreamPhysicalCalc calc) {
        RexProgramBuilder programBuilder = new RexProgramBuilder(calc.getRowType(), relBuilder.getRexBuilder());
        if (calc.getProgram().getCondition() != null) {
            programBuilder.addCondition(calc.getProgram().expandLocalRef(calc.getProgram().getCondition()));
        }
        for (Pair<RexLocalRef, String> projectRef : projectFromCalc.getProgram().getNamedProjects()) {
            RexNode project = projectFromCalc.getProgram().expandLocalRef((RexLocalRef)projectRef.left);
            programBuilder.addProject(project, (String)projectRef.right);
        }
        RexProgram newProgram = programBuilder.getProgram();
        return (StreamPhysicalCalc)calc.copy(calc.getTraitSet(), calc.getInput(), newProgram);
    }

    private void transformWithRemainingPredicates(RelOptRuleCall call, StreamPhysicalChangelogNormalize changelogNormalize, List<RexNode> predicates) {
        StreamPhysicalCalc newCalc;
        StreamPhysicalCalc calc = (StreamPhysicalCalc)call.rel(0);
        RelBuilder relBuilder = call.builder();
        StreamPhysicalCalc newProjectedCalc = this.projectWith(relBuilder, calc, newCalc = this.projectIdentityWithConditions(relBuilder, changelogNormalize, predicates));
        if (newProjectedCalc.getProgram().isTrivial()) {
            call.transformTo(changelogNormalize);
        } else {
            call.transformTo(newProjectedCalc);
        }
    }

    public static interface Config
    extends RelRule.Config {
        @Override
        default public RelOptRule toRule() {
            return new PushFilterPastChangelogNormalizeRule(this);
        }

        default public Config onMatch() {
            RelRule.OperandTransform exchangeTransform = operandBuilder -> operandBuilder.operand(StreamPhysicalExchange.class).anyInputs();
            RelRule.OperandTransform changelogNormalizeTransform = operandBuilder -> operandBuilder.operand(StreamPhysicalChangelogNormalize.class).oneInput(exchangeTransform);
            RelRule.OperandTransform calcTransform = operandBuilder -> operandBuilder.operand(StreamPhysicalCalc.class).predicate(calc -> calc.getProgram().getCondition() != null).oneInput(changelogNormalizeTransform);
            return this.withOperandSupplier(calcTransform).as(Config.class);
        }
    }
}

