/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.dag.app.dag.impl;

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.app.dag.EdgeManager;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.app.dag.VertexScheduler;
import org.apache.tez.dag.app.dag.impl.Edge;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.InputFailedEvent;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ShuffleVertexManager
implements VertexScheduler {
    private static final Log LOG = LogFactory.getLog(ShuffleVertexManager.class);
    final Vertex managedVertex;
    float slowStartMinSrcCompletionFraction;
    float slowStartMaxSrcCompletionFraction;
    long desiredTaskInputDataSize = 0x6400000L;
    int minTaskParallelism = 1;
    boolean enableAutoParallelism = false;
    boolean parallelismDetermined = false;
    int numSourceTasks = 0;
    int numSourceTasksCompleted = 0;
    int numVertexManagerEventsReceived = 0;
    ArrayList<TezTaskID> pendingTasks;
    int totalTasksToSchedule = 0;
    HashMap<TezVertexID, Vertex> bipartiteSources = new HashMap();
    Set<TezTaskID> completedSourceTasks = new HashSet<TezTaskID>();
    long completedSourceTasksOutputSize = 0L;

    public ShuffleVertexManager(Vertex managedVertex) {
        this.managedVertex = managedVertex;
        Map<Vertex, Edge> inputs = managedVertex.getInputVertices();
        for (Map.Entry<Vertex, Edge> entry : inputs.entrySet()) {
            if (entry.getValue().getEdgeProperty().getDataMovementType() != EdgeProperty.DataMovementType.SCATTER_GATHER) continue;
            Vertex vertex = entry.getKey();
            this.bipartiteSources.put(vertex.getVertexId(), vertex);
        }
        if (this.bipartiteSources.isEmpty()) {
            throw new TezUncheckedException("Atleast 1 bipartite source should exist");
        }
    }

    @Override
    public void onVertexStarted(List<TezTaskAttemptID> completions) {
        this.pendingTasks = new ArrayList(this.managedVertex.getTotalTasks());
        this.updatePendingTasks();
        this.updateSourceTaskCount();
        LOG.info((Object)("OnVertexStarted vertex: " + this.managedVertex.getVertexId() + " with " + this.numSourceTasks + " source tasks and " + this.totalTasksToSchedule + " pending tasks"));
        if (completions != null) {
            for (TezTaskAttemptID srcAttemptId : completions) {
                this.onSourceTaskCompleted(srcAttemptId);
            }
        }
        this.schedulePendingTasks();
    }

    @Override
    public void onSourceTaskCompleted(TezTaskAttemptID srcAttemptId) {
        this.updateSourceTaskCount();
        TezTaskID srcTaskId = srcAttemptId.getTaskID();
        TezVertexID srcVertexId = srcTaskId.getVertexID();
        if (this.bipartiteSources.containsKey(srcVertexId)) {
            if (this.completedSourceTasks.add(srcTaskId)) {
                ++this.numSourceTasksCompleted;
            }
            this.schedulePendingTasks();
        }
    }

    @Override
    public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
        if (this.enableAutoParallelism) {
            ShuffleUserPayloads.VertexManagerEventPayloadProto proto;
            try {
                proto = ShuffleUserPayloads.VertexManagerEventPayloadProto.parseFrom((byte[])vmEvent.getUserPayload());
            }
            catch (InvalidProtocolBufferException e) {
                throw new TezUncheckedException((Throwable)e);
            }
            long sourceTaskOutputSize = proto.getOutputSize();
            ++this.numVertexManagerEventsReceived;
            this.completedSourceTasksOutputSize += sourceTaskOutputSize;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Received info of output size: " + sourceTaskOutputSize + " numInfoReceived: " + this.numVertexManagerEventsReceived + " total output size: " + this.completedSourceTasksOutputSize));
            }
        }
    }

    void updatePendingTasks() {
        this.pendingTasks.clear();
        this.pendingTasks.addAll(this.managedVertex.getTasks().keySet());
        this.totalTasksToSchedule = this.pendingTasks.size();
    }

    void updateSourceTaskCount() {
        int numSrcTasks = 0;
        for (Vertex vertex : this.bipartiteSources.values()) {
            numSrcTasks += vertex.getTotalTasks();
        }
        this.numSourceTasks = numSrcTasks;
    }

    void determineParallelismAndApply() {
        int finalTaskParallelism;
        if (this.numSourceTasksCompleted == 0) {
            return;
        }
        if (this.numVertexManagerEventsReceived == 0) {
            return;
        }
        int currentParallelism = this.pendingTasks.size();
        long expectedTotalSourceTasksOutputSize = (long)this.numSourceTasks * this.completedSourceTasksOutputSize / (long)this.numVertexManagerEventsReceived;
        int desiredTaskParallelism = (int)((expectedTotalSourceTasksOutputSize + this.desiredTaskInputDataSize - 1L) / this.desiredTaskInputDataSize);
        if (desiredTaskParallelism < this.minTaskParallelism) {
            desiredTaskParallelism = this.minTaskParallelism;
        }
        if (desiredTaskParallelism >= currentParallelism) {
            return;
        }
        int basePartitionRange = currentParallelism / desiredTaskParallelism;
        if (basePartitionRange <= 1) {
            return;
        }
        int numShufflersWithBaseRange = currentParallelism / basePartitionRange;
        int remainderRangeForLastShuffler = currentParallelism % basePartitionRange;
        int n = finalTaskParallelism = remainderRangeForLastShuffler > 0 ? numShufflersWithBaseRange + 1 : numShufflersWithBaseRange;
        if (finalTaskParallelism < currentParallelism) {
            LOG.info((Object)("Reducing parallelism for vertex: " + this.managedVertex.getVertexId() + " to " + finalTaskParallelism + " from " + this.pendingTasks.size() + " . Expected output: " + expectedTotalSourceTasksOutputSize + " based on actual output: " + this.completedSourceTasksOutputSize + " from " + this.numVertexManagerEventsReceived + " vertex manager events. " + " desiredTaskInputSize: " + this.desiredTaskInputDataSize));
            HashMap<Vertex, EdgeManager> edgeManagers = new HashMap<Vertex, EdgeManager>(this.bipartiteSources.size());
            for (Vertex vertex : this.bipartiteSources.values()) {
                edgeManagers.put(vertex, new CustomShuffleEdgeManager(currentParallelism, finalTaskParallelism, basePartitionRange, remainderRangeForLastShuffler > 0 ? remainderRangeForLastShuffler : basePartitionRange));
            }
            this.managedVertex.setParallelism(finalTaskParallelism, edgeManagers);
            this.updatePendingTasks();
        }
    }

    void schedulePendingTasks(int numTasksToSchedule) {
        if (this.enableAutoParallelism && !this.parallelismDetermined) {
            this.parallelismDetermined = true;
            this.determineParallelismAndApply();
        }
        ArrayList<TezTaskID> scheduledTasks = new ArrayList<TezTaskID>(numTasksToSchedule);
        while (!this.pendingTasks.isEmpty() && numTasksToSchedule > 0) {
            --numTasksToSchedule;
            scheduledTasks.add(this.pendingTasks.get(0));
            this.pendingTasks.remove(0);
        }
        this.managedVertex.scheduleTasks(scheduledTasks);
    }

    void schedulePendingTasks() {
        int numPendingTasks = this.pendingTasks.size();
        if (numPendingTasks == 0) {
            return;
        }
        if (this.numSourceTasksCompleted == this.numSourceTasks && numPendingTasks > 0) {
            LOG.info((Object)("All source tasks assigned. Ramping up " + numPendingTasks + " remaining tasks for vertex: " + this.managedVertex.getName()));
            this.schedulePendingTasks(numPendingTasks);
            return;
        }
        float completedSourceTaskFraction = 0.0f;
        completedSourceTaskFraction = this.numSourceTasks != 0 ? (float)this.numSourceTasksCompleted / (float)this.numSourceTasks : 1.0f;
        float tasksFractionToSchedule = 1.0f;
        float percentRange = this.slowStartMaxSrcCompletionFraction - this.slowStartMinSrcCompletionFraction;
        if (percentRange > 0.0f) {
            tasksFractionToSchedule = (completedSourceTaskFraction - this.slowStartMinSrcCompletionFraction) / percentRange;
        } else if (completedSourceTaskFraction < this.slowStartMinSrcCompletionFraction) {
            tasksFractionToSchedule = 0.0f;
        }
        if (tasksFractionToSchedule > 1.0f) {
            tasksFractionToSchedule = 1.0f;
        } else if (tasksFractionToSchedule < 0.0f) {
            tasksFractionToSchedule = 0.0f;
        }
        int numTasksToSchedule = (int)(tasksFractionToSchedule * (float)this.totalTasksToSchedule) - (this.totalTasksToSchedule - numPendingTasks);
        if (numTasksToSchedule > 0) {
            LOG.info((Object)("Scheduling " + numTasksToSchedule + " tasks for vertex: " + this.managedVertex.getVertexId() + " with totalTasks: " + this.totalTasksToSchedule + ". " + this.numSourceTasksCompleted + " source tasks completed out of " + this.numSourceTasks + ". SourceTaskCompletedFraction: " + completedSourceTaskFraction + " min: " + this.slowStartMinSrcCompletionFraction + " max: " + this.slowStartMaxSrcCompletionFraction));
            this.schedulePendingTasks(numTasksToSchedule);
        }
    }

    @Override
    public void initialize(Configuration conf) {
        this.slowStartMinSrcCompletionFraction = conf.getFloat("tez.am.shuffle-vertex-manager.min-src-fraction", 0.25f);
        this.slowStartMaxSrcCompletionFraction = conf.getFloat("tez.am.shuffle-vertex-manager.max-src-fraction", 0.75f);
        if (this.slowStartMinSrcCompletionFraction < 0.0f || this.slowStartMaxSrcCompletionFraction < this.slowStartMinSrcCompletionFraction) {
            throw new IllegalArgumentException("Invalid values for slowStartMinSrcCompletionFraction/slowStartMaxSrcCompletionFraction. Min cannot be < 0 and max cannot be < min.");
        }
        this.enableAutoParallelism = conf.getBoolean("tez.am.shuffle-vertex-manager.enable.auto-parallel", false);
        this.desiredTaskInputDataSize = conf.getLong("tez.am.shuffle-vertex-manager.desired-task-input-size", 0x6400000L);
        this.minTaskParallelism = conf.getInt("tez.am.shuffle-vertex-manager.min-task-parallelism", 1);
    }

    @Override
    public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<Event> events) {
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public class CustomShuffleEdgeManager
    extends EdgeManager {
        int numSourceTaskOutputs;
        int numDestinationTasks;
        int basePartitionRange;
        int remainderRangeForLastShuffler;

        CustomShuffleEdgeManager(int numSourceTaskOutputs, int numDestinationTasks, int basePartitionRange, int remainderPartitionForLastShuffler) {
            this.numSourceTaskOutputs = numSourceTaskOutputs;
            this.numDestinationTasks = numDestinationTasks;
            this.basePartitionRange = basePartitionRange;
            this.remainderRangeForLastShuffler = remainderPartitionForLastShuffler;
        }

        @Override
        public int getNumDestinationTaskInputs(int numSourceTasks, int destinationTaskIndex) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return numSourceTasks * partitionRange;
        }

        @Override
        public int getNumSourceTaskOutputs(int numDestinationTasks, int sourceTaskIndex) {
            return this.numSourceTaskOutputs;
        }

        @Override
        public void routeEventToDestinationTasks(DataMovementEvent event, int sourceTaskIndex, int numDestinationTasks, List<Integer> taskIndices) {
            int sourceIndex = event.getSourceIndex();
            int destinationTaskIndex = sourceIndex / this.basePartitionRange;
            int targetIndex = sourceTaskIndex * this.basePartitionRange + sourceIndex % this.basePartitionRange;
            event.setTargetIndex(targetIndex);
            taskIndices.add(new Integer(destinationTaskIndex));
        }

        @Override
        public void routeEventToDestinationTasks(InputFailedEvent event, int sourceTaskIndex, int numDestinationTasks, List<Integer> taskIndices) {
            int sourceIndex = event.getSourceIndex();
            int destinationTaskIndex = sourceIndex / this.basePartitionRange;
            int targetIndex = sourceTaskIndex * this.basePartitionRange + sourceIndex % this.basePartitionRange;
            event.setTargetIndex(targetIndex);
            taskIndices.add(new Integer(destinationTaskIndex));
        }

        @Override
        public int routeEventToSourceTasks(int destinationTaskIndex, InputReadErrorEvent event) {
            int partitionRange = 1;
            partitionRange = destinationTaskIndex < this.numDestinationTasks - 1 ? this.basePartitionRange : this.remainderRangeForLastShuffler;
            return event.getIndex() / partitionRange;
        }

        @Override
        public int getDestinationConsumerTaskNumber(int sourceTaskIndex, int numDestTasks) {
            return numDestTasks;
        }
    }
}

