/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness;

import java.io.IOException;
import java.util.Map;
import org.apache.beam.fn.harness.GroupingTable;
import org.apache.beam.fn.harness.MapFnRunners;
import org.apache.beam.fn.harness.PTransformRunnerFactory;
import org.apache.beam.fn.harness.PrecombineGroupingTable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.function.ThrowingFunction;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;

public class CombineRunners {
    static <KeyT, AccumT> ThrowingFunction<KV<KeyT, Iterable<AccumT>>, KV<KeyT, AccumT>> createMergeAccumulatorsMapFunction(String pTransformId, RunnerApi.PTransform pTransform) throws IOException {
        RunnerApi.CombinePayload combinePayload = RunnerApi.CombinePayload.parseFrom(pTransform.getSpec().getPayload());
        Combine.CombineFn combineFn = (Combine.CombineFn)SerializableUtils.deserializeFromByteArray(combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");
        return input -> KV.of(input.getKey(), combineFn.mergeAccumulators((Iterable)input.getValue()));
    }

    static <KeyT, AccumT, OutputT> ThrowingFunction<KV<KeyT, AccumT>, KV<KeyT, OutputT>> createExtractOutputsMapFunction(String pTransformId, RunnerApi.PTransform pTransform) throws IOException {
        RunnerApi.CombinePayload combinePayload = RunnerApi.CombinePayload.parseFrom(pTransform.getSpec().getPayload());
        Combine.CombineFn combineFn = (Combine.CombineFn)SerializableUtils.deserializeFromByteArray(combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");
        return input -> KV.of(input.getKey(), combineFn.extractOutput(input.getValue()));
    }

    static <KeyT, InputT, AccumT> ThrowingFunction<KV<KeyT, InputT>, KV<KeyT, AccumT>> createConvertToAccumulatorsMapFunction(String pTransformId, RunnerApi.PTransform pTransform) throws IOException {
        RunnerApi.CombinePayload combinePayload = RunnerApi.CombinePayload.parseFrom(pTransform.getSpec().getPayload());
        Combine.CombineFn combineFn = (Combine.CombineFn)SerializableUtils.deserializeFromByteArray(combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");
        return input -> KV.of(input.getKey(), combineFn.addInput(combineFn.createAccumulator(), input.getValue()));
    }

    static <KeyT, InputT, AccumT, OutputT> ThrowingFunction<KV<KeyT, Iterable<InputT>>, KV<KeyT, OutputT>> createCombineGroupedValuesMapFunction(String pTransformId, RunnerApi.PTransform pTransform) throws IOException {
        RunnerApi.CombinePayload combinePayload = RunnerApi.CombinePayload.parseFrom(pTransform.getSpec().getPayload());
        Combine.CombineFn combineFn = (Combine.CombineFn)SerializableUtils.deserializeFromByteArray(combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");
        return input -> KV.of(input.getKey(), combineFn.apply((Iterable)input.getValue()));
    }

    @VisibleForTesting
    public static class PrecombineFactory<KeyT, InputT, AccumT>
    implements PTransformRunnerFactory<PrecombineRunner<KeyT, InputT, AccumT>> {
        @Override
        public PrecombineRunner<KeyT, InputT, AccumT> createRunnerForPTransform(PTransformRunnerFactory.Context context) throws IOException {
            RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(RunnerApi.Components.newBuilder().putAllCoders(context.getCoders()).putAllWindowingStrategies(context.getWindowingStrategies()).build());
            String mainInputTag = Iterables.getOnlyElement(context.getPTransform().getInputsMap().keySet());
            RunnerApi.PCollection mainInput = context.getPCollections().get(context.getPTransform().getInputsOrThrow(mainInputTag));
            Coder<?> uncastInputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
            KvCoder inputCoder = uncastInputCoder instanceof WindowedValue.WindowedValueCoder ? (KvCoder)((WindowedValue.WindowedValueCoder)uncastInputCoder).getValueCoder() : (KvCoder)rehydratedComponents.getCoder(mainInput.getCoderId());
            Coder keyCoder = inputCoder.getKeyCoder();
            RunnerApi.CombinePayload combinePayload = RunnerApi.CombinePayload.parseFrom(context.getPTransform().getSpec().getPayload());
            Combine.CombineFn combineFn = (Combine.CombineFn)SerializableUtils.deserializeFromByteArray(combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");
            Coder<?> accumCoder = rehydratedComponents.getCoder(combinePayload.getAccumulatorCoderId());
            FnDataReceiver consumer = context.getPCollectionConsumer(Iterables.getOnlyElement(context.getPTransform().getOutputsMap().values()));
            PrecombineRunner runner = new PrecombineRunner(context.getPipelineOptions(), combineFn, consumer, keyCoder, accumCoder);
            context.addStartBundleFunction(runner::startBundle);
            context.addPCollectionConsumer(Iterables.getOnlyElement(context.getPTransform().getInputsMap().values()), runner::processElement, inputCoder);
            context.addFinishBundleFunction(runner::finishBundle);
            return runner;
        }
    }

    private static class PrecombineRunner<KeyT, InputT, AccumT> {
        private PipelineOptions options;
        private Combine.CombineFn<InputT, AccumT, ?> combineFn;
        private FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output;
        private Coder<KeyT> keyCoder;
        private GroupingTable<WindowedValue<KeyT>, InputT, AccumT> groupingTable;
        private Coder<AccumT> accumCoder;

        PrecombineRunner(PipelineOptions options, Combine.CombineFn<InputT, AccumT, ?> combineFn, FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output, Coder<KeyT> keyCoder, Coder<AccumT> accumCoder) {
            this.options = options;
            this.combineFn = combineFn;
            this.output = output;
            this.keyCoder = keyCoder;
            this.accumCoder = accumCoder;
        }

        void startBundle() {
            this.groupingTable = PrecombineGroupingTable.combiningAndSampling(this.options, this.combineFn, this.keyCoder, this.accumCoder, 0.001);
        }

        void processElement(WindowedValue<KV<KeyT, InputT>> elem) throws Exception {
            this.groupingTable.put(elem, outputElem -> this.output.accept((WindowedValue)outputElem));
        }

        void finishBundle() throws Exception {
            this.groupingTable.flush(outputElem -> this.output.accept((WindowedValue)outputElem));
        }
    }

    public static class Registrar
    implements PTransformRunnerFactory.Registrar {
        @Override
        public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
            return ImmutableMap.of("beam:transform:combine_per_key_precombine:v1", new PrecombineFactory(), "beam:transform:combine_per_key_merge_accumulators:v1", MapFnRunners.forValueMapFnFactory(CombineRunners::createMergeAccumulatorsMapFunction), "beam:transform:combine_per_key_extract_outputs:v1", MapFnRunners.forValueMapFnFactory(CombineRunners::createExtractOutputsMapFunction), "beam:transform:combine_per_key_convert_to_accumulators:v1", MapFnRunners.forValueMapFnFactory(CombineRunners::createConvertToAccumulatorsMapFunction), "beam:transform:combine_grouped_values:v1", MapFnRunners.forValueMapFnFactory(CombineRunners::createCombineGroupedValuesMapFunction));
        }
    }
}

