/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.compile;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.drill.common.config.DrillConfig;
import org.apache.drill.common.util.DrillStringUtils;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.compile.AsmUtil;
import org.apache.drill.exec.compile.ByteCodeLoader;
import org.apache.drill.exec.compile.MergeAdapter;
import org.apache.drill.exec.compile.QueryClassLoader;
import org.apache.drill.exec.compile.TemplateClassDefinition;
import org.apache.drill.exec.exception.ClassTransformationException;
import org.apache.drill.exec.expr.CodeGenerator;
import org.apache.drill.exec.server.options.OptionSet;
import org.apache.drill.shaded.guava.com.google.common.annotations.VisibleForTesting;
import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;
import org.apache.drill.shaded.guava.com.google.common.collect.Lists;
import org.apache.drill.shaded.guava.com.google.common.collect.Maps;
import org.apache.drill.shaded.guava.com.google.common.collect.Sets;
import org.codehaus.commons.compiler.CompileException;
import org.objectweb.asm.tree.ClassNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClassTransformer {
    private static final Logger logger = LoggerFactory.getLogger(ClassTransformer.class);
    private static final int MAX_SCALAR_REPLACE_CODE_SIZE = 0x200000;
    private final ByteCodeLoader byteCodeLoader = new ByteCodeLoader();
    private final DrillConfig config;
    private final OptionSet optionManager;

    public ClassTransformer(DrillConfig config, OptionSet optionManager) {
        this.config = config;
        this.optionManager = optionManager;
    }

    public Class<?> getImplementationClass(CodeGenerator<?> cg) throws ClassTransformationException {
        QueryClassLoader loader = new QueryClassLoader(this.config, this.optionManager);
        return this.getImplementationClass(loader, cg.getDefinition(), cg.getGeneratedCode(), cg.getMaterializedClassName());
    }

    public Class<?> getImplementationClass(QueryClassLoader classLoader, TemplateClassDefinition<?> templateDefinition, String entireClass, String materializedClassName) throws ClassTransformationException {
        ScalarReplacementOption scalarReplacementOption = ScalarReplacementOption.fromString(this.optionManager.getOption(ExecConstants.SCALAR_REPLACEMENT_VALIDATOR));
        try {
            long t1 = System.nanoTime();
            ClassSet set = new ClassSet(null, templateDefinition.getTemplateClassName(), materializedClassName);
            byte[][] implementationClasses = classLoader.getClassByteCode(set.generated, entireClass);
            long totalBytecodeSize = 0L;
            HashMap<String, Pair> classesToMerge = Maps.newHashMap();
            for (byte[] byArray : implementationClasses) {
                totalBytecodeSize += (long)byArray.length;
                ClassNode node = AsmUtil.classFromBytes(byArray, 4);
                if (!AsmUtil.isClassOk(logger, "implementationClasses", node)) {
                    throw new IllegalStateException("Problem found with implementationClasses");
                }
                classesToMerge.put(node.name, Pair.of((Object)byArray, (Object)node));
            }
            LinkedList<ClassSet> names = Lists.newLinkedList();
            HashSet<ClassSet> namesCompleted = Sets.newHashSet();
            names.add(set);
            while (!names.isEmpty()) {
                ClassSet nextSet = (ClassSet)names.removeFirst();
                if (namesCompleted.contains(nextSet)) continue;
                ClassNames classNames = nextSet.precompiled;
                byte[] precompiledBytes = this.byteCodeLoader.getClassByteCodeFromPath(classNames.clazz);
                ClassNames nextGenerated = nextSet.generated;
                Pair classNodePair = (Pair)classesToMerge.remove(nextGenerated.slash);
                ClassNode generatedNode = classNodePair != null ? (ClassNode)classNodePair.getValue() : null;
                MergeAdapter.MergedClassResult result = null;
                boolean scalarReplace = scalarReplacementOption != ScalarReplacementOption.OFF && entireClass.length() < 0x200000;
                while (true) {
                    try {
                        result = MergeAdapter.getMergedClass(nextSet, precompiledBytes, generatedNode, scalarReplace);
                    }
                    catch (RuntimeException e) {
                        if (!scalarReplace) {
                            throw e;
                        }
                        if (scalarReplacementOption == ScalarReplacementOption.ON) {
                            throw e;
                        }
                        logger.info("scalar replacement failure (retrying)\n", (Throwable)e);
                        scalarReplace = false;
                        continue;
                    }
                    break;
                }
                for (String s : result.innerClasses) {
                    s = s.replace('/', '.');
                    names.add(nextSet.getChild(s));
                }
                classLoader.injectByteCode(nextGenerated.dot, result.bytes);
                namesCompleted.add(nextSet);
            }
            for (Map.Entry entry : classesToMerge.entrySet()) {
                classLoader.injectByteCode(((String)entry.getKey()).replace('/', '.'), (byte[])((Pair)entry.getValue()).getKey());
            }
            Class<?> c = classLoader.findClass(set.generated.dot);
            if (templateDefinition.getExternalInterface().isAssignableFrom(c)) {
                logger.debug("Compiled and merged {}: bytecode size = {}, time = {} ms.", new Object[]{c.getSimpleName(), DrillStringUtils.readable(totalBytecodeSize), (System.nanoTime() - t1 + 500000L) / 1000000L});
                return c;
            }
            throw new ClassTransformationException("The requested class did not implement the expected interface.");
        }
        catch (IOException | ClassNotFoundException | CompileException e) {
            throw new ClassTransformationException(String.format("Failure generating transformation classes for value: \n %s", entireClass), e);
        }
    }

    @VisibleForTesting
    public static enum ScalarReplacementOption {
        OFF,
        TRY,
        ON;


        public static ScalarReplacementOption fromString(String s) {
            switch (s) {
                case "off": {
                    return OFF;
                }
                case "try": {
                    return TRY;
                }
                case "on": {
                    return ON;
                }
            }
            throw new IllegalArgumentException("Invalid ScalarReplacementOption \"" + s + "\"");
        }
    }

    public static class ClassSet {
        public final ClassSet parent;
        public final ClassNames precompiled;
        public final ClassNames generated;

        public ClassSet(ClassSet parent, String precompiled, String generated) {
            Preconditions.checkArgument(!generated.startsWith(precompiled), String.format("The new name of a class cannot start with the old name of a class, otherwise class renaming will cause problems. Precompiled class name %s. Generated class name %s", precompiled, generated));
            this.parent = parent;
            this.precompiled = new ClassNames(precompiled);
            this.generated = new ClassNames(generated);
        }

        public ClassSet getChild(String precompiled, String generated) {
            return new ClassSet(this, precompiled, generated);
        }

        public ClassSet getChild(String precompiled) {
            return new ClassSet(this, precompiled, precompiled.replace(this.precompiled.dot, this.generated.dot));
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.generated == null ? 0 : this.generated.hashCode());
            result = 31 * result + (this.parent == null ? 0 : this.parent.hashCode());
            result = 31 * result + (this.precompiled == null ? 0 : this.precompiled.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            ClassSet other = (ClassSet)obj;
            if (this.generated == null ? other.generated != null : !this.generated.equals(other.generated)) {
                return false;
            }
            if (this.parent == null ? other.parent != null : !this.parent.equals(other.parent)) {
                return false;
            }
            return !(this.precompiled == null ? other.precompiled != null : !this.precompiled.equals(other.precompiled));
        }
    }

    public static class ClassNames {
        public final String dot;
        public final String slash;
        public final String clazz;

        public ClassNames(String className) {
            this.dot = className;
            this.slash = className.replace('.', '/');
            this.clazz = '/' + this.slash + ".class";
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.clazz == null ? 0 : this.clazz.hashCode());
            result = 31 * result + (this.dot == null ? 0 : this.dot.hashCode());
            result = 31 * result + (this.slash == null ? 0 : this.slash.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            ClassNames other = (ClassNames)obj;
            if (this.clazz == null ? other.clazz != null : !this.clazz.equals(other.clazz)) {
                return false;
            }
            if (this.dot == null ? other.dot != null : !this.dot.equals(other.dot)) {
                return false;
            }
            return !(this.slash == null ? other.slash != null : !this.slash.equals(other.slash));
        }
    }
}

