/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.DoubleLiteralExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.LambdaExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.ThisExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.stmt.Statement;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import com.github.javaparser.ast.type.UnknownType;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.commons.utils.KiePMMLModelUtils;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KiePMMLRegressionTableRegressionFactory {
    public static final String KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE_JAVA = "KiePMMLRegressionTableRegressionTemplate.tmpl";
    public static final String KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE = "KiePMMLRegressionTableRegressionTemplate";
    private static final Logger logger = LoggerFactory.getLogger((String)KiePMMLRegressionTableRegressionFactory.class.getName());
    static final String MAIN_CLASS_NOT_FOUND = "Main class not found";
    static final String KIE_PMML_EVALUATE_METHOD_TEMPLATE_JAVA = "KiePMMLEvaluateMethodTemplate.tmpl";
    static final String KIE_PMML_EVALUATE_METHOD_TEMPLATE = "KiePMMLEvaluateMethodTemplate";
    static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT, RegressionModel.NormalizationMethod.NONE);
    static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.LOGLOG);
    private static final String COEFFICIENT = "coefficient";
    private static AtomicInteger classArity = new AtomicInteger(0);
    private static AtomicInteger predictorsArity = new AtomicInteger(0);
    private static CompilationUnit templateEvaluate;
    private static CompilationUnit cloneEvaluate;

    private KiePMMLRegressionTableRegressionFactory() {
    }

    public static LinkedHashMap<String, KiePMMLTableSourceCategory> getRegressionTables(RegressionCompilationDTO compilationDTO) {
        logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
        LinkedHashMap<String, KiePMMLTableSourceCategory> toReturn = new LinkedHashMap<String, KiePMMLTableSourceCategory>();
        for (RegressionTable regressionTable : compilationDTO.getRegressionTables()) {
            Map.Entry<String, String> regressionTableEntry = KiePMMLRegressionTableRegressionFactory.getRegressionTable(regressionTable, compilationDTO);
            String targetCategory = regressionTable.getTargetCategory() != null ? regressionTable.getTargetCategory().toString() : "";
            toReturn.put(regressionTableEntry.getKey(), new KiePMMLTableSourceCategory(regressionTableEntry.getValue(), targetCategory));
        }
        return toReturn;
    }

    public static Map.Entry<String, String> getRegressionTable(RegressionTable regressionTable, RegressionCompilationDTO compilationDTO) {
        logger.trace("getRegressionTable {}", (Object)regressionTable);
        String className = "KiePMMLRegressionTableRegression" + classArity.addAndGet(1);
        CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit((String)className, (String)compilationDTO.getPackageName(), (String)KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE_JAVA, (String)KIE_PMML_REGRESSION_TABLE_REGRESSION_TEMPLATE);
        ClassOrInterfaceDeclaration tableTemplate = (ClassOrInterfaceDeclaration)cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException("Main class not found: " + className));
        ConstructorDeclaration constructorDeclaration = (ConstructorDeclaration)tableTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format("Missing default constructor in ClassOrInterfaceDeclaration %s ", tableTemplate.getName())));
        KiePMMLRegressionTableRegressionFactory.setConstructor(regressionTable, constructorDeclaration, tableTemplate.getName(), compilationDTO.getTargetFieldName(), regressionTable.getTargetCategory(), compilationDTO.getDefaultNormalizationMethod());
        Map<String, Expression> numericPredictorsMap = KiePMMLRegressionTableRegressionFactory.createNumericPredictorsExpressions(regressionTable.getNumericPredictors());
        Map<String, MethodDeclaration> predictorTermsMap = KiePMMLRegressionTableRegressionFactory.addPredictorTerms(regressionTable.getPredictorTerms(), tableTemplate);
        BlockStmt body = constructorDeclaration.getBody();
        Map<String, Expression> categoricalPredictorsMap = KiePMMLRegressionTableRegressionFactory.createCategoricalPredictorsExpressions(regressionTable.getCategoricalPredictors(), body);
        CommonCodegenUtils.addMapPopulationExpressions(numericPredictorsMap, (BlockStmt)body, (String)"numericFunctionMap");
        CommonCodegenUtils.addMapPopulationExpressions(categoricalPredictorsMap, (BlockStmt)body, (String)"categoricalFunctionMap");
        CommonCodegenUtils.addMapPopulation(predictorTermsMap, (BlockStmt)body, (String)"predictorTermsFunctionMap");
        return new AbstractMap.SimpleEntry<String, String>(JavaParserUtils.getFullClassName((CompilationUnit)cloneCU), cloneCU.toString());
    }

    static void setConstructor(RegressionTable regressionTable, ConstructorDeclaration constructorDeclaration, SimpleName tableName, String targetField, Object targetCategory, RegressionModel.NormalizationMethod normalizationMethod) {
        constructorDeclaration.setName(tableName);
        BlockStmt body = constructorDeclaration.getBody();
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"intercept", (Expression)new DoubleLiteralExpr(String.valueOf(regressionTable.getIntercept().doubleValue())));
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"targetField", (Expression)new StringLiteralExpr(targetField));
        Expression targetCategoryExpression = CommonCodegenUtils.getExpressionForObject((Object)targetCategory);
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"targetCategory", (Expression)targetCategoryExpression);
        Expression resultUpdaterExpression = KiePMMLRegressionTableRegressionFactory.createResultUpdaterExpression(normalizationMethod);
        CommonCodegenUtils.setAssignExpressionValue((BlockStmt)body, (String)"resultUpdater", (Expression)resultUpdaterExpression);
    }

    static Expression createResultUpdaterExpression(RegressionModel.NormalizationMethod normalizationMethod) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            return new NullLiteralExpr();
        }
        return KiePMMLRegressionTableRegressionFactory.createResultUpdaterSupportedExpression(normalizationMethod);
    }

    static MethodReferenceExpr createResultUpdaterSupportedExpression(RegressionModel.NormalizationMethod normalizationMethod) {
        String thisExpressionMethodName = String.format("update%sResult", normalizationMethod.name());
        CastExpr castExpr = new CastExpr();
        String doubleClassName = Double.class.getSimpleName();
        ClassOrInterfaceType consumerType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)SerializableFunction.class.getCanonicalName(), Arrays.asList(doubleClassName, doubleClassName));
        castExpr.setType((Type)consumerType);
        castExpr.setExpression((Expression)new ThisExpr());
        MethodReferenceExpr toReturn = new MethodReferenceExpr();
        toReturn.setScope((Expression)castExpr);
        toReturn.setIdentifier(thisExpressionMethodName);
        return toReturn;
    }

    static Map<String, Expression> createNumericPredictorsExpressions(List<NumericPredictor> numericPredictors) {
        return numericPredictors.stream().collect(Collectors.toMap(numericPredictor -> numericPredictor.getName().getValue(), KiePMMLRegressionTableRegressionFactory::createNumericPredictorExpression));
    }

    static CastExpr createNumericPredictorExpression(NumericPredictor numericPredictor) {
        boolean withExponent = !Objects.equals(1, numericPredictor.getExponent());
        String lambdaExpressionMethodName = withExponent ? "evaluateNumericWithExponent" : "evaluateNumericWithoutExponent";
        String parameterName = "input";
        MethodCallExpr lambdaMethodCallExpr = new MethodCallExpr();
        lambdaMethodCallExpr.setName(lambdaExpressionMethodName);
        NodeList arguments = new NodeList();
        arguments.add(0, (Node)new NameExpr("input"));
        arguments.add(1, (Node)CommonCodegenUtils.getExpressionForObject((Object)numericPredictor.getCoefficient().doubleValue()));
        if (withExponent) {
            arguments.add(2, (Node)CommonCodegenUtils.getExpressionForObject((Object)numericPredictor.getExponent().doubleValue()));
        }
        lambdaMethodCallExpr.setArguments(arguments);
        ExpressionStmt lambdaExpressionStmt = new ExpressionStmt((Expression)lambdaMethodCallExpr);
        LambdaExpr lambdaExpr = new LambdaExpr();
        Parameter lambdaParameter = new Parameter((Type)new UnknownType(), "input");
        lambdaExpr.setParameters(NodeList.nodeList((Node[])new Parameter[]{lambdaParameter}));
        lambdaExpr.setBody((Statement)lambdaExpressionStmt);
        String doubleClassName = Double.class.getSimpleName();
        ClassOrInterfaceType serializableFunctionType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)SerializableFunction.class.getCanonicalName(), Arrays.asList(doubleClassName, doubleClassName));
        CastExpr toReturn = new CastExpr();
        toReturn.setType((Type)serializableFunctionType);
        toReturn.setExpression((Expression)lambdaExpr);
        return toReturn;
    }

    static Map<String, Expression> createCategoricalPredictorsExpressions(List<CategoricalPredictor> categoricalPredictors, BlockStmt body) {
        Map<String, List<CategoricalPredictor>> groupedCollectors = categoricalPredictors.stream().collect(Collectors.groupingBy(categoricalPredictor -> categoricalPredictor.getField().getValue()));
        return groupedCollectors.entrySet().stream().map(entry -> {
            String categoricalPredictorMapName = KiePMMLModelUtils.getSanitizedVariableName((String)String.format("%sMap", entry.getKey()));
            KiePMMLRegressionTableRegressionFactory.populateWithGroupedCategoricalPredictorMap((List)entry.getValue(), body, categoricalPredictorMapName);
            return new AbstractMap.SimpleEntry<String, CastExpr>((String)entry.getKey(), KiePMMLRegressionTableRegressionFactory.createCategoricalPredictorExpression(categoricalPredictorMapName));
        }).collect(Collectors.toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
    }

    static void populateWithGroupedCategoricalPredictorMap(List<CategoricalPredictor> categoricalPredictors, BlockStmt toPopulate, String categoricalPredictorMapName) {
        VariableDeclarator categoricalMapDeclarator = new VariableDeclarator((Type)CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)Map.class.getName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName())), categoricalPredictorMapName);
        ObjectCreationExpr categoricalMapInitializer = new ObjectCreationExpr();
        categoricalMapInitializer.setType(CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)HashMap.class.getName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName())));
        categoricalMapDeclarator.setInitializer((Expression)categoricalMapInitializer);
        VariableDeclarationExpr categoricalMapDeclarationExpr = new VariableDeclarationExpr(categoricalMapDeclarator);
        toPopulate.addStatement((Expression)categoricalMapDeclarationExpr);
        LinkedHashMap mapExpressions = new LinkedHashMap();
        categoricalPredictors.forEach(categoricalPredictor -> mapExpressions.put(categoricalPredictor.getValue().toString(), CommonCodegenUtils.getExpressionForObject((Object)categoricalPredictor.getCoefficient().doubleValue())));
        CommonCodegenUtils.addMapPopulationExpressions(mapExpressions, (BlockStmt)toPopulate, (String)categoricalPredictorMapName);
    }

    static CastExpr createCategoricalPredictorExpression(String categoricalPredictorMapName) {
        String lambdaExpressionMethodName = "evaluateCategoricalPredictor";
        String parameterName = "input";
        MethodCallExpr lambdaMethodCallExpr = new MethodCallExpr();
        lambdaMethodCallExpr.setName("evaluateCategoricalPredictor");
        NodeList arguments = new NodeList();
        arguments.add(0, (Node)new NameExpr("input"));
        arguments.add(1, (Node)new NameExpr(categoricalPredictorMapName));
        lambdaMethodCallExpr.setArguments(arguments);
        ExpressionStmt lambdaExpressionStmt = new ExpressionStmt((Expression)lambdaMethodCallExpr);
        LambdaExpr lambdaExpr = new LambdaExpr();
        Parameter lambdaParameter = new Parameter((Type)new UnknownType(), "input");
        lambdaExpr.setParameters(NodeList.nodeList((Node[])new Parameter[]{lambdaParameter}));
        lambdaExpr.setBody((Statement)lambdaExpressionStmt);
        ClassOrInterfaceType serializableFunctionType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)SerializableFunction.class.getCanonicalName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName()));
        CastExpr toReturn = new CastExpr();
        toReturn.setType((Type)serializableFunctionType);
        toReturn.setExpression((Expression)lambdaExpr);
        return toReturn;
    }

    static Map<String, MethodDeclaration> addPredictorTerms(List<PredictorTerm> predictorTerms, ClassOrInterfaceDeclaration tableTemplate) {
        predictorsArity.set(0);
        return predictorTerms.stream().map(predictorTerm -> {
            int arity = predictorsArity.addAndGet(1);
            return new AbstractMap.SimpleEntry<String, MethodDeclaration>(predictorTerm.getName() != null ? predictorTerm.getName().getValue() : "predictorTerm" + arity, KiePMMLRegressionTableRegressionFactory.addPredictorTerm(predictorTerm, tableTemplate, arity));
        }).collect(Collectors.toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
    }

    static MethodDeclaration addPredictorTerm(PredictorTerm predictorTerm, ClassOrInterfaceDeclaration tableTemplate, int predictorArity) {
        try {
            templateEvaluate = JavaParserUtils.getFromFileName((String)KIE_PMML_EVALUATE_METHOD_TEMPLATE_JAVA);
            cloneEvaluate = templateEvaluate.clone();
            ClassOrInterfaceDeclaration evaluateTemplateClass = (ClassOrInterfaceDeclaration)cloneEvaluate.getClassByName(KIE_PMML_EVALUATE_METHOD_TEMPLATE).orElseThrow(() -> new RuntimeException(MAIN_CLASS_NOT_FOUND));
            MethodDeclaration methodTemplate = (MethodDeclaration)evaluateTemplateClass.getMethodsByName("evaluatePredictor").get(0);
            BlockStmt body = (BlockStmt)methodTemplate.getBody().orElseThrow(() -> new KiePMMLInternalException(String.format("Missing body in %s", methodTemplate.getName())));
            VariableDeclarator variableDeclarator = (VariableDeclarator)CommonCodegenUtils.getVariableDeclarator((BlockStmt)body, (String)"fieldRefs").orElseThrow(() -> new KiePMMLInternalException(String.format("Missing expected variable '%s' in body %s", "fieldRefs", body)));
            List nodeList = predictorTerm.getFieldRefs().stream().map(fieldRef -> new StringLiteralExpr(fieldRef.getField().getValue())).collect(Collectors.toList());
            NodeList expressions = NodeList.nodeList(nodeList);
            MethodCallExpr methodCallExpr = new MethodCallExpr((Expression)new NameExpr("Arrays"), "asList", expressions);
            variableDeclarator.setInitializer((Expression)methodCallExpr);
            variableDeclarator = (VariableDeclarator)CommonCodegenUtils.getVariableDeclarator((BlockStmt)body, (String)COEFFICIENT).orElseThrow(() -> new KiePMMLInternalException(String.format("Missing expected variable '%s' in body %s", COEFFICIENT, body)));
            variableDeclarator.setInitializer(String.valueOf(predictorTerm.getCoefficient().doubleValue()));
            return CommonCodegenUtils.addMethod((MethodDeclaration)methodTemplate, (ClassOrInterfaceDeclaration)tableTemplate, (String)("evaluatePredictorTerm" + predictorArity));
        }
        catch (Exception e) {
            throw new KiePMMLInternalException(String.format("Failed to add PredictorTerm %s", predictorTerm), (Throwable)e);
        }
    }
}

