/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Stack;
import java.util.TreeSet;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.StatsTask;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.MergeJoinWork;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.StatsWork;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MemoryDecider
implements PhysicalPlanResolver {
    protected static final transient Logger LOG = LoggerFactory.getLogger(MemoryDecider.class);

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        pctx.getConf();
        MemoryCalculator disp = new MemoryCalculator(pctx);
        TaskGraphWalker ogw = new TaskGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    public class MemoryCalculator
    implements Dispatcher {
        private final long totalAvailableMemory;
        private final long minimumHashTableSize;
        private final double inflationFactor;
        private final PhysicalContext pctx;

        public MemoryCalculator(PhysicalContext pctx) {
            this.pctx = pctx;
            this.totalAvailableMemory = HiveConf.getLongVar(pctx.conf, HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD);
            this.minimumHashTableSize = HiveConf.getIntVar(pctx.conf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINNUMPARTITIONS) * HiveConf.getIntVar(pctx.conf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINWBSIZE);
            this.inflationFactor = HiveConf.getFloatVar(pctx.conf, HiveConf.ConfVars.HIVE_HASH_TABLE_INFLATION_FACTOR);
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
            Task currTask = (Task)nd;
            if (currTask instanceof StatsTask) {
                currTask = ((StatsWork)((StatsTask)currTask).getWork()).getSourceTask();
            }
            if (currTask instanceof TezTask) {
                TezWork work = (TezWork)((TezTask)currTask).getWork();
                for (BaseWork w : work.getAllWork()) {
                    this.evaluateWork(w);
                }
            }
            return null;
        }

        private void evaluateWork(BaseWork w) throws SemanticException {
            if (w instanceof MapWork) {
                this.evaluateMapWork((MapWork)w);
            } else if (w instanceof ReduceWork) {
                this.evaluateReduceWork((ReduceWork)w);
            } else if (w instanceof MergeJoinWork) {
                this.evaluateMergeWork((MergeJoinWork)w);
            } else {
                LOG.info("We are not going to evaluate this work type: " + w.getClass().getCanonicalName());
            }
        }

        private void evaluateMergeWork(MergeJoinWork w) throws SemanticException {
            for (BaseWork baseWork : w.getBaseWorkList()) {
                this.evaluateOperators(baseWork, this.pctx);
            }
        }

        private void evaluateReduceWork(ReduceWork w) throws SemanticException {
            this.evaluateOperators(w, this.pctx);
        }

        private void evaluateMapWork(MapWork w) throws SemanticException {
            this.evaluateOperators(w, this.pctx);
        }

        private void evaluateOperators(BaseWork w, PhysicalContext pctx) throws SemanticException {
            DefaultRuleDispatcher disp = null;
            final LinkedHashSet mapJoins = new LinkedHashSet();
            HashMap<Rule, NodeProcessor> rules = new HashMap<Rule, NodeProcessor>();
            rules.put(new RuleRegExp("Map join memory estimator", MapJoinOperator.getOperatorName() + "%"), new NodeProcessor(){

                @Override
                public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) {
                    mapJoins.add((MapJoinOperator)nd);
                    return null;
                }
            });
            disp = new DefaultRuleDispatcher(null, rules, null);
            DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
            ArrayList<Node> topNodes = new ArrayList<Node>();
            topNodes.addAll(w.getAllRootOperators());
            LinkedHashMap<Node, Object> nodeOutput = new LinkedHashMap<Node, Object>();
            ogw.startWalking(topNodes, nodeOutput);
            if (mapJoins.size() == 0) {
                return;
            }
            try {
                long total = 0L;
                final HashMap<MapJoinOperator, Long> sizes = new HashMap<MapJoinOperator, Long>();
                final HashMap<MapJoinOperator, Integer> positions = new HashMap<MapJoinOperator, Integer>();
                int i = 0;
                for (MapJoinOperator mj : mapJoins) {
                    long size = this.computeSizeToFitInMem(mj);
                    sizes.put(mj, size);
                    positions.put(mj, i++);
                    total += size;
                }
                Comparator<MapJoinOperator> comp = new Comparator<MapJoinOperator>(){

                    @Override
                    public int compare(MapJoinOperator mj1, MapJoinOperator mj2) {
                        if (mj1 == null || mj2 == null) {
                            throw new NullPointerException();
                        }
                        int res = Long.compare((Long)sizes.get(mj1), (Long)sizes.get(mj2));
                        if (res == 0) {
                            res = Integer.compare((Integer)positions.get(mj1), (Integer)positions.get(mj2));
                        }
                        return res;
                    }
                };
                TreeSet<MapJoinOperator> sortedMapJoins = new TreeSet<MapJoinOperator>(comp);
                sortedMapJoins.addAll(mapJoins);
                long remainingSize = this.totalAvailableMemory / 2L;
                Iterator it = sortedMapJoins.iterator();
                long totalLargeJoins = 0L;
                while (it.hasNext()) {
                    MapJoinOperator mj = (MapJoinOperator)it.next();
                    long size = (Long)sizes.get(mj);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("MapJoin: " + mj + ", size: " + size + ", remaining: " + remainingSize);
                    }
                    if (size < remainingSize) {
                        if (LOG.isInfoEnabled()) {
                            LOG.info("Setting " + size + " bytes needed for " + mj + " (in-mem)");
                        }
                        ((MapJoinDesc)mj.getConf()).setMemoryNeeded(size);
                        remainingSize -= size;
                        it.remove();
                        continue;
                    }
                    totalLargeJoins += ((Long)sizes.get(mj)).longValue();
                }
                if (sortedMapJoins.isEmpty()) {
                    sortedMapJoins.addAll(mapJoins);
                    totalLargeJoins = total;
                    if (totalLargeJoins > this.totalAvailableMemory) {
                        throw new HiveException();
                    }
                    remainingSize = this.totalAvailableMemory / 2L;
                }
                double weight = (double)(remainingSize + this.totalAvailableMemory / 2L) / (double)totalLargeJoins;
                for (MapJoinOperator mj : sortedMapJoins) {
                    long size = (long)(weight * (double)((Long)sizes.get(mj)).longValue());
                    if (LOG.isInfoEnabled()) {
                        LOG.info("Setting " + size + " bytes needed for " + mj + " (spills)");
                    }
                    ((MapJoinDesc)mj.getConf()).setMemoryNeeded(size);
                }
            }
            catch (HiveException e) {
                long size = this.totalAvailableMemory / (long)mapJoins.size();
                if (LOG.isInfoEnabled()) {
                    LOG.info("Scaling mapjoin memory w/o stats");
                }
                for (MapJoinOperator mj : mapJoins) {
                    if (LOG.isInfoEnabled()) {
                        LOG.info("Setting " + size + " bytes needed for " + mj + " (fallback)");
                    }
                    ((MapJoinDesc)mj.getConf()).setMemoryNeeded(size);
                }
            }
        }

        private long computeSizeToFitInMem(MapJoinOperator mj) throws HiveException {
            return (long)((double)Math.max(this.minimumHashTableSize, this.computeInputSize(mj)) * this.inflationFactor);
        }

        private long computeInputSize(MapJoinOperator mj) throws HiveException {
            long size = 0L;
            if (mj.getConf() != null && ((MapJoinDesc)mj.getConf()).getParentDataSizes() != null) {
                for (long l : ((MapJoinDesc)mj.getConf()).getParentDataSizes().values()) {
                    size += l;
                }
            }
            if (size == 0L) {
                throw new HiveException("No data sizes");
            }
            return size;
        }

        public class DefaultRule
        implements NodeProcessor {
            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        }
    }
}

