/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.logical;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.drill.exec.planner.logical.DrillAggregateRel;
import org.apache.drill.exec.planner.logical.DrillRelFactories;
import org.apache.drill.exec.planner.logical.DrillWindowRel;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper;
import org.apache.drill.exec.planner.sql.DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper;
import org.apache.drill.exec.planner.sql.DrillSqlOperator;
import org.apache.drill.exec.planner.sql.TypeInferenceUtils;
import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility;
import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableList;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.drill.shaded.guava.com.google.common.collect.Maps;

public class DrillReduceAggregatesRule
extends RelOptRule {
    public static final DrillReduceAggregatesRule INSTANCE = new DrillReduceAggregatesRule(DrillReduceAggregatesRule.operand(LogicalAggregate.class, (RelOptRuleOperandChildren)DrillReduceAggregatesRule.any()));
    public static final DrillConvertSumToSumZero INSTANCE_SUM = new DrillConvertSumToSumZero(DrillReduceAggregatesRule.operand(DrillAggregateRel.class, (RelOptRuleOperandChildren)DrillReduceAggregatesRule.any()));
    public static final DrillConvertWindowSumToSumZero INSTANCE_WINDOW_SUM = new DrillConvertWindowSumToSumZero(DrillReduceAggregatesRule.operand(DrillWindowRel.class, (RelOptRuleOperandChildren)DrillReduceAggregatesRule.any()));
    private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false, new SqlReturnTypeInference(){

        public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
            return TypeInferenceUtils.createCalciteTypeWithNullability(opBinding.getTypeFactory(), SqlTypeName.ANY, opBinding.getOperandType(0).isNullable());
        }
    }, false);

    protected DrillReduceAggregatesRule(RelOptRuleOperand operand) {
        super(operand, DrillRelFactories.LOGICAL_BUILDER, null);
    }

    public boolean matches(RelOptRuleCall call) {
        if (!super.matches(call)) {
            return false;
        }
        Aggregate oldAggRel = (Aggregate)call.rels[0];
        return this.containsAvgStddevVarCall(oldAggRel.getAggCallList());
    }

    public void onMatch(RelOptRuleCall ruleCall) {
        Aggregate oldAggRel = (Aggregate)ruleCall.rels[0];
        this.reduceAggs(ruleCall, oldAggRel);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
        for (AggregateCall call : aggCallList) {
            SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getAggregation());
            if (!(sqlAggFunction instanceof SqlAvgAggFunction) && !(sqlAggFunction instanceof SqlSumAggFunction)) continue;
            return true;
        }
        return false;
    }

    private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        List oldCalls = oldAggRel.getAggCallList();
        int nGroups = oldAggRel.getGroupCount();
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        HashMap<AggregateCall, RexNode> aggCallMapping = new HashMap<AggregateCall, RexNode>();
        ArrayList<Object> projList = new ArrayList<Object>();
        for (int i = 0; i < nGroups; ++i) {
            projList.add(rexBuilder.makeInputRef(this.getFieldType((RelNode)oldAggRel, i), i));
        }
        RelNode input = oldAggRel.getInput();
        ArrayList<RexNode> inputExprs = new ArrayList<RexNode>();
        for (RelDataTypeField field : input.getRowType().getFieldList()) {
            inputExprs.add((RexNode)rexBuilder.makeInputRef(field.getType(), inputExprs.size()));
        }
        for (AggregateCall oldCall : oldCalls) {
            projList.add(this.reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
        }
        int extraArgCount = inputExprs.size() - input.getRowType().getFieldCount();
        if (extraArgCount > 0) {
            input = this.relBuilderFactory.create(input.getCluster(), null).push(input).projectNamed(inputExprs, (Iterable)CompositeList.of((List)input.getRowType().getFieldNames(), Collections.nCopies(extraArgCount, null)), true).build();
        }
        Aggregate newAggRel = this.newAggregateRel(oldAggRel, input, newCalls);
        RelNode projectRel = this.relBuilderFactory.create(newAggRel.getCluster(), null).push((RelNode)newAggRel).projectNamed(projList, (Iterable)oldAggRel.getRowType().getFieldNames(), true).build();
        ruleCall.transformTo(projectRel);
    }

    private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
        if (sqlAggFunction instanceof SqlSumAggFunction) {
            return this.reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
        }
        if (sqlAggFunction instanceof SqlAvgAggFunction) {
            if (oldCall.getType().getSqlTypeName() == SqlTypeName.DECIMAL) {
                return oldAggRel.getCluster().getRexBuilder().addAggCall(oldCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, ImmutableList.of(this.getFieldType(oldAggRel.getInput(), (Integer)oldCall.getArgList().get(0))));
            }
            SqlKind subtype = sqlAggFunction.getKind();
            switch (subtype) {
                case AVG: {
                    return this.reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
                }
                case STDDEV_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
                }
                case STDDEV_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
                }
            }
            throw Util.unexpected((Enum)subtype);
        }
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int nGroups = oldAggRel.getGroupCount();
        ArrayList<RelDataType> oldArgTypes = new ArrayList<RelDataType>();
        List ordinals = oldCall.getArgList();
        assert (ordinals.size() <= inputExprs.size());
        Iterator iterator = ordinals.iterator();
        while (iterator.hasNext()) {
            int ordinal = (Integer)iterator.next();
            oldArgTypes.add(inputExprs.get(ordinal).getType());
        }
        if (aggCallMapping.containsKey(oldCall) && !aggCallMapping.get(oldCall).getType().equals(oldCall.getType())) {
            int index = newCalls.size() + nGroups;
            newCalls.add(oldCall);
            return rexBuilder.makeInputRef(oldCall.getType(), index);
        }
        return rexBuilder.addAggCall(oldCall, nGroups, newCalls, aggCallMapping, oldArgTypes);
    }

    private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
        RelDataType sumType;
        PlannerSettings plannerSettings = (PlannerSettings)oldAggRel.getCluster().getPlanner().getContext();
        boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
        int nGroups = oldAggRel.getGroupCount();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int iAvgInput = (Integer)oldCall.getArgList().get(0);
        RelDataType avgInputType = this.getFieldType(oldAggRel.getInput(), iAvgInput);
        sumType = typeFactory.createTypeWithNullability(sumType, (sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.of()).inferReturnType((SqlOperatorBinding)oldCall.createBinding(oldAggRel))).isNullable() || nGroups == 0);
        DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper sumAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
        AggregateCall sumCall = DrillReduceAggregatesRule.getAggCall(oldCall, (SqlAggFunction)sumAgg, sumType);
        SqlCountAggFunction countAgg = (SqlCountAggFunction)SqlStdOperatorTable.COUNT;
        RelDataType countType = countAgg.getReturnType(typeFactory);
        AggregateCall countCall = DrillReduceAggregatesRule.getAggCall(oldCall, (SqlAggFunction)countAgg, countType);
        RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        RexNode n = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)}), rexBuilder.constantNull(), tmpsumRef});
        RexNode numeratorRef = rexBuilder.makeCall((SqlOperator)CastHighOp, new RexNode[]{n});
        RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        if (isInferenceEnabled) {
            return rexBuilder.makeCall((SqlOperator)new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), new RexNode[]{numeratorRef, denominatorRef});
        }
        RexNode divideRef = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{numeratorRef, denominatorRef});
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
    }

    private static AggregateCall getAggCall(AggregateCall oldCall, SqlAggFunction aggFunction, RelDataType sumType) {
        return AggregateCall.create((SqlAggFunction)aggFunction, (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (boolean)oldCall.ignoreNulls(), (List)oldCall.getArgList(), (int)oldCall.filterArg, (ImmutableBitSet)oldCall.distinctKeys, (RelCollation)oldCall.getCollation(), (RelDataType)sumType, null);
    }

    private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
        PlannerSettings plannerSettings = (PlannerSettings)oldAggRel.getCluster().getPlanner().getContext();
        boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
        int nGroups = oldAggRel.getGroupCount();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int arg = (Integer)oldCall.getArgList().get(0);
        RelDataType argType = this.getFieldType(oldAggRel.getInput(), arg);
        RelDataType sumType = isInferenceEnabled ? oldCall.getType() : typeFactory.createTypeWithNullability(oldCall.getType(), argType.isNullable());
        DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
        AggregateCall sumZeroCall = DrillReduceAggregatesRule.getAggCall(oldCall, (SqlAggFunction)sumZeroAgg, sumType);
        SqlCountAggFunction countAgg = (SqlCountAggFunction)SqlStdOperatorTable.COUNT;
        RelDataType countType = countAgg.getReturnType(typeFactory);
        AggregateCall countCall = DrillReduceAggregatesRule.getAggCall(oldCall, (SqlAggFunction)countAgg, countType);
        RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        if (!oldCall.getType().isNullable()) {
            return sumZeroRef;
        }
        RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)}), rexBuilder.constantNull(), sumZeroRef});
    }

    private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        RexNode div;
        RexNode denominator;
        PlannerSettings plannerSettings = (PlannerSettings)oldAggRel.getCluster().getPlanner().getContext();
        boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
        int nGroups = oldAggRel.getGroupCount();
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        assert (oldCall.getArgList().size() == 1) : oldCall.getArgList();
        int argOrdinal = (Integer)oldCall.getArgList().get(0);
        RelDataType argType = this.getFieldType(oldAggRel.getInput(), argOrdinal);
        RexNode argRef = rexBuilder.makeCall((SqlOperator)CastHighOp, new RexNode[]{inputExprs.get(argOrdinal)});
        inputExprs.set(argOrdinal, argRef);
        RexNode argSquared = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{argRef, argRef});
        int argSquaredOrdinal = DrillReduceAggregatesRule.lookupOrAdd(inputExprs, argSquared);
        RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.of()).inferReturnType((SqlOperatorBinding)oldCall.createBinding(oldAggRel));
        sumType = typeFactory.createTypeWithNullability(sumType, true);
        AggregateCall sumArgSquaredAggCall = AggregateCall.create((SqlAggFunction)new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction)new SqlSumAggFunction(sumType), sumType), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (boolean)oldCall.ignoreNulls(), (List)ImmutableIntList.of((int[])new int[]{argSquaredOrdinal}), (int)oldCall.filterArg, (ImmutableBitSet)oldCall.distinctKeys, (RelCollation)oldCall.getCollation(), (RelDataType)sumType, null);
        RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        AggregateCall sumArgAggCall = AggregateCall.create((SqlAggFunction)new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction)new SqlSumAggFunction(sumType), sumType), (boolean)oldCall.isDistinct(), (boolean)oldCall.isApproximate(), (boolean)oldCall.ignoreNulls(), (List)ImmutableIntList.of((int[])new int[]{argOrdinal}), (int)oldCall.filterArg, (ImmutableBitSet)oldCall.distinctKeys, (RelCollation)oldCall.getCollation(), (RelDataType)sumType, null);
        RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        RexNode sumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumArg, sumArg});
        SqlCountAggFunction countAgg = (SqlCountAggFunction)SqlStdOperatorTable.COUNT;
        RelDataType countType = countAgg.getReturnType(typeFactory);
        AggregateCall countArgAggCall = DrillReduceAggregatesRule.getAggCall(oldCall, (SqlAggFunction)countAgg, countType);
        RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        RexNode avgSumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{sumSquaredArg, countArg});
        RexNode diff = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{sumArgSquared, avgSumSquaredArg});
        if (biased) {
            denominator = countArg;
        } else {
            RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexLiteral nul = rexBuilder.makeNullLiteral(countArg.getType());
            RexNode countMinusOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{countArg, one});
            RexNode countEqOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{countArg, one});
            denominator = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{countEqOne, nul, countMinusOne});
        }
        Object divide = isInferenceEnabled ? new DrillSqlOperator("divide", 2, true, oldCall.getType(), false) : SqlStdOperatorTable.DIVIDE;
        RexNode result = div = rexBuilder.makeCall((SqlOperator)divide, new RexNode[]{diff, denominator});
        if (sqrt) {
            RexLiteral half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
            result = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.POWER, new RexNode[]{div, half});
        }
        if (isInferenceEnabled) {
            return result;
        }
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
    }

    private static <T> int lookupOrAdd(List<T> list, T element) {
        int ordinal = list.indexOf(element);
        if (ordinal == -1) {
            ordinal = list.size();
            list.add(element);
        }
        return ordinal;
    }

    protected Aggregate newAggregateRel(Aggregate oldAggRel, RelNode inputRel, List<AggregateCall> newCalls) {
        RelOptCluster cluster = inputRel.getCluster();
        return new LogicalAggregate(cluster, cluster.traitSetOf((RelTrait)Convention.NONE), Collections.emptyList(), inputRel, oldAggRel.getGroupSet(), (List)oldAggRel.getGroupSets(), newCalls);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        RelDataTypeField inputField = (RelDataTypeField)relNode.getRowType().getFieldList().get(i);
        return inputField.getType();
    }

    private static boolean isConversionToSumZeroNeeded(SqlOperator sqlOperator, RelDataType type) {
        return (sqlOperator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(sqlOperator)) instanceof SqlSumAggFunction && !type.isNullable();
    }

    private static class DrillConvertSumToSumZero
    extends RelOptRule {
        public DrillConvertSumToSumZero(RelOptRuleOperand operand) {
            super(operand, DrillRelFactories.LOGICAL_BUILDER, null);
        }

        public boolean matches(RelOptRuleCall call) {
            DrillAggregateRel oldAggRel = (DrillAggregateRel)call.rels[0];
            for (AggregateCall aggregateCall : oldAggRel.getAggCallList()) {
                if (!DrillReduceAggregatesRule.isConversionToSumZeroNeeded((SqlOperator)aggregateCall.getAggregation(), aggregateCall.getType())) continue;
                return true;
            }
            return false;
        }

        public void onMatch(RelOptRuleCall call) {
            DrillAggregateRel oldAggRel = (DrillAggregateRel)call.rels[0];
            HashMap aggCallMapping = Maps.newHashMap();
            ArrayList<AggregateCall> newAggregateCalls = Lists.newArrayList();
            for (AggregateCall oldAggregateCall : oldAggRel.getAggCallList()) {
                if (DrillReduceAggregatesRule.isConversionToSumZeroNeeded((SqlOperator)oldAggregateCall.getAggregation(), oldAggregateCall.getType())) {
                    RelDataType argType = oldAggregateCall.getType();
                    RelDataType sumType = oldAggRel.getCluster().getTypeFactory().createTypeWithNullability(argType, argType.isNullable());
                    DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
                    AggregateCall sumZeroCall = AggregateCall.create((SqlAggFunction)sumZeroAgg, (boolean)oldAggregateCall.isDistinct(), (boolean)oldAggregateCall.isApproximate(), (boolean)oldAggregateCall.ignoreNulls(), (List)oldAggregateCall.getArgList(), (int)oldAggregateCall.filterArg, (ImmutableBitSet)oldAggregateCall.distinctKeys, (RelCollation)oldAggregateCall.getCollation(), (RelDataType)sumType, (String)oldAggregateCall.getName());
                    oldAggRel.getCluster().getRexBuilder().addAggCall(sumZeroCall, oldAggRel.getGroupCount(), newAggregateCalls, aggCallMapping, ImmutableList.of(argType));
                    continue;
                }
                newAggregateCalls.add(oldAggregateCall);
            }
            call.transformTo((RelNode)new DrillAggregateRel(oldAggRel.getCluster(), oldAggRel.getTraitSet(), oldAggRel.getInput(), oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newAggregateCalls));
        }
    }

    private static class DrillConvertWindowSumToSumZero
    extends RelOptRule {
        public DrillConvertWindowSumToSumZero(RelOptRuleOperand operand) {
            super(operand, DrillRelFactories.LOGICAL_BUILDER, null);
        }

        public boolean matches(RelOptRuleCall call) {
            DrillWindowRel oldWinRel = (DrillWindowRel)call.rels[0];
            for (Window.Group group : oldWinRel.groups) {
                for (Window.RexWinAggCall rexWinAggCall : group.aggCalls) {
                    if (!DrillReduceAggregatesRule.isConversionToSumZeroNeeded(rexWinAggCall.getOperator(), rexWinAggCall.getType())) continue;
                    return true;
                }
            }
            return false;
        }

        public void onMatch(RelOptRuleCall call) {
            DrillWindowRel oldWinRel = (DrillWindowRel)call.rels[0];
            ImmutableList.Builder builder = ImmutableList.builder();
            for (Window.Group group : oldWinRel.groups) {
                ArrayList<Window.RexWinAggCall> aggCalls = Lists.newArrayList();
                for (Window.RexWinAggCall rexWinAggCall : group.aggCalls) {
                    if (DrillReduceAggregatesRule.isConversionToSumZeroNeeded(rexWinAggCall.getOperator(), rexWinAggCall.getType())) {
                        RelDataType argType = rexWinAggCall.getType();
                        RelDataType sumType = oldWinRel.getCluster().getTypeFactory().createTypeWithNullability(argType, argType.isNullable());
                        DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper sumZeroAgg = new DrillCalciteSqlSumEmptyIsZeroAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
                        Window.RexWinAggCall sumZeroCall = new Window.RexWinAggCall((SqlAggFunction)sumZeroAgg, sumType, (List)rexWinAggCall.operands, rexWinAggCall.ordinal, rexWinAggCall.distinct);
                        aggCalls.add(sumZeroCall);
                        continue;
                    }
                    aggCalls.add(rexWinAggCall);
                }
                Window.Group newGroup = new Window.Group(group.keys, group.isRows, group.lowerBound, group.upperBound, group.orderKeys, aggCalls);
                builder.add(newGroup);
            }
            call.transformTo((RelNode)new DrillWindowRel(oldWinRel.getCluster(), oldWinRel.getTraitSet(), oldWinRel.getInput(), oldWinRel.constants, oldWinRel.getRowType(), (List<Window.Group>)((Object)builder.build())));
        }
    }
}

