/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.correlation;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.ForwardWalker;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.optimizer.correlation.AbstractCorrelationProcCtx;
import org.apache.hadoop.hive.ql.optimizer.correlation.CorrelationUtilities;
import org.apache.hadoop.hive.ql.optimizer.correlation.ReduceSinkDeDuplicationUtils;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReduceSinkJoinDeDuplication
extends Transform {
    protected static final Logger LOG = LoggerFactory.getLogger(ReduceSinkJoinDeDuplication.class);
    protected ParseContext pGraphContext;

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        this.pGraphContext = pctx;
        ReduceSinkJoinDeDuplicateProcCtx cppCtx = new ReduceSinkJoinDeDuplicateProcCtx(this.pGraphContext);
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", ReduceSinkOperator.getOperatorName() + "%"), ReduceSinkJoinDeDuplicateProcFactory.getReducerMapJoinProc());
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(ReduceSinkJoinDeDuplicateProcFactory.getDefaultProc(), opRules, cppCtx);
        ForwardWalker ogw = new ForwardWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(this.pGraphContext.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return this.pGraphContext;
    }

    static class ReducerProc
    implements NodeProcessor {
        ReducerProc() {
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            ReduceSinkJoinDeDuplicateProcCtx dedupCtx = (ReduceSinkJoinDeDuplicateProcCtx)procCtx;
            ReduceSinkOperator cRS = (ReduceSinkOperator)nd;
            if (((ReduceSinkDesc)cRS.getConf()).isForwarding()) {
                return false;
            }
            if (((ReduceSinkDesc)cRS.getConf()).getKeyCols().isEmpty()) {
                return false;
            }
            boolean onlyPartitioning = false;
            Operator<OperatorDesc> cRSChild = cRS.getChildOperators().get(0);
            if (cRSChild instanceof MapJoinOperator || cRSChild instanceof CommonMergeJoinOperator) {
                for (Operator<OperatorDesc> parent : cRSChild.getParentOperators()) {
                    if (parent instanceof ReduceSinkOperator) continue;
                    return false;
                }
                if (cRSChild instanceof MapJoinOperator) {
                    onlyPartitioning = true;
                }
            }
            int maxNumReducers = ((ReduceSinkDesc)cRS.getConf()).getNumReducers();
            ReduceSinkOperator pRS = onlyPartitioning ? CorrelationUtilities.findFirstPossibleParent(cRS, ReduceSinkOperator.class, dedupCtx.trustScript()) : CorrelationUtilities.findFirstPossibleParentPreserveSortOrder(cRS, ReduceSinkOperator.class, dedupCtx.trustScript());
            if (pRS != null) {
                Operator<OperatorDesc> pRSChild = pRS.getChildOperators().get(0);
                if (pRSChild instanceof MapJoinOperator) {
                    MapJoinOperator pRSChildMJ = (MapJoinOperator)pRSChild;
                    if (!(((MapJoinDesc)pRSChildMJ.getConf()).isDynamicPartitionHashJoin() && cRSChild instanceof MapJoinOperator && ((MapJoinDesc)((MapJoinOperator)cRSChild).getConf()).isDynamicPartitionHashJoin())) {
                        return false;
                    }
                    ImmutableList.Builder l = ImmutableList.builder();
                    for (Operator<OperatorDesc> parent : pRSChild.getParentOperators()) {
                        ReduceSinkOperator rsOp = (ReduceSinkOperator)parent;
                        l.add(rsOp);
                        if (((ReduceSinkDesc)rsOp.getConf()).getNumReducers() <= maxNumReducers) continue;
                        maxNumReducers = ((ReduceSinkDesc)rsOp.getConf()).getNumReducers();
                    }
                    if (ReduceSinkDeDuplicationUtils.strictMerge(cRS, (List<ReduceSinkOperator>)((Object)l.build()))) {
                        LOG.debug("Set {} to forward data", (Object)cRS);
                        ((ReduceSinkDesc)cRS.getConf()).setForwarding(true);
                        ReducerProc.propagateMaxNumReducers(dedupCtx, cRS, maxNumReducers);
                        return true;
                    }
                } else if (pRS.getChildOperators().get(0) instanceof CommonMergeJoinOperator) {
                    ImmutableList.Builder l = ImmutableList.builder();
                    for (Operator<OperatorDesc> parent : pRSChild.getParentOperators()) {
                        if (!(parent instanceof ReduceSinkOperator)) {
                            return false;
                        }
                        ReduceSinkOperator rsOp = (ReduceSinkOperator)parent;
                        l.add(rsOp);
                        if (((ReduceSinkDesc)rsOp.getConf()).getNumReducers() <= maxNumReducers) continue;
                        maxNumReducers = ((ReduceSinkDesc)rsOp.getConf()).getNumReducers();
                    }
                    if (ReduceSinkDeDuplicationUtils.strictMerge(cRS, (List<ReduceSinkOperator>)((Object)l.build()))) {
                        LOG.debug("Set {} to forward data", (Object)cRS);
                        ((ReduceSinkDesc)cRS.getConf()).setForwarding(true);
                        ReducerProc.propagateMaxNumReducers(dedupCtx, cRS, maxNumReducers);
                        return true;
                    }
                } else {
                    if (((ReduceSinkDesc)pRS.getConf()).getNumReducers() > maxNumReducers) {
                        maxNumReducers = ((ReduceSinkDesc)pRS.getConf()).getNumReducers();
                    }
                    if (ReduceSinkDeDuplicationUtils.strictMerge(cRS, pRS)) {
                        LOG.debug("Set {} to forward data", (Object)cRS);
                        ((ReduceSinkDesc)cRS.getConf()).setForwarding(true);
                        ReducerProc.propagateMaxNumReducers(dedupCtx, cRS, maxNumReducers);
                        return true;
                    }
                }
            }
            return false;
        }

        private static void propagateMaxNumReducers(ReduceSinkJoinDeDuplicateProcCtx dedupCtx, ReduceSinkOperator rsOp, int maxNumReducers) throws SemanticException {
            if (rsOp == null) {
                return;
            }
            if (rsOp.getChildOperators().get(0) instanceof MapJoinOperator || rsOp.getChildOperators().get(0) instanceof CommonMergeJoinOperator) {
                for (Operator<OperatorDesc> p : rsOp.getChildOperators().get(0).getParentOperators()) {
                    ReduceSinkOperator pRSOp = (ReduceSinkOperator)p;
                    ((ReduceSinkDesc)pRSOp.getConf()).setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.FIXED));
                    ((ReduceSinkDesc)pRSOp.getConf()).setNumReducers(maxNumReducers);
                    LOG.debug("Set {} to FIXED parallelism: {}", (Object)pRSOp, (Object)maxNumReducers);
                    if (!((ReduceSinkDesc)pRSOp.getConf()).isForwarding()) continue;
                    ReduceSinkOperator newRSOp = CorrelationUtilities.findFirstPossibleParent(pRSOp, ReduceSinkOperator.class, dedupCtx.trustScript());
                    ReducerProc.propagateMaxNumReducers(dedupCtx, newRSOp, maxNumReducers);
                }
            } else {
                ((ReduceSinkDesc)rsOp.getConf()).setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.FIXED));
                ((ReduceSinkDesc)rsOp.getConf()).setNumReducers(maxNumReducers);
                LOG.debug("Set {} to FIXED parallelism: {}", (Object)rsOp, (Object)maxNumReducers);
                if (((ReduceSinkDesc)rsOp.getConf()).isForwarding()) {
                    ReduceSinkOperator newRSOp = CorrelationUtilities.findFirstPossibleParent(rsOp, ReduceSinkOperator.class, dedupCtx.trustScript());
                    ReducerProc.propagateMaxNumReducers(dedupCtx, newRSOp, maxNumReducers);
                }
            }
        }
    }

    static class DefaultProc
    implements NodeProcessor {
        DefaultProc() {
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            return null;
        }
    }

    static class ReduceSinkJoinDeDuplicateProcFactory {
        ReduceSinkJoinDeDuplicateProcFactory() {
        }

        public static NodeProcessor getReducerMapJoinProc() {
            return new ReducerProc();
        }

        public static NodeProcessor getDefaultProc() {
            return new DefaultProc();
        }
    }

    protected class ReduceSinkJoinDeDuplicateProcCtx
    extends AbstractCorrelationProcCtx {
        public ReduceSinkJoinDeDuplicateProcCtx(ParseContext pctx) {
            super(pctx);
        }
    }
}

