/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.planner.fragment;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.exec.ops.QueryContext;
import org.apache.drill.exec.physical.base.Exchange;
import org.apache.drill.exec.physical.base.PhysicalOperator;
import org.apache.drill.exec.physical.config.AbstractMuxExchange;
import org.apache.drill.exec.planner.AbstractOpWrapperVisitor;
import org.apache.drill.exec.planner.cost.NodeResource;
import org.apache.drill.exec.planner.fragment.Fragment;
import org.apache.drill.exec.planner.fragment.PlanningSet;
import org.apache.drill.exec.planner.fragment.Wrapper;
import org.apache.drill.exec.proto.CoordinationProtos;
import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;

public class MemoryCalculator
extends AbstractOpWrapperVisitor<Void, RuntimeException> {
    private final PlanningSet planningSet;
    private final Map<CoordinationProtos.DrillbitEndpoint, List<Pair<PhysicalOperator, Long>>> bufferedOperators;
    private final QueryContext queryContext;

    public MemoryCalculator(PlanningSet planningSet, QueryContext context) {
        this.planningSet = planningSet;
        this.bufferedOperators = new HashMap<CoordinationProtos.DrillbitEndpoint, List<Pair<PhysicalOperator, Long>>>();
        this.queryContext = context;
    }

    private Map<CoordinationProtos.DrillbitEndpoint, Integer> getMinorFragCountPerDrillbit(Wrapper currFragment) {
        return currFragment.getAssignedEndpoints().stream().collect(Collectors.groupingBy(Function.identity(), Collectors.summingInt(x -> 1)));
    }

    private void merge(Wrapper currFrag, Map<CoordinationProtos.DrillbitEndpoint, Integer> minorFragsPerDrillBit, Function<Map.Entry<CoordinationProtos.DrillbitEndpoint, Integer>, Long> getMemory) {
        NodeResource.merge(currFrag.getResourceMap(), minorFragsPerDrillBit.entrySet().stream().collect(Collectors.toMap(x -> (CoordinationProtos.DrillbitEndpoint)x.getKey(), x -> NodeResource.create(0L, (Long)getMemory.apply((Map.Entry<CoordinationProtos.DrillbitEndpoint, Integer>)x)))));
    }

    @Override
    public Void visitSendingExchange(Exchange exchange, Wrapper fragment) throws RuntimeException {
        Wrapper receivingFragment = this.planningSet.get(fragment.getNode().getSendingExchangePair().getNode());
        this.merge(fragment, this.getMinorFragCountPerDrillbit(fragment), x -> exchange.getSenderMemory(receivingFragment.getWidth(), (Integer)x.getValue()));
        return this.visitOp((PhysicalOperator)exchange, fragment);
    }

    @Override
    public Void visitReceivingExchange(Exchange exchange, Wrapper fragment) throws RuntimeException {
        List<Fragment.ExchangeFragmentPair> receivingExchangePairs = fragment.getNode().getReceivingExchangePairs();
        HashMap<CoordinationProtos.DrillbitEndpoint, Integer> sendingFragsPerDrillBit = new HashMap<CoordinationProtos.DrillbitEndpoint, Integer>();
        for (Fragment.ExchangeFragmentPair pair : receivingExchangePairs) {
            if (pair.getExchange() != exchange) continue;
            Wrapper sendingFragment = this.planningSet.get(pair.getNode());
            Preconditions.checkArgument(sendingFragment.isEndpointsAssignmentDone());
            for (CoordinationProtos.DrillbitEndpoint endpoint : sendingFragment.getAssignedEndpoints()) {
                sendingFragsPerDrillBit.putIfAbsent(endpoint, 0);
                sendingFragsPerDrillBit.put(endpoint, (Integer)sendingFragsPerDrillBit.get(endpoint) + 1);
            }
        }
        int totalSendingFrags = sendingFragsPerDrillBit.entrySet().stream().mapToInt(x -> (Integer)x.getValue()).reduce(0, (x, y) -> x + y);
        this.merge(fragment, this.getMinorFragCountPerDrillbit(fragment), x -> exchange.getReceiverMemory(fragment.getWidth(), exchange instanceof AbstractMuxExchange ? (Integer)sendingFragsPerDrillBit.get(x.getKey()) : totalSendingFrags));
        return null;
    }

    public List<Pair<PhysicalOperator, Long>> getBufferedOperators(CoordinationProtos.DrillbitEndpoint endpoint) {
        return this.bufferedOperators.getOrDefault(endpoint, new ArrayList());
    }

    @Override
    public Void visitOp(PhysicalOperator op, Wrapper fragment) {
        long memoryCost = (int)Math.ceil(op.getCost().getMemoryCost());
        if (op.isBufferedOperator(this.queryContext)) {
            long memoryCostPerMinorFrag = (int)Math.ceil(memoryCost / (long)fragment.getAssignedEndpoints().size());
            Map<CoordinationProtos.DrillbitEndpoint, Integer> drillbitEndpointMinorFragMap = this.getMinorFragCountPerDrillbit(fragment);
            Map<CoordinationProtos.DrillbitEndpoint, Pair> bufferedOperatorsPerDrillbit = drillbitEndpointMinorFragMap.entrySet().stream().collect(Collectors.toMap(x -> (CoordinationProtos.DrillbitEndpoint)x.getKey(), x -> Pair.of((Object)op, (Object)(memoryCostPerMinorFrag * (long)((Integer)x.getValue()).intValue()))));
            bufferedOperatorsPerDrillbit.entrySet().forEach(x -> {
                this.bufferedOperators.putIfAbsent((CoordinationProtos.DrillbitEndpoint)x.getKey(), new ArrayList());
                this.bufferedOperators.get(x.getKey()).add((Pair<PhysicalOperator, Long>)((Pair)x.getValue()));
            });
            this.merge(fragment, drillbitEndpointMinorFragMap, x -> memoryCostPerMinorFrag * (long)((Integer)x.getValue()).intValue());
        } else {
            this.merge(fragment, this.getMinorFragCountPerDrillbit(fragment), x -> memoryCost * (long)((Integer)x.getValue()).intValue());
        }
        for (PhysicalOperator child : op) {
            child.accept(this, fragment);
        }
        return null;
    }
}

