/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness.state;

import com.google.auto.value.AutoValue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.fn.harness.state.AutoValue_StateFetchingIterators_CachingStateIterable_Block;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.stream.DataStreams;
import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;

public class StateFetchingIterators {
    private StateFetchingIterators() {
    }

    public static <T> CachingStateIterable<T> readAllAndDecodeStartingFrom(Cache<?, ?> cache, BeamFnStateClient beamFnStateClient, BeamFnApi.StateRequest stateRequestForFirstChunk, Coder<T> valueCoder) {
        return new CachingStateIterable(cache, beamFnStateClient, stateRequestForFirstChunk, valueCoder);
    }

    @VisibleForTesting
    static class LazyBlockingStateFetchingIterator
    implements PrefetchableIterator<ByteString> {
        private final BeamFnStateClient beamFnStateClient;
        private final BeamFnApi.StateRequest stateRequestForFirstChunk;
        private ByteString continuationToken;
        private CompletableFuture<BeamFnApi.StateResponse> prefetchedResponse;

        LazyBlockingStateFetchingIterator(BeamFnStateClient beamFnStateClient, BeamFnApi.StateRequest stateRequestForFirstChunk) {
            this.beamFnStateClient = beamFnStateClient;
            this.stateRequestForFirstChunk = stateRequestForFirstChunk;
            this.continuationToken = stateRequestForFirstChunk.getGet().getContinuationToken();
        }

        @Nullable
        public ByteString getContinuationToken() {
            return this.continuationToken;
        }

        public void seekToContinuationToken(@Nullable ByteString continuationToken) {
            if (Objects.equals(this.continuationToken, continuationToken)) {
                return;
            }
            this.continuationToken = continuationToken;
            this.prefetchedResponse = null;
        }

        @Override
        public boolean isReady() {
            if (this.prefetchedResponse == null) {
                return this.continuationToken == null;
            }
            return this.prefetchedResponse.isDone();
        }

        @Override
        public void prefetch() {
            if (this.continuationToken != null && this.prefetchedResponse == null) {
                this.prefetchedResponse = this.loadPrefetchedResponse(this.continuationToken);
            }
        }

        public CompletableFuture<BeamFnApi.StateResponse> loadPrefetchedResponse(ByteString continuationToken) {
            return this.beamFnStateClient.handle(this.stateRequestForFirstChunk.toBuilder().setGet(BeamFnApi.StateGetRequest.newBuilder().setContinuationToken(continuationToken)));
        }

        @Override
        public boolean hasNext() {
            return this.continuationToken != null;
        }

        @Override
        public ByteString next() {
            BeamFnApi.StateResponse stateResponse;
            if (!this.hasNext()) {
                throw new NoSuchElementException();
            }
            this.prefetch();
            try {
                stateResponse = this.prefetchedResponse.get();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException(e);
            }
            catch (ExecutionException e) {
                if (e.getCause() == null) {
                    throw new IllegalStateException(e);
                }
                Throwables.throwIfUnchecked(e.getCause());
                throw new IllegalStateException(e.getCause());
            }
            this.prefetchedResponse = null;
            if (ByteString.EMPTY.equals(stateResponse.getGet().getContinuationToken())) {
                this.continuationToken = null;
            } else {
                this.continuationToken = stateResponse.getGet().getContinuationToken();
                this.prefetch();
            }
            return stateResponse.getGet().getData();
        }
    }

