/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.dag.library.vertexmanager;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.EdgeManager;
import org.apache.tez.dag.api.EdgeManagerContext;
import org.apache.tez.dag.api.EdgeManagerDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;

public class ShuffleVertexManager
implements VertexManagerPlugin {
    private static final String TEZ_AM_PREFIX = "tez.am.";
    public static final String TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION = "tez.am.shuffle-vertex-manager.min-src-fraction";
    public static final float TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT = 0.25f;
    public static final String TEZ_AM_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION = "tez.am.shuffle-vertex-manager.max-src-fraction";
    public static final float TEZ_AM_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT = 0.75f;
    public static final String TEZ_AM_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL = "tez.am.shuffle-vertex-manager.enable.auto-parallel";
    public static final boolean TEZ_AM_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT = false;
    public static final String TEZ_AM_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE = "tez.am.shuffle-vertex-manager.desired-task-input-size";
    public static final long TEZ_AM_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT = 0x6400000L;
    public static final String TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM = "tez.am.shuffle-vertex-manager.min-task-parallelism";
    public static final int TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT = 1;
    private static final Log LOG = LogFactory.getLog(ShuffleVertexManager.class);
    VertexManagerPluginContext context;
    float slowStartMinSrcCompletionFraction;
    float slowStartMaxSrcCompletionFraction;
    long desiredTaskInputDataSize = 0x6400000L;
    int minTaskParallelism = 1;
    boolean enableAutoParallelism = false;
    boolean parallelismDetermined = false;
    int numSourceTasks = 0;
    int numSourceTasksCompleted = 0;
    int numVertexManagerEventsReceived = 0;
    ArrayList<Integer> pendingTasks;
    int totalTasksToSchedule = 0;
    Map<String, Set<Integer>> bipartiteSources = Maps.newHashMap();
    long completedSourceTasksOutputSize = 0L;

    public void onVertexStarted(Map<String, List<Integer>> completions) {
        this.pendingTasks = new ArrayList(this.context.getVertexNumTasks(this.context.getVertexName()));
        this.updatePendingTasks();
        this.updateSourceTaskCount();
        LOG.info((Object)("OnVertexStarted vertex: " + this.context.getVertexName() + " with " + this.numSourceTasks + " source tasks and " + this.totalTasksToSchedule + " pending tasks"));
        if (completions != null) {
            for (Map.Entry<String, List<Integer>> entry : completions.entrySet()) {
                for (Integer taskId : entry.getValue()) {
                    this.onSourceTaskCompleted(entry.getKey(), taskId);
                }
            }
        }
        this.schedulePendingTasks();
    }

    public void onSourceTaskCompleted(String srcVertexName, Integer srcTaskId) {
        this.updateSourceTaskCount();
        Set<Integer> completedSourceTasks = this.bipartiteSources.get(srcVertexName);
        if (completedSourceTasks != null) {
            if (completedSourceTasks.add(srcTaskId)) {
                ++this.numSourceTasksCompleted;
            }
            this.schedulePendingTasks();
        }
    }

    public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
        if (this.enableAutoParallelism) {
            ShuffleUserPayloads.VertexManagerEventPayloadProto proto;
            try {
                proto = ShuffleUserPayloads.VertexManagerEventPayloadProto.parseFrom(vmEvent.getUserPayload());
            }
            catch (InvalidProtocolBufferException e) {
                throw new TezUncheckedException((Throwable)e);
            }
            long sourceTaskOutputSize = proto.getOutputSize();
            ++this.numVertexManagerEventsReceived;
            this.completedSourceTasksOutputSize += sourceTaskOutputSize;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Received info of output size: " + sourceTaskOutputSize + " numInfoReceived: " + this.numVertexManagerEventsReceived + " total output size: " + this.completedSourceTasksOutputSize));
            }
        }
    }

    void updatePendingTasks() {
        this.pendingTasks.clear();
        for (int i = 0; i < this.context.getVertexNumTasks(this.context.getVertexName()); ++i) {
            this.pendingTasks.add(new Integer(i));
        }
        this.totalTasksToSchedule = this.pendingTasks.size();
    }

    void updateSourceTaskCount() {
        int numSrcTasks = 0;
        for (String vertex : this.bipartiteSources.keySet()) {
            numSrcTasks += this.context.getVertexNumTasks(vertex);
        }
        this.numSourceTasks = numSrcTasks;
    }

    void determineParallelismAndApply() {
        if (this.numSourceTasksCompleted == 0) {
            return;
        }
        if (this.numVertexManagerEventsReceived == 0) {
            return;
        }
        int currentParallelism = this.pendingTasks.size();
        long expectedTotalSourceTasksOutputSize = (long)this.numSourceTasks * this.completedSourceTasksOutputSize / (long)this.numVertexManagerEventsReceived;
        int desiredTaskParallelism = (int)((expectedTotalSourceTasksOutputSize + this.desiredTaskInputDataSize - 1L) / this.desiredTaskInputDataSize);
        if (desiredTaskParallelism < this.minTaskParallelism) {
            desiredTaskParallelism = this.minTaskParallelism;
        }
        if (desiredTaskParallelism >= currentParallelism) {
            return;
        }
        int basePartitionRange = currentParallelism / desiredTaskParallelism;
        if (basePartitionRange <= 1) {
            return;
        }
        int numShufflersWithBaseRange = currentParallelism / basePartitionRange;
        int remainderRangeForLastShuffler = currentParallelism % basePartitionRange;
        int finalTaskParallelism = remainderRangeForLastShuffler > 0 ? numShufflersWithBaseRange + 1 : numShufflersWithBaseRange;
        LOG.info((Object)("Reduce auto parallelism for vertex: " + this.context.getVertexName() + " to " + finalTaskParallelism + " from " + this.pendingTasks.size() + " . Expected output: " + expectedTotalSourceTasksOutputSize + " based on actual output: " + this.completedSourceTasksOutputSize + " from " + this.numVertexManagerEventsReceived + " vertex manager events. " + " desiredTaskInputSize: " + this.desiredTaskInputDataSize));
        if (finalTaskParallelism < currentParallelism) {
            HashMap<String, EdgeManagerDescriptor> edgeManagers = new HashMap<String, EdgeManagerDescriptor>(this.bipartiteSources.size());
            for (String vertex : this.bipartiteSources.keySet()) {
                CustomShuffleEdgeManagerConfig edgeManagerConfig = new CustomShuffleEdgeManagerConfig(currentParallelism, finalTaskParallelism, basePartitionRange, remainderRangeForLastShuffler > 0 ? remainderRangeForLastShuffler : basePartitionRange);
                EdgeManagerDescriptor edgeManagerDescriptor = new EdgeManagerDescriptor(CustomShuffleEdgeManager.class.getName());
                edgeManagerDescriptor.setUserPayload(edgeManagerConfig.toUserPayload());
                edgeManagers.put(vertex, edgeManagerDescriptor);
            }
            this.context.setVertexParallelism(finalTaskParallelism, null, edgeManagers);
            this.updatePendingTasks();
        }
    }

    void schedulePendingTasks(int numTasksToSchedule) {
        if (this.enableAutoParallelism && !this.parallelismDetermined) {
            this.parallelismDetermined = true;
            this.determineParallelismAndApply();
        }
        ArrayList<Integer> scheduledTasks = new ArrayList<Integer>(numTasksToSchedule);
        while (!this.pendingTasks.isEmpty() && numTasksToSchedule > 0) {
            --numTasksToSchedule;
            scheduledTasks.add(this.pendingTasks.get(0));
            this.pendingTasks.remove(0);
        }
        this.context.scheduleVertexTasks(scheduledTasks);
    }

    void schedulePendingTasks() {
        int numPendingTasks = this.pendingTasks.size();
        if (numPendingTasks == 0) {
            return;
        }
        if (this.numSourceTasksCompleted == this.numSourceTasks && numPendingTasks > 0) {
            LOG.info((Object)("All source tasks assigned. Ramping up " + numPendingTasks + " remaining tasks for vertex: " + this.context.getVertexName()));
            this.schedulePendingTasks(numPendingTasks);
            return;
        }
        float completedSourceTaskFraction = 0.0f;
        completedSourceTaskFraction = this.numSourceTasks != 0 ? (float)this.numSourceTasksCompleted / (float)this.numSourceTasks : 1.0f;
        float tasksFractionToSchedule = 1.0f;
        float percentRange = this.slowStartMaxSrcCompletionFraction - this.slowStartMinSrcCompletionFraction;
        if (percentRange > 0.0f) {
            tasksFractionToSchedule = (completedSourceTaskFraction - this.slowStartMinSrcCompletionFraction) / percentRange;
        } else if (completedSourceTaskFraction < this.slowStartMinSrcCompletionFraction) {
            tasksFractionToSchedule = 0.0f;
        }
        if (tasksFractionToSchedule > 1.0f) {
            tasksFractionToSchedule = 1.0f;
        } else if (tasksFractionToSchedule < 0.0f) {
            tasksFractionToSchedule = 0.0f;
        }
        int numTasksToSchedule = (int)(tasksFractionToSchedule * (float)this.totalTasksToSchedule) - (this.totalTasksToSchedule - numPendingTasks);
        if (numTasksToSchedule > 0) {
            LOG.info((Object)("Scheduling " + numTasksToSchedule + " tasks for vertex: " + this.context.getVertexName() + " with totalTasks: " + this.totalTasksToSchedule + ". " + this.numSourceTasksCompleted + " source tasks completed out of " + this.numSourceTasks + ". SourceTaskCompletedFraction: " + completedSourceTaskFraction + " min: " + this.slowStartMinSrcCompletionFraction + " max: " + this.slowStartMaxSrcCompletionFraction));
            this.schedulePendingTasks(numTasksToSchedule);
        }
    }

    public void initialize(VertexManagerPluginContext context) {
        Configuration conf;
        try {
            conf = TezUtils.createConfFromUserPayload((byte[])context.getUserPayload());
        }
        catch (IOException e) {
            throw new TezUncheckedException((Throwable)e);
        }
        this.context = context;
        this.slowStartMinSrcCompletionFraction = conf.getFloat(TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 0.25f);
        this.slowStartMaxSrcCompletionFraction = conf.getFloat(TEZ_AM_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, 0.75f);
        if (this.slowStartMinSrcCompletionFraction < 0.0f || this.slowStartMaxSrcCompletionFraction < this.slowStartMinSrcCompletionFraction) {
            throw new IllegalArgumentException("Invalid values for slowStartMinSrcCompletionFraction/slowStartMaxSrcCompletionFraction. Min cannot be < 0 and max cannot be < min.");
        }
        this.enableAutoParallelism = conf.getBoolean(TEZ_AM_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, false);
        this.desiredTaskInputDataSize = conf.getLong(TEZ_AM_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 0x6400000L);
        this.minTaskParallelism = conf.getInt(TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM, 1);
        LOG.info((Object)("Shuffle Vertex Manager: settings minFrac:" + this.slowStartMinSrcCompletionFraction + " maxFrac:" + this.slowStartMaxSrcCompletionFraction + " auto:" + this.enableAutoParallelism + " desiredTaskIput:" + this.desiredTaskInputDataSize + " minTasks:" + this.minTaskParallelism));
        Map inputs = context.getInputVertexEdgeProperties();
        for (Map.Entry entry : inputs.entrySet()) {
            if (((EdgeProperty)entry.getValue()).getDataMovementType() != EdgeProperty.DataMovementType.SCATTER_GATHER) continue;
            String vertex = (String)entry.getKey();
            this.bipartiteSources.put(vertex, new HashSet());
        }
        if (this.bipartiteSources.isEmpty()) {
            throw new TezUncheckedException("Atleast 1 bipartite source should exist");
        }
    }

    public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<Event> events) {
    }

    private static class CustomShuffleEdgeManagerConfig {
        int numSourceTaskOutputs;
        int numDestinationTasks;
        int basePartitionRange;
        int remainderRangeForLastShuffler;

        private CustomShuffleEdgeManagerConfig(int numSourceTaskOutputs, int numDestinationTasks, int basePartitionRange, int remainderRangeForLastShuffler) {
            this.numSourceTaskOutputs = numSourceTaskOutputs;
            this.numDestinationTasks = numDestinationTasks;
            this.basePartitionRange = basePartitionRange;
            this.remainderRangeForLastShuffler = remainderRangeForLastShuffler;
        }

        public byte[] toUserPayload() {
            return ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto.newBuilder().setNumSourceTaskOutputs(this.numSourceTaskOutputs).setNumDestinationTasks(this.numDestinationTasks).setBasePartitionRange(this.basePartitionRange).setRemainderRangeForLastShuffler(this.remainderRangeForLastShuffler).build().toByteArray();
        }

        public static CustomShuffleEdgeManagerConfig fromUserPayload(byte[] userPayload) throws InvalidProtocolBufferException {
            ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto proto = ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto.parseFrom(userPayload);
            return new CustomShuffleEdgeManagerConfig(proto.getNumSourceTaskOutputs(), proto.getNumDestinationTasks(), proto.getBasePartitionRange(), proto.getRemainderRangeForLastShuffler());
        }
    }

    public static class CustomShuffleEdgeManager
    implements EdgeManager {
        int numSourceTaskOutputs;
        int numDestinationTasks;
        int basePartitionRange;
        int remainderRangeForLastShuffler;

        public void initialize(EdgeManagerContext edgeManagerContext) {
            CustomShuffleEdgeManagerConfig config;
            byte[] userPayload = edgeManagerContext.getUserPayload();
            if (userPayload == null || userPayload.length == 0) {
                throw new RuntimeException("Could not initialize CustomShuffleEdgeManager from provided user payload");
            }
            try {
                config = CustomShuffleEdgeManagerConfig.fromUserPayload(userPayload);
            }
            catch (InvalidProtocolBufferException e) {
                throw new RuntimeException("Could not initialize CustomShuffleEdgeManager from provided user payload", e);
            }
            this.numSourceTaskOutputs = config.numSourceTaskOutputs;
            this.numDestinationTasks = config.numDestinationTasks;
            this.basePartitionRange = config.basePartitionRange;
            this.remainderRangeForLastShuffler = config.remainderRangeForLastShuffler;
        }

        public int getNumDestinationTaskPhysicalInputs(int numSourceTasks, int destinationTaskIndex) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return numSourceTasks * partitionRange;
        }

        public int getNumSourceTaskPhysicalOutputs(int numDestinationTasks, int sourceTaskIndex) {
            return this.numSourceTaskOutputs;
        }

        public void routeDataMovementEventToDestination(DataMovementEvent event, int sourceTaskIndex, int numDestinationTasks, Map<Integer, List<Integer>> inputIndicesToTaskIndices) {
            int sourceIndex = event.getSourceIndex();
            int destinationTaskIndex = sourceIndex / this.basePartitionRange;
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            int targetIndex = sourceTaskIndex * partitionRange + sourceIndex % partitionRange;
            inputIndicesToTaskIndices.put(new Integer(targetIndex), Collections.singletonList(new Integer(destinationTaskIndex)));
        }

        public void routeInputSourceTaskFailedEventToDestination(int sourceTaskIndex, int numDestinationTasks, Map<Integer, List<Integer>> inputIndicesToTaskIndices) {
            if (this.remainderRangeForLastShuffler < this.basePartitionRange) {
                int i;
                List<Integer> lastTask = Collections.singletonList(new Integer(numDestinationTasks - 1));
                ArrayList otherTasks = Lists.newArrayListWithCapacity((int)(numDestinationTasks - 1));
                for (int i2 = 0; i2 < numDestinationTasks - 1; ++i2) {
                    otherTasks.add(new Integer(i2));
                }
                int startOffset = sourceTaskIndex * this.basePartitionRange;
                for (i = 0; i < this.basePartitionRange; ++i) {
                    inputIndicesToTaskIndices.put(new Integer(startOffset + i), otherTasks);
                }
                startOffset = sourceTaskIndex * this.remainderRangeForLastShuffler;
                for (i = 0; i < this.remainderRangeForLastShuffler; ++i) {
                    inputIndicesToTaskIndices.put(new Integer(startOffset + i), lastTask);
                }
            } else {
                ArrayList allTasks = Lists.newArrayListWithCapacity((int)numDestinationTasks);
                for (int i = 0; i < numDestinationTasks; ++i) {
                    allTasks.add(new Integer(i));
                }
                int startOffset = sourceTaskIndex * this.basePartitionRange;
                for (int i = 0; i < this.basePartitionRange; ++i) {
                    inputIndicesToTaskIndices.put(new Integer(startOffset + i), allTasks);
                }
            }
        }

        public int routeInputErrorEventToSource(InputReadErrorEvent event, int destinationTaskIndex) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return event.getIndex() / partitionRange;
        }

        public int getNumDestinationConsumerTasks(int sourceTaskIndex, int numDestTasks) {
            return numDestTasks;
        }
    }
}

