/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.Preconditions;

public class RoundRobinOperatorStateRepartitioner
implements OperatorStateRepartitioner {
    public static final OperatorStateRepartitioner INSTANCE = new RoundRobinOperatorStateRepartitioner();
    private static final boolean OPTIMIZE_MEMORY_USE = false;

    @Override
    public List<List<OperatorStateHandle>> repartitionState(List<OperatorStateHandle> previousParallelSubtaskStates, int newParallelism) {
        Preconditions.checkNotNull(previousParallelSubtaskStates);
        Preconditions.checkArgument((newParallelism > 0 ? 1 : 0) != 0);
        GroupByStateNameResults nameToStateByMode = this.groupByStateName(previousParallelSubtaskStates);
        ArrayList<List<OperatorStateHandle>> result = new ArrayList<List<OperatorStateHandle>>(newParallelism);
        List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = this.repartition(nameToStateByMode, newParallelism);
        for (int i = 0; i < mergeMapList.size(); ++i) {
            result.add(i, new ArrayList<OperatorStateHandle>(mergeMapList.get(i).values()));
        }
        return result;
    }

    private GroupByStateNameResults groupByStateName(List<OperatorStateHandle> previousParallelSubtaskStates) {
        EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode = new EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>>(OperatorStateHandle.Mode.class);
        for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
            nameToStateByMode.put(mode, new HashMap());
        }
        for (OperatorStateHandle psh : previousParallelSubtaskStates) {
            if (psh == null) continue;
            Set<Map.Entry<String, OperatorStateHandle.StateMetaInfo>> partitionOffsetEntries = psh.getStateNameToPartitionOffsets().entrySet();
            for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : partitionOffsetEntries) {
                OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
                Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState = nameToStateByMode.get((Object)metaInfo.getDistributionMode());
                List stateLocations = nameToState.computeIfAbsent(e.getKey(), k -> new ArrayList(previousParallelSubtaskStates.size() * partitionOffsetEntries.size()));
                stateLocations.add(new Tuple2((Object)psh.getDelegateStateHandle(), (Object)e.getValue()));
            }
        }
        return new GroupByStateNameResults(nameToStateByMode);
    }

    private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(GroupByStateNameResults nameToStateByMode, int newParallelism) {
        ArrayList<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<Map<StreamStateHandle, OperatorStateHandle>>(newParallelism);
        for (int i = 0; i < newParallelism; ++i) {
            mergeMapList.add(new HashMap());
        }
        Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
        int startParallelOp = 0;
        for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : distributeNameToState.entrySet()) {
            List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
            int totalPartitions = 0;
            for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
                totalPartitions += ((OperatorStateHandle.StateMetaInfo)offsets.f1).getOffsets().length;
            }
            int lstIdx = 0;
            int offsetIdx = 0;
            int baseFraction = totalPartitions / newParallelism;
            int remainder = totalPartitions % newParallelism;
            int newStartParallelOp = startParallelOp;
            for (int i = 0; i < newParallelism; ++i) {
                int parallelOpIdx = (i + startParallelOp) % newParallelism;
                int numberOfPartitionsToAssign = baseFraction;
                if (remainder > 0) {
                    ++numberOfPartitionsToAssign;
                    --remainder;
                } else if (remainder == 0) {
                    newStartParallelOp = parallelOpIdx;
                    --remainder;
                }
                while (numberOfPartitionsToAssign > 0) {
                    long[] offs;
                    Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets = current.get(lstIdx);
                    long[] offsets = ((OperatorStateHandle.StateMetaInfo)handleWithOffsets.f1).getOffsets();
                    int remaining = offsets.length - offsetIdx;
                    if (remaining > numberOfPartitionsToAssign) {
                        offs = Arrays.copyOfRange(offsets, offsetIdx, offsetIdx + numberOfPartitionsToAssign);
                        offsetIdx += numberOfPartitionsToAssign;
                    } else {
                        offs = Arrays.copyOfRange(offsets, offsetIdx, offsets.length);
                        offsetIdx = 0;
                        ++lstIdx;
                    }
                    numberOfPartitionsToAssign -= remaining;
                    Map mergeMap = (Map)mergeMapList.get(parallelOpIdx);
                    OperatorStateHandle operatorStateHandle = (OperatorStateHandle)mergeMap.get(handleWithOffsets.f0);
                    if (operatorStateHandle == null) {
                        operatorStateHandle = new OperatorStreamStateHandle(new HashMap<String, OperatorStateHandle.StateMetaInfo>(distributeNameToState.size()), (StreamStateHandle)handleWithOffsets.f0);
                        mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
                    }
                    operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
                }
            }
            startParallelOp = newStartParallelOp;
            e.setValue(null);
        }
        Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.UNION);
        for (int i = 0; i < newParallelism; ++i) {
            Map mergeMap = (Map)mergeMapList.get(i);
            for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : broadcastNameToState.entrySet()) {
                for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : e.getValue()) {
                    OperatorStateHandle operatorStateHandle = (OperatorStateHandle)mergeMap.get(handleWithMetaInfo.f0);
                    if (operatorStateHandle == null) {
                        operatorStateHandle = new OperatorStreamStateHandle(new HashMap<String, OperatorStateHandle.StateMetaInfo>(broadcastNameToState.size()), (StreamStateHandle)handleWithMetaInfo.f0);
                        mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
                    }
                    operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), (OperatorStateHandle.StateMetaInfo)handleWithMetaInfo.f1);
                }
            }
        }
        Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> uniformBroadcastNameToState = nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
        for (int i = 0; i < newParallelism; ++i) {
            Map mergeMap = (Map)mergeMapList.get(i);
            for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e : uniformBroadcastNameToState.entrySet()) {
                int oldParallelism = e.getValue().size();
                Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo = e.getValue().get(i % oldParallelism);
                OperatorStateHandle operatorStateHandle = (OperatorStateHandle)mergeMap.get(handleWithMetaInfo.f0);
                if (operatorStateHandle == null) {
                    operatorStateHandle = new OperatorStreamStateHandle(new HashMap<String, OperatorStateHandle.StateMetaInfo>(uniformBroadcastNameToState.size()), (StreamStateHandle)handleWithMetaInfo.f0);
                    mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
                }
                operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), (OperatorStateHandle.StateMetaInfo)handleWithMetaInfo.f1);
            }
        }
        return mergeMapList;
    }

    private static final class GroupByStateNameResults {
        private final EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode;

        GroupByStateNameResults(EnumMap<OperatorStateHandle.Mode, Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) {
            this.byMode = (EnumMap)Preconditions.checkNotNull(byMode);
        }

        public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode(OperatorStateHandle.Mode mode) {
            return this.byMode.get((Object)mode);
        }
    }
}