    static class CachingStateIterable<T>
    extends PrefetchableIterables.Default<T> {
        private final Cache<BeamFnApi.StateKey, Blocks<T>> cache;
        private final BeamFnStateClient beamFnStateClient;
        private final BeamFnApi.StateRequest stateRequestForFirstChunk;
        private final Coder<T> valueCoder;

        private static <T> long sumWeight(List<Block<T>> blocks) {
            try {
                long sum = 0L;
                for (int i = 0; i < blocks.size(); ++i) {
                    sum = Math.addExact(sum, blocks.get(i).getWeight());
                }
                return sum;
            }
            catch (ArithmeticException e) {
                return Long.MAX_VALUE;
            }
        }

        private static long addBoundByMax(long first, long second) {
            try {
                return Math.addExact(first, second);
            }
            catch (ArithmeticException e) {
                return Long.MAX_VALUE;
            }
        }

        public CachingStateIterable(Cache<BeamFnApi.StateKey, Blocks<T>> cache, BeamFnStateClient beamFnStateClient, BeamFnApi.StateRequest stateRequestForFirstChunk, Coder<T> valueCoder) {
            this.cache = cache;
            this.beamFnStateClient = beamFnStateClient;
            this.stateRequestForFirstChunk = stateRequestForFirstChunk;
            this.valueCoder = valueCoder;
        }

        public void remove(Set<Object> toRemoveStructuralValues) {
            if (toRemoveStructuralValues.isEmpty()) {
                return;
            }
            Blocks<T> existing = this.cache.peek(this.stateRequestForFirstChunk.getStateKey());
            if (existing == null) {
                return;
            }
            if (existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken() != null) {
                this.cache.remove(this.stateRequestForFirstChunk.getStateKey());
            }
            List<Block<T>> blocks = existing.getBlocks();
            long totalWeight = 0L;
            int totalSize = 0;
            for (int i = 0; i < blocks.size(); ++i) {
                totalSize += blocks.get(i).getValues().size();
            }
            ArrayList<T> allValues = new ArrayList<T>(totalSize);
            for (int i = 0; i < blocks.size(); ++i) {
                int startIndex = allValues.size();
                for (T value : blocks.get(i).getValues()) {
                    if (toRemoveStructuralValues.contains(this.valueCoder.structuralValue(value))) continue;
                    allValues.add(value);
                }
                if (startIndex + blocks.get(i).getValues().size() == allValues.size()) {
                    totalWeight = CachingStateIterable.addBoundByMax(totalWeight, blocks.get(i).getWeight());
                    continue;
                }
                for (int j = startIndex; j < allValues.size(); ++j) {
                    totalWeight = CachingStateIterable.addBoundByMax(totalWeight, Caches.weigh(allValues.get(j)));
                }
            }
            this.cache.put(this.stateRequestForFirstChunk.getStateKey(), new MutatedBlocks(Block.mutatedBlock(allValues, totalWeight)));
        }

        public void clearAndAppend(List<T> values) {
            this.cache.put(this.stateRequestForFirstChunk.getStateKey(), new MutatedBlocks<T>(Block.mutatedBlock(new ArrayList<T>(values), Caches.weigh(values))));
        }

        @Override
        public PrefetchableIterator<T> createIterator() {
            return new CachingStateIterator();
        }

        public void append(List<T> values) {
            if (values.isEmpty()) {
                return;
            }
            Blocks<T> existing = this.cache.peek(this.stateRequestForFirstChunk.getStateKey());
            if (existing == null) {
                return;
            }
            if (existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken() != null) {
                this.cache.remove(this.stateRequestForFirstChunk.getStateKey());
            }
            List<Block<T>> blocks = existing.getBlocks();
            long totalWeight = CachingStateIterable.addBoundByMax(Caches.weigh(values), CachingStateIterable.sumWeight(blocks));
            int totalSize = values.size();
            for (int i = 0; i < blocks.size(); ++i) {
                totalSize += blocks.get(i).getValues().size();
            }
            ArrayList<T> allValues = new ArrayList<T>(totalSize);
            for (int i = 0; i < blocks.size(); ++i) {
                allValues.addAll(blocks.get(i).getValues());
            }
            allValues.addAll(values);
            this.cache.put(this.stateRequestForFirstChunk.getStateKey(), new MutatedBlocks(Block.mutatedBlock(allValues, totalWeight)));
        }

        class CachingStateIterator
        implements PrefetchableIterator<T> {
            private final LazyBlockingStateFetchingIterator underlyingStateFetchingIterator;
            private final DataStreams.DataStreamDecoder<T> dataStreamDecoder;
            private Block<T> currentBlock;
            private int currentCachedBlockValueIndex;

            public CachingStateIterator() {
                this.underlyingStateFetchingIterator = new LazyBlockingStateFetchingIterator(CachingStateIterable.this.beamFnStateClient, CachingStateIterable.this.stateRequestForFirstChunk);
                this.dataStreamDecoder = new DataStreams.DataStreamDecoder(CachingStateIterable.this.valueCoder, this.underlyingStateFetchingIterator);
                this.currentBlock = Block.fromValues(Collections.emptyList(), CachingStateIterable.this.stateRequestForFirstChunk.getGet().getContinuationToken());
                this.currentCachedBlockValueIndex = 0;
            }

            @Override
            public boolean isReady() {
                while (true) {
                    int currentBlockIndex;
                    if (this.currentBlock.getValues().size() > this.currentCachedBlockValueIndex) {
                        return true;
                    }
                    if (this.currentBlock.getNextToken() == null) {
                        return true;
                    }
                    Blocks existing = (Blocks)CachingStateIterable.this.cache.peek(CachingStateIterable.this.stateRequestForFirstChunk.getStateKey());
                    boolean isFirstBlock = ByteString.EMPTY.equals(this.currentBlock.getNextToken());
                    if (existing == null) {
                        return false;
                    }
                    if (isFirstBlock) {
                        this.currentBlock = existing.getBlocks().get(0);
                        this.currentCachedBlockValueIndex = 0;
                        continue;
                    }
                    List blocks = existing.getBlocks();
                    for (currentBlockIndex = 0; currentBlockIndex < blocks.size() && !this.currentBlock.getNextToken().equals(blocks.get(currentBlockIndex).getNextToken()); ++currentBlockIndex) {
                    }
                    if (currentBlockIndex + 1 >= blocks.size()) break;
                    this.currentBlock = blocks.get(currentBlockIndex + 1);
                    this.currentCachedBlockValueIndex = 0;
                }
                return false;
            }

            @Override
            public void prefetch() {
                if (!this.isReady()) {
                    this.underlyingStateFetchingIterator.seekToContinuationToken(this.currentBlock.getNextToken());
                    this.underlyingStateFetchingIterator.prefetch();
                }
            }

            @Override
            public boolean hasNext() {
                while (this.currentBlock.getValues().size() <= this.currentCachedBlockValueIndex) {
                    if (this.currentBlock.getNextToken() == null) {
                        return false;
                    }
                    Blocks existing = (Blocks)CachingStateIterable.this.cache.peek(CachingStateIterable.this.stateRequestForFirstChunk.getStateKey());
                    boolean isFirstBlock = ByteString.EMPTY.equals(this.currentBlock.getNextToken());
                    if (existing == null) {
                        this.currentBlock = this.loadNextBlock(this.currentBlock.getNextToken());
                        if (isFirstBlock) {
                            CachingStateIterable.this.cache.put(CachingStateIterable.this.stateRequestForFirstChunk.getStateKey(), new BlocksPrefix(Collections.singletonList(this.currentBlock)));
                        }
                    } else if (isFirstBlock) {
                        this.currentBlock = existing.getBlocks().get(0);
                    } else {
                        int currentBlockIndex;
                        Preconditions.checkState(existing instanceof BlocksPrefix, "Unexpected blocks type %s, expected a %s.", existing.getClass(), BlocksPrefix.class);
                        List blocks = existing.getBlocks();
                        for (currentBlockIndex = 0; currentBlockIndex < blocks.size() && !this.currentBlock.getNextToken().equals(blocks.get(currentBlockIndex).getNextToken()); ++currentBlockIndex) {
                        }
                        if (currentBlockIndex + 1 < blocks.size()) {
                            this.currentBlock = blocks.get(currentBlockIndex + 1);
                        } else {
                            this.currentBlock = this.loadNextBlock(this.currentBlock.getNextToken());
                            if (currentBlockIndex == blocks.size() - 1) {
                                ArrayList newBlocks = new ArrayList(currentBlockIndex + 1);
                                newBlocks.addAll(blocks);
                                newBlocks.add(this.currentBlock);
                                CachingStateIterable.this.cache.put(CachingStateIterable.this.stateRequestForFirstChunk.getStateKey(), new BlocksPrefix(newBlocks));
                            }
                        }
                    }
                    this.currentCachedBlockValueIndex = 0;
                }
                return true;
            }

            @VisibleForTesting
            Block<T> loadNextBlock(ByteString continuationToken) {
                this.underlyingStateFetchingIterator.seekToContinuationToken(continuationToken);
                List values = this.dataStreamDecoder.decodeFromChunkBoundaryToChunkBoundary();
                ByteString nextToken = this.underlyingStateFetchingIterator.getContinuationToken();
                if (ByteString.EMPTY.equals(nextToken)) {
                    nextToken = null;
                }
                return Block.fromValues(values, nextToken);
            }

            @Override
            public T next() {
                if (!this.hasNext()) {
                    throw new NoSuchElementException();
                }
                return this.currentBlock.getValues().get(this.currentCachedBlockValueIndex++);
            }
        }

        @AutoValue
        static abstract class Block<T>
        implements Weighted {
            Block() {
            }

            public static <T> Block<T> mutatedBlock(List<T> values, long weight) {
                return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<T>(values, null, weight);
            }

            public static <T> Block<T> fromValues(List<T> values, @Nullable ByteString nextToken) {
                return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<T>(values, nextToken, Caches.weigh(values) + Caches.weigh(nextToken));
            }

            abstract List<T> getValues();

            @Nullable
            abstract ByteString getNextToken();

            @Override
            public abstract long getWeight();
        }

        static class BlocksPrefix<T>
        extends Blocks<T>
        implements Cache.Shrinkable<BlocksPrefix<T>> {
            private final List<Block<T>> blocks;

            @Override
            public long getWeight() {
                return CachingStateIterable.sumWeight(this.blocks);
            }

            BlocksPrefix(List<Block<T>> blocks) {
                this.blocks = blocks;
            }

            @Override
            public BlocksPrefix<T> shrink() {
                ArrayList<Block<T>> subList = new ArrayList<Block<T>>(this.getBlocks().subList(0, this.getBlocks().size() / 2));
                if (subList.isEmpty()) {
                    return null;
                }
                return new BlocksPrefix<T>(subList);
            }

            @Override
            public List<Block<T>> getBlocks() {
                return this.blocks;
            }
        }

        static class MutatedBlocks<T>
        extends Blocks<T> {
            private final Block<T> wholeBlock;

            MutatedBlocks(Block<T> wholeBlock) {
                this.wholeBlock = wholeBlock;
            }

            @Override
            public List<Block<T>> getBlocks() {
                return Collections.singletonList(this.wholeBlock);
            }

            @Override
            public long getWeight() {
                return this.wholeBlock.getWeight();
            }
        }

        static abstract class Blocks<T>
        implements Weighted {
            Blocks() {
            }

            public abstract List<Block<T>> getBlocks();
        }
    }
}

