/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.nodes.exec.common;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexProgram;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

public abstract class CommonExecPythonCalc
extends ExecNodeBase<RowData>
implements SingleTransformationTranslator<RowData> {
    private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.RowDataPythonScalarFunctionOperator";
    private static final String ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.arrow.RowDataArrowPythonScalarFunctionOperator";
    private final RexProgram calcProgram;

    public CommonExecPythonCalc(RexProgram calcProgram, InputProperty inputProperty, RowType outputType, String description) {
        super(Collections.singletonList(inputProperty), (LogicalType)outputType, description);
        this.calcProgram = calcProgram;
    }

    @Override
    protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
        ExecEdge inputEdge = this.getInputEdges().get(0);
        Transformation<?> inputTransform = inputEdge.translateToPlan(planner);
        Configuration config = CommonPythonUtil.getMergedConfig(planner.getExecEnv(), planner.getTableConfig());
        OneInputTransformation<RowData, RowData> ret = this.createPythonOneInputTransformation(inputTransform, this.calcProgram, this.getDescription(), config);
        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(config)) {
            ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
        }
        return ret;
    }

    private OneInputTransformation<RowData, RowData> createPythonOneInputTransformation(Transformation<RowData> inputTransform, RexProgram calcProgram, String name, Configuration config) {
        List<RexCall> pythonRexCalls = calcProgram.getProjectList().stream().map(calcProgram::expandLocalRef).filter(x -> x instanceof RexCall).map(x -> (RexCall)x).collect(Collectors.toList());
        List forwardedFields = calcProgram.getProjectList().stream().map(calcProgram::expandLocalRef).filter(x -> x instanceof RexInputRef).map(x -> ((RexInputRef)x).getIndex()).collect(Collectors.toList());
        Tuple2<int[], PythonFunctionInfo[]> extractResult = this.extractPythonScalarFunctionInfos(pythonRexCalls);
        int[] pythonUdfInputOffsets = (int[])extractResult.f0;
        PythonFunctionInfo[] pythonFunctionInfos = (PythonFunctionInfo[])extractResult.f1;
        LogicalType[] inputLogicalTypes = ((InternalTypeInfo)inputTransform.getOutputType()).toRowFieldTypes();
        InternalTypeInfo pythonOperatorInputTypeInfo = (InternalTypeInfo)inputTransform.getOutputType();
        List forwardedFieldsLogicalTypes = forwardedFields.stream().map(i -> inputLogicalTypes[i]).collect(Collectors.toList());
        List pythonCallLogicalTypes = pythonRexCalls.stream().map(node -> FlinkTypeFactory.toLogicalType(node.getType())).collect(Collectors.toList());
        ArrayList fieldsLogicalTypes = new ArrayList();
        fieldsLogicalTypes.addAll(forwardedFieldsLogicalTypes);
        fieldsLogicalTypes.addAll(pythonCallLogicalTypes);
        InternalTypeInfo pythonOperatorResultTyeInfo = InternalTypeInfo.ofFields((LogicalType[])fieldsLogicalTypes.toArray(new LogicalType[0]));
        OneInputStreamOperator<RowData, RowData> pythonOperator = this.getPythonScalarFunctionOperator(config, (InternalTypeInfo<RowData>)pythonOperatorInputTypeInfo, (InternalTypeInfo<RowData>)pythonOperatorResultTyeInfo, pythonUdfInputOffsets, pythonFunctionInfos, forwardedFields.stream().mapToInt(x -> x).toArray(), calcProgram.getExprList().stream().anyMatch(x -> PythonUtil.containsPythonCall(x, PythonFunctionKind.PANDAS)));
        return new OneInputTransformation(inputTransform, name, pythonOperator, (TypeInformation)pythonOperatorResultTyeInfo, inputTransform.getParallelism());
    }

    private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos(List<RexCall> rexCalls) {
        LinkedHashMap inputNodes = new LinkedHashMap();
        PythonFunctionInfo[] pythonFunctionInfos = rexCalls.stream().map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes)).collect(Collectors.toList()).toArray(new PythonFunctionInfo[rexCalls.size()]);
        int[] udfInputOffsets = inputNodes.keySet().stream().map(x -> {
            if (x instanceof RexInputRef) {
                return ((RexInputRef)x).getIndex();
            }
            if (x instanceof RexFieldAccess) {
                return ((RexFieldAccess)x).getField().getIndex();
            }
            return null;
        }).mapToInt(i -> i).toArray();
        return Tuple2.of((Object)udfInputOffsets, (Object)pythonFunctionInfos);
    }

    private OneInputStreamOperator<RowData, RowData> getPythonScalarFunctionOperator(Configuration config, InternalTypeInfo<RowData> inputRowTypeInfo, InternalTypeInfo<RowData> outputRowTypeInfo, int[] udfInputOffsets, PythonFunctionInfo[] pythonFunctionInfos, int[] forwardedFields, boolean isArrow) {
        Class clazz = isArrow ? CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME) : CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
        try {
            Constructor ctor = clazz.getConstructor(Configuration.class, PythonFunctionInfo[].class, RowType.class, RowType.class, int[].class, int[].class);
            return (OneInputStreamOperator)ctor.newInstance(config, pythonFunctionInfos, inputRowTypeInfo.toRowType(), outputRowTypeInfo.toRowType(), udfInputOffsets, forwardedFields);
        }
        catch (Exception e) {
            throw new TableException("Python Scalar Function Operator constructed failed.", (Throwable)e);
        }
    }
}

