/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint.channel;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateByteBuffer;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.RecoveredChannelStateHandler;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel;

class InputChannelRecoveredStateHandler
implements RecoveredChannelStateHandler<InputChannelInfo, Buffer> {
    private final InputGate[] inputGates;
    private final InflightDataRescalingDescriptor channelMapping;
    private final Map<InputChannelInfo, List<RecoveredInputChannel>> rescaledChannels = new HashMap<InputChannelInfo, List<RecoveredInputChannel>>();
    private final Map<Integer, RescaleMappings> oldToNewMappings = new HashMap<Integer, RescaleMappings>();

    InputChannelRecoveredStateHandler(InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) {
        this.inputGates = inputGates;
        this.channelMapping = channelMapping;
    }

    @Override
    public RecoveredChannelStateHandler.BufferWithContext<Buffer> getBuffer(InputChannelInfo channelInfo) throws IOException, InterruptedException {
        RecoveredInputChannel channel = this.getMappedChannels(channelInfo).get(0);
        Buffer buffer = channel.requestBufferBlocking();
        return new RecoveredChannelStateHandler.BufferWithContext<Buffer>(ChannelStateByteBuffer.wrap(buffer), buffer);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void recover(InputChannelInfo channelInfo, int oldSubtaskIndex, Buffer buffer) throws IOException {
        try {
            if (buffer.readableBytes() > 0) {
                for (RecoveredInputChannel channel : this.getMappedChannels(channelInfo)) {
                    channel.onRecoveredStateBuffer(EventSerializer.toBuffer(new SubtaskConnectionDescriptor(oldSubtaskIndex, channelInfo.getInputChannelIdx()), false));
                    channel.onRecoveredStateBuffer(buffer.retainBuffer());
                }
            }
        }
        finally {
            buffer.recycleBuffer();
        }
    }

    @Override
    public void close() throws IOException {
        for (InputGate inputGate : this.inputGates) {
            inputGate.finishReadRecoveredState();
        }
    }

    private RecoveredInputChannel getChannel(int gateIndex, int subPartitionIndex) {
        InputChannel inputChannel = this.inputGates[gateIndex].getChannel(subPartitionIndex);
        if (!(inputChannel instanceof RecoveredInputChannel)) {
            throw new IllegalStateException("Cannot restore state to a non-recovered input channel: " + inputChannel);
        }
        return (RecoveredInputChannel)inputChannel;
    }

    private List<RecoveredInputChannel> getMappedChannels(InputChannelInfo channelInfo) {
        return this.rescaledChannels.computeIfAbsent(channelInfo, this::calculateMapping);
    }

    private List<RecoveredInputChannel> calculateMapping(InputChannelInfo info) {
        RescaleMappings oldToNewMapping = this.oldToNewMappings.computeIfAbsent(info.getGateIdx(), idx -> this.channelMapping.getChannelMapping((int)idx).invert());
        List<RecoveredInputChannel> channels = Arrays.stream(oldToNewMapping.getMappedIndexes(info.getInputChannelIdx())).mapToObj(newChannelIndex -> this.getChannel(info.getGateIdx(), newChannelIndex)).collect(Collectors.toList());
        if (channels.isEmpty()) {
            throw new IllegalStateException("Recovered a buffer from old " + info + " that has no mapping in " + this.channelMapping.getChannelMapping(info.getGateIdx()));
        }
        return channels;
    }
}

