/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.ppd;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.PreOrderWalker;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.OpParseContext;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.RowResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicListDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.JoinCondDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;

public class SyntheticJoinPredicate
implements Transform {
    private static transient Log LOG = LogFactory.getLog((String)SyntheticJoinPredicate.class.getName());

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        if (!pctx.getConf().getVar(HiveConf.ConfVars.HIVE_EXECUTION_ENGINE).equals("tez") || !pctx.getConf().getBoolVar(HiveConf.ConfVars.TEZ_DYNAMIC_PARTITION_PRUNING)) {
            return pctx;
        }
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", "(" + TableScanOperator.getOperatorName() + "%" + ".*" + ReduceSinkOperator.getOperatorName() + "%" + JoinOperator.getOperatorName() + "%)"), new JoinSynthetic());
        SyntheticContext context = new SyntheticContext(pctx);
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, context);
        PreOrderWalker ogw = new PreOrderWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private static Operator<FilterDesc> createFilter(Operator<?> target, Operator<?> parent, RowResolver parentRR, ExprNodeDesc filterExpr) {
        Operator<FilterDesc> filter = OperatorFactory.get(new FilterDesc(filterExpr, false), new RowSchema(parentRR.getColumnInfos()), new Operator[0]);
        filter.setParentOperators(new ArrayList<Operator<? extends OperatorDesc>>());
        filter.setChildOperators(new ArrayList<Operator<? extends OperatorDesc>>());
        filter.getParentOperators().add(parent);
        filter.getChildOperators().add(target);
        parent.replaceChild(target, filter);
        target.replaceParent(parent, filter);
        return filter;
    }

    private static class Vectors {
        private final Set<Integer>[] vector;

        public Vectors(int length) {
            this.vector = new Set[length];
        }

        public void add(int from, int to) {
            if (this.vector[from] == null) {
                this.vector[from] = new HashSet<Integer>();
            }
            this.vector[from].add(to);
        }

        public int[] traverse(int pos) {
            HashSet<Integer> targets = new HashSet<Integer>();
            this.traverse(targets, pos);
            return this.toArray(targets);
        }

        private int[] toArray(Set<Integer> values) {
            int index = 0;
            int[] result = new int[values.size()];
            for (int value : values) {
                result[index++] = value;
            }
            return result;
        }

        private void traverse(Set<Integer> targets, int pos) {
            if (this.vector[pos] == null) {
                return;
            }
            for (int target : this.vector[pos]) {
                if (!targets.add(target)) continue;
                this.traverse(targets, target);
            }
        }
    }

    private static class JoinSynthetic
    implements NodeProcessor {
        private JoinSynthetic() {
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            ParseContext pCtx = ((SyntheticContext)procCtx).getParseContext();
            CommonJoinOperator join = (CommonJoinOperator)nd;
            ReduceSinkOperator source = (ReduceSinkOperator)stack.get(stack.size() - 2);
            int srcPos = join.getParentOperators().indexOf(source);
            List<Operator<OperatorDesc>> parents = join.getParentOperators();
            int[][] targets = this.getTargets(join);
            Operator parent = source.getParentOperators().get(0);
            RowResolver parentRR = pCtx.getOpParseCtx().get(parent).getRowResolver();
            if (((JoinDesc)join.getConf()).getNullSafes() != null) {
                for (boolean b : ((JoinDesc)join.getConf()).getNullSafes()) {
                    if (!b) continue;
                    return null;
                }
            }
            for (int targetPos : targets[srcPos]) {
                if (srcPos == targetPos) continue;
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)("Synthetic predicate: " + srcPos + " --> " + targetPos));
                }
                ReduceSinkOperator target = (ReduceSinkOperator)parents.get(targetPos);
                ArrayList<ExprNodeDesc> sourceKeys = ((ReduceSinkDesc)source.getConf()).getKeyCols();
                ArrayList<ExprNodeDesc> targetKeys = ((ReduceSinkDesc)target.getConf()).getKeyCols();
                if (sourceKeys.size() < 1) continue;
                ExprNodeGenericFuncDesc syntheticExpr = null;
                for (int i = 0; i < sourceKeys.size(); ++i) {
                    ArrayList<ExprNodeDesc> inArgs = new ArrayList<ExprNodeDesc>();
                    inArgs.add((ExprNodeDesc)sourceKeys.get(i));
                    ExprNodeDynamicListDesc dynamicExpr = new ExprNodeDynamicListDesc(((ExprNodeDesc)targetKeys.get(i)).getTypeInfo(), target, i);
                    inArgs.add(dynamicExpr);
                    ExprNodeGenericFuncDesc syntheticInExpr = ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo("in").getGenericUDF(), inArgs);
                    if (syntheticExpr != null) {
                        ArrayList<ExprNodeDesc> andArgs = new ArrayList<ExprNodeDesc>();
                        andArgs.add(syntheticExpr);
                        andArgs.add(syntheticInExpr);
                        syntheticExpr = ExprNodeGenericFuncDesc.newInstance(FunctionRegistry.getFunctionInfo("and").getGenericUDF(), andArgs);
                        continue;
                    }
                    syntheticExpr = syntheticInExpr;
                }
                Operator newFilter = SyntheticJoinPredicate.createFilter(source, parent, parentRR, syntheticExpr);
                pCtx.getOpParseCtx().put(newFilter, new OpParseContext(parentRR));
                parent = newFilter;
            }
            return null;
        }

        private int[][] getTargets(CommonJoinOperator<JoinDesc> join) {
            JoinCondDesc[] conds = ((JoinDesc)join.getConf()).getConds();
            int aliases = conds.length + 1;
            Vectors vector = new Vectors(aliases);
            block5: for (JoinCondDesc cond : conds) {
                int left = cond.getLeft();
                int right = cond.getRight();
                switch (cond.getType()) {
                    case 0: 
                    case 5: {
                        vector.add(left, right);
                        vector.add(right, left);
                        continue block5;
                    }
                    case 1: {
                        vector.add(right, left);
                        continue block5;
                    }
                    case 2: {
                        vector.add(left, right);
                        continue block5;
                    }
                }
            }
            int[][] result = new int[aliases][];
            for (int pos = 0; pos < aliases; ++pos) {
                result[pos] = vector.traverse(pos);
            }
            return result;
        }
    }

    private static class SyntheticContext
    implements NodeProcessorCtx {
        ParseContext parseContext;

        public SyntheticContext(ParseContext pCtx) {
            this.parseContext = pCtx;
        }

        public ParseContext getParseContext() {
            return this.parseContext;
        }
    }
}

