You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pig.apache.org by zl...@apache.org on 2017/03/31 08:51:35 UTC

svn commit: r1789631 - in /pig/branches/spark: src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/relationalOperators/ src/org/apache/pig/backend/hadoop/executionengine/spark/ src/org/apache/pig/backend/hadoop/executionengine/spark/convert...

Author: zly
Date: Fri Mar 31 08:51:35 2017
New Revision: 1789631

URL: http://svn.apache.org/viewvc?rev=1789631&view=rev
Log:
PIG-4858:Implement Skewed join for spark engine(Xianda via Liyun)

Modified:
    pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/relationalOperators/POPoissonSample.java
    pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java
    pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java
    pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/operator/POPoissonSampleSpark.java
    pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkCompiler.java
    pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkOperator.java
    pig/branches/spark/test/org/apache/pig/test/TestSkewedJoin.java

Modified: pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/relationalOperators/POPoissonSample.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/relationalOperators/POPoissonSample.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/relationalOperators/POPoissonSample.java (original)
+++ pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/relationalOperators/POPoissonSample.java Fri Mar 31 08:51:35 2017
@@ -36,38 +36,38 @@ public class POPoissonSample extends Phy
     // minimum number of samples) and the confidence set to 95%
     public static final int DEFAULT_SAMPLE_RATE = 17;
 
-    private int sampleRate = 0;
+    protected int sampleRate = 0;
 
-    private float heapPerc = 0f;
+    protected float heapPerc = 0f;
 
-    private Long totalMemory;
+    protected Long totalMemory;
 
-    private transient boolean initialized;
+    protected transient boolean initialized;
 
     // num of rows skipped so far
-    private transient int numSkipped;
+    protected transient int numSkipped;
 
     // num of rows sampled so far
-    private transient int numRowsSampled;
+    protected transient int numRowsSampled;
 
     // average size of tuple in memory, for tuples sampled
-    private transient long avgTupleMemSz;
+    protected transient long avgTupleMemSz;
 
     // current row number
-    private transient long rowNum;
+    protected transient long rowNum;
 
     // number of tuples to skip after each sample
-    private transient long skipInterval;
+    protected transient long skipInterval;
 
     // bytes in input to skip after every sample.
     // divide this by avgTupleMemSize to get skipInterval
-    private transient long memToSkipPerSample;
+    protected transient long memToSkipPerSample;
 
     // has the special row with row number information been returned
-    private transient boolean numRowSplTupleReturned;
+    protected transient boolean numRowSplTupleReturned;
 
     // new Sample result
-    private transient Result newSample;
+    protected transient Result newSample;
 
     public POPoissonSample(OperatorKey k, int rp, int sr, float hp, long tm) {
         super(k, rp, null);
@@ -204,7 +204,7 @@ public class POPoissonSample extends Phy
      * and recalculate skipInterval
      * @param t - tuple
      */
-    private void updateSkipInterval(Tuple t) {
+    protected void updateSkipInterval(Tuple t) {
         avgTupleMemSz =
             ((avgTupleMemSz*numRowsSampled) + t.getMemorySize())/(numRowsSampled + 1);
         skipInterval = memToSkipPerSample/avgTupleMemSz;
@@ -224,7 +224,7 @@ public class POPoissonSample extends Phy
      * @return - Tuple appended with special marker string column, num-rows column
      * @throws ExecException
      */
-    private Result createNumRowTuple(Tuple sample) throws ExecException {
+    protected Result createNumRowTuple(Tuple sample) throws ExecException {
         int sz = (sample == null) ? 0 : sample.size();
         Tuple t = mTupleFactory.newTuple(sz + 2);
 

Modified: pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java (original)
+++ pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java Fri Mar 31 08:51:35 2017
@@ -41,6 +41,8 @@ import org.apache.pig.backend.hadoop.exe
 import org.apache.pig.backend.hadoop.executionengine.mapReduceLayer.PhyPlanSetter;
 import org.apache.pig.backend.hadoop.executionengine.mapReduceLayer.UDFFinishVisitor;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.PhysicalOperator;
+import org.apache.pig.backend.hadoop.executionengine.physicalLayer.expressionOperators.ConstantExpression;
+import org.apache.pig.backend.hadoop.executionengine.physicalLayer.plans.PhyPlanVisitor;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.plans.PhysicalPlan;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POBroadcastSpark;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POMergeJoin;
@@ -49,14 +51,17 @@ import org.apache.pig.backend.hadoop.exe
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.util.PlanHelper;
 import org.apache.pig.backend.hadoop.executionengine.spark.converter.FRJoinConverter;
 import org.apache.pig.backend.hadoop.executionengine.spark.converter.RDDConverter;
+import org.apache.pig.backend.hadoop.executionengine.spark.converter.SkewedJoinConverter;
 import org.apache.pig.backend.hadoop.executionengine.spark.operator.NativeSparkOperator;
 import org.apache.pig.backend.hadoop.executionengine.spark.operator.POJoinGroupSpark;
+import org.apache.pig.backend.hadoop.executionengine.spark.operator.POPoissonSampleSpark;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOpPlanVisitor;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOperPlan;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOperator;
 import org.apache.pig.data.Tuple;
 import org.apache.pig.impl.PigContext;
 import org.apache.pig.impl.plan.DependencyOrderWalker;
+import org.apache.pig.impl.plan.DepthFirstWalker;
 import org.apache.pig.impl.plan.OperatorKey;
 import org.apache.pig.impl.plan.VisitorException;
 import org.apache.pig.newplan.logical.relational.LOJoin;
@@ -285,6 +290,11 @@ public class JobGraphBuilder extends Spa
                 setReplicatedInputs(physicalOperator, (FRJoinConverter) converter);
             }
 
+            if (sparkOperator.isSkewedJoin() && converter instanceof SkewedJoinConverter) {
+                SkewedJoinConverter skewedJoinConverter = (SkewedJoinConverter) converter;
+                skewedJoinConverter.setSkewedJoinPartitionFile(sparkOperator.getSkewedJoinPartitionFile());
+            }
+            adjustRuntimeParallelismForSkewedJoin(physicalOperator, sparkOperator, allPredRDDs);
             nextRDD = converter.convert(allPredRDDs, physicalOperator);
 
             if (nextRDD == null) {
@@ -373,4 +383,71 @@ public class JobGraphBuilder extends Spa
         seenJobIDs.addAll(unseenJobIDs);
         return unseenJobIDs;
     }
+
+
+    /**
+     * if the parallelism of skewed join is NOT specified by user in the script when sampling,
+     * set a default parallelism for sampling
+     *
+     * @param physicalOperator
+     * @param sparkOperator
+     * @param allPredRDDs
+     * @throws VisitorException
+     */
+    private void adjustRuntimeParallelismForSkewedJoin(PhysicalOperator physicalOperator,
+                                                       SparkOperator sparkOperator,
+                                                       List<RDD<Tuple>> allPredRDDs) throws VisitorException {
+        // We need to calculate the final number of reducers of the next job (skew-join)
+        // adjust parallelism of ConstantExpression
+        if (sparkOperator.isSampler() && sparkPlan.getSuccessors(sparkOperator) != null
+                && physicalOperator instanceof POPoissonSampleSpark) {
+            // set the runtime #reducer of the next job as the #partition
+
+            int defaultParallelism = SparkUtil.getParallelism(allPredRDDs, physicalOperator);
+
+            ParallelConstantVisitor visitor =
+                    new ParallelConstantVisitor(sparkOperator.physicalPlan, defaultParallelism);
+            visitor.visit();
+        }
+    }
+
+    /**
+     * here, we don't reuse MR/Tez's ParallelConstantVisitor
+     * To automatic adjust reducer parallelism for skewed join, we only adjust the
+     * ConstantExpression operator after POPoissionSampleSpark operator
+     */
+    private static class ParallelConstantVisitor extends PhyPlanVisitor {
+
+        private int rp;
+        private boolean replaced = false;
+        private boolean isAfterSampleOperator = false;
+
+        public ParallelConstantVisitor(PhysicalPlan plan, int rp) {
+            super(plan, new DepthFirstWalker<PhysicalOperator, PhysicalPlan>(
+                    plan));
+            this.rp = rp;
+        }
+
+        @Override
+        public void visitConstant(ConstantExpression cnst) throws VisitorException {
+            if (isAfterSampleOperator && cnst.getRequestedParallelism() == -1) {
+                Object obj = cnst.getValue();
+                if (obj instanceof Integer) {
+                    if (replaced) {
+                        // sample job should have only one ConstantExpression
+                        throw new VisitorException("Invalid reduce plan: more " +
+                                "than one ConstantExpression found in sampling job");
+                    }
+                    cnst.setValue(rp);
+                    cnst.setRequestedParallelism(rp);
+                    replaced = true;
+                }
+            }
+        }
+
+        @Override
+        public void visitPoissonSampleSpark(POPoissonSampleSpark po) {
+            isAfterSampleOperator = true;
+        }
+    }
 }

Modified: pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java (original)
+++ pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java Fri Mar 31 08:51:35 2017
@@ -19,9 +19,20 @@ package org.apache.pig.backend.hadoop.ex
 
 import java.io.IOException;
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
 
+import com.google.common.collect.Maps;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.pig.data.DataBag;
+import org.apache.pig.impl.builtin.PartitionSkewedKeys;
+import org.apache.pig.impl.util.Pair;
+import org.apache.spark.Partitioner;
+import org.apache.spark.broadcast.Broadcast;
 import scala.Tuple2;
 import scala.runtime.AbstractFunction1;
 
@@ -47,9 +58,18 @@ import org.apache.spark.rdd.RDD;
 public class SkewedJoinConverter implements
         RDDConverter<Tuple, Tuple, POSkewedJoin>, Serializable {
 
+    private static Log log = LogFactory.getLog(SkewedJoinConverter.class);
+
     private POLocalRearrange[] LRs;
     private POSkewedJoin poSkewedJoin;
 
+    private String skewedJoinPartitionFile;
+
+    public void setSkewedJoinPartitionFile(String partitionFile) {
+        skewedJoinPartitionFile = partitionFile;
+    }
+
+
     @Override
     public RDD<Tuple> convert(List<RDD<Tuple>> predecessors,
                               POSkewedJoin poSkewedJoin) throws IOException {
@@ -64,28 +84,33 @@ public class SkewedJoinConverter impleme
         RDD<Tuple> rdd1 = predecessors.get(0);
         RDD<Tuple> rdd2 = predecessors.get(1);
 
-        // make (key, value) pairs, key has type IndexedKey, value has type Tuple
-        RDD<Tuple2<IndexedKey, Tuple>> rdd1Pair = rdd1.map(new ExtractKeyFunction(
-                this, 0), SparkUtil.<IndexedKey, Tuple>getTuple2Manifest());
-        RDD<Tuple2<IndexedKey, Tuple>> rdd2Pair = rdd2.map(new ExtractKeyFunction(
-                this, 1), SparkUtil.<IndexedKey, Tuple>getTuple2Manifest());
-
-        // join fn is present in JavaPairRDD class ..
-        JavaPairRDD<IndexedKey, Tuple> rdd1Pair_javaRDD = new JavaPairRDD<IndexedKey, Tuple>(
-                rdd1Pair, SparkUtil.getManifest(IndexedKey.class),
+        Broadcast<List<Tuple>> keyDist = SparkUtil.getBroadcastedVars().get(skewedJoinPartitionFile);
+
+        // if no keyDist,  we need  defaultParallelism
+        Integer defaultParallelism = SparkUtil.getParallelism(predecessors, poSkewedJoin);
+
+        // with partition id
+        SkewPartitionIndexKeyFunction skewFun = new SkewPartitionIndexKeyFunction(this, keyDist, defaultParallelism);
+        RDD<Tuple2<PartitionIndexedKey, Tuple>> skewIdxKeyRDD = rdd1.map(skewFun,
+                SparkUtil.<PartitionIndexedKey, Tuple>getTuple2Manifest());
+
+        // Tuple2 RDD to Pair RDD
+        JavaPairRDD<PartitionIndexedKey, Tuple> skewIndexedJavaPairRDD = new JavaPairRDD<PartitionIndexedKey, Tuple>(
+                skewIdxKeyRDD, SparkUtil.getManifest(PartitionIndexedKey.class),
                 SparkUtil.getManifest(Tuple.class));
-        JavaPairRDD<IndexedKey, Tuple> rdd2Pair_javaRDD = new JavaPairRDD<IndexedKey, Tuple>(
-                rdd2Pair, SparkUtil.getManifest(IndexedKey.class),
+
+        // with partition id
+        StreamPartitionIndexKeyFunction streamFun = new StreamPartitionIndexKeyFunction(this, keyDist, defaultParallelism);
+        JavaRDD<Tuple2<PartitionIndexedKey, Tuple>> streamIdxKeyJavaRDD = rdd2.toJavaRDD().flatMap(streamFun);
+
+        // Tuple2 RDD to Pair RDD
+        JavaPairRDD<PartitionIndexedKey, Tuple> streamIndexedJavaPairRDD = new JavaPairRDD<PartitionIndexedKey, Tuple>(
+                streamIdxKeyJavaRDD.rdd(), SparkUtil.getManifest(PartitionIndexedKey.class),
                 SparkUtil.getManifest(Tuple.class));
 
-        int parallelism = SparkUtil.getParallelism(predecessors, poSkewedJoin);
-        // join() returns (key, (t1, t2)) where (key, t1) is in this and (key, t2) is in other
-        JavaPairRDD<IndexedKey, Tuple2<Tuple, Tuple>> result_KeyValue = rdd1Pair_javaRDD
-                .join(rdd2Pair_javaRDD, parallelism);
-
-        // map to get JavaRDD<Tuple> from JAvaPairRDD<IndexedKey, Tuple2<Tuple, Tuple>> by
-        // ignoring the key (of type IndexedKey) and appending the values (the
-        // Tuples)
+        JavaPairRDD<PartitionIndexedKey, Tuple2<Tuple, Tuple>> result_KeyValue = skewIndexedJavaPairRDD
+                .join(streamIndexedJavaPairRDD, buildPartitioner(keyDist, defaultParallelism));
+
         JavaRDD<Tuple> result = result_KeyValue
                 .mapPartitions(new ToValueFunction());
 
@@ -115,68 +140,24 @@ public class SkewedJoinConverter impleme
         return new OperatorKey(poSkewedJoin.getOperatorKey().scope, NodeIdGenerator.getGenerator().getNextNodeId(poSkewedJoin.getOperatorKey().scope));
     }
 
-    private static class ExtractKeyFunction extends
-            AbstractFunction1<Tuple, Tuple2<IndexedKey, Tuple>> implements
-            Serializable {
-
-        private final SkewedJoinConverter poSkewedJoin;
-        private final int LR_index; // 0 for left table, 1 for right table
-
-        public ExtractKeyFunction(SkewedJoinConverter poSkewedJoin, int LR_index) {
-            this.poSkewedJoin = poSkewedJoin;
-            this.LR_index = LR_index;
-        }
-
-        @Override
-        public Tuple2<IndexedKey, Tuple> apply(Tuple tuple) {
-
-            // attach tuple to LocalRearrange
-            poSkewedJoin.LRs[LR_index].attachInput(tuple);
-
-            try {
-                // getNextTuple() returns the rearranged tuple
-                Result lrOut = poSkewedJoin.LRs[LR_index].getNextTuple();
-
-                // If tuple is (AA, 5) and key index is $1, then it lrOut is 0 5
-                // (AA), so get(1) returns key
-                Byte index = (Byte)((Tuple) lrOut.result).get(0);
-                Object key = ((Tuple) lrOut.result).get(1);
-                IndexedKey indexedKey = new IndexedKey(index,key);
-                Tuple value = tuple;
-
-                // make a (key, value) pair
-                Tuple2<IndexedKey, Tuple> tuple_KeyValue = new Tuple2<IndexedKey, Tuple>(
-                        indexedKey, value);
-
-                return tuple_KeyValue;
-
-            } catch (Exception e) {
-                System.out.print(e);
-                return null;
-            }
-        }
-
-    }
-
-    private static class ToValueFunction
-            implements
-            FlatMapFunction<Iterator<Tuple2<IndexedKey, Tuple2<Tuple, Tuple>>>, Tuple>, Serializable {
+    private static class ToValueFunction implements
+            FlatMapFunction<Iterator<Tuple2<PartitionIndexedKey, Tuple2<Tuple, Tuple>>>, Tuple>, Serializable {
 
         private class Tuple2TransformIterable implements Iterable<Tuple> {
 
-            Iterator<Tuple2<IndexedKey, Tuple2<Tuple, Tuple>>> in;
+            Iterator<Tuple2<PartitionIndexedKey, Tuple2<Tuple, Tuple>>> in;
 
             Tuple2TransformIterable(
-                    Iterator<Tuple2<IndexedKey, Tuple2<Tuple, Tuple>>> input) {
+                    Iterator<Tuple2<PartitionIndexedKey, Tuple2<Tuple, Tuple>>> input) {
                 in = input;
             }
 
             public Iterator<Tuple> iterator() {
-                return new IteratorTransform<Tuple2<IndexedKey, Tuple2<Tuple, Tuple>>, Tuple>(
+                return new IteratorTransform<Tuple2<PartitionIndexedKey, Tuple2<Tuple, Tuple>>, Tuple>(
                         in) {
                     @Override
                     protected Tuple transform(
-                            Tuple2<IndexedKey, Tuple2<Tuple, Tuple>> next) {
+                            Tuple2<PartitionIndexedKey, Tuple2<Tuple, Tuple>> next) {
                         try {
 
                             Tuple leftTuple = next._2._1;
@@ -194,9 +175,9 @@ public class SkewedJoinConverter impleme
                                 result.set(i + leftTuple.size(),
                                         rightTuple.get(i));
 
-                            System.out.println("MJC: Result = "
-                                    + result.toDelimitedString(" "));
-
+                            if (log.isDebugEnabled()) {
+                                log.debug("MJC: Result = " + result.toDelimitedString(" "));
+                            }
                             return result;
 
                         } catch (Exception e) {
@@ -210,8 +191,326 @@ public class SkewedJoinConverter impleme
 
         @Override
         public Iterable<Tuple> call(
-                Iterator<Tuple2<IndexedKey, Tuple2<Tuple, Tuple>>> input) {
+                Iterator<Tuple2<PartitionIndexedKey, Tuple2<Tuple, Tuple>>> input) {
             return new Tuple2TransformIterable(input);
         }
     }
+
+    /**
+     * Utility function.
+     * 1. Get parallelism
+     * 2. build a key distribution map from the broadcasted key distribution file
+     *
+     * @param keyDist
+     * @param totalReducers
+     * @return
+     */
+    private static Map<Tuple, Pair<Integer, Integer>> loadKeyDistribution(Broadcast<List<Tuple>> keyDist,
+                                                                          Integer[] totalReducers) {
+        Map<Tuple, Pair<Integer, Integer>> reducerMap = new HashMap<>();
+        totalReducers[0] = -1; // set a default value
+
+        if (keyDist == null || keyDist.value() == null || keyDist.value().size() == 0) {
+            // this could happen if sampling is empty
+            log.warn("Empty dist file: ");
+            return reducerMap;
+        }
+
+        try {
+            final TupleFactory tf = TupleFactory.getInstance();
+
+            Tuple t = keyDist.value().get(0);
+
+            Map<String, Object> distMap = (Map<String, Object>) t.get(0);
+            DataBag partitionList = (DataBag) distMap.get(PartitionSkewedKeys.PARTITION_LIST);
+
+            totalReducers[0] = Integer.valueOf("" + distMap.get(PartitionSkewedKeys.TOTAL_REDUCERS));
+
+            Iterator<Tuple> it = partitionList.iterator();
+            while (it.hasNext()) {
+                Tuple idxTuple = it.next();
+                Integer maxIndex = (Integer) idxTuple.get(idxTuple.size() - 1);
+                Integer minIndex = (Integer) idxTuple.get(idxTuple.size() - 2);
+                // Used to replace the maxIndex with the number of reducers
+                if (maxIndex < minIndex) {
+                    maxIndex = totalReducers[0] + maxIndex;
+                }
+
+                // remove the last 2 fields of the tuple, i.e: minIndex and maxIndex and store
+                // it in the reducer map
+                Tuple keyTuple = tf.newTuple();
+                for (int i = 0; i < idxTuple.size() - 2; i++) {
+                    keyTuple.append(idxTuple.get(i));
+                }
+
+                // number of reducers
+                Integer cnt = maxIndex - minIndex;
+                reducerMap.put(keyTuple, new Pair(minIndex, cnt));
+            }
+
+        } catch (ExecException e) {
+            log.warn(e.getMessage());
+        }
+
+        return reducerMap;
+    }
+
+    private static class PartitionIndexedKey extends IndexedKey {
+        // for user defined partitioner
+        int partitionId;
+
+        public PartitionIndexedKey(byte index, Object key) {
+            super(index, key);
+            partitionId = -1;
+        }
+
+        public PartitionIndexedKey(byte index, Object key, int pid) {
+            super(index, key);
+            partitionId = pid;
+        }
+
+        public int getPartitionId() {
+            return partitionId;
+        }
+
+        private void setPartitionId(int pid) {
+            partitionId = pid;
+        }
+
+        @Override
+        public String toString() {
+            return "PartitionIndexedKey{" +
+                    "index=" + getIndex() +
+                    ", partitionId=" + getPartitionId() +
+                    ", key=" + getKey() +
+                    '}';
+        }
+    }
+
+    /**
+     * append a Partition id to the records from skewed table.
+     * so that the SkewedJoinPartitioner can send skewed records to different reducer
+     * <p>
+     * see: https://wiki.apache.org/pig/PigSkewedJoinSpec
+     */
+    private static class SkewPartitionIndexKeyFunction extends
+            AbstractFunction1<Tuple, Tuple2<PartitionIndexedKey, Tuple>> implements
+            Serializable {
+
+        private final SkewedJoinConverter poSkewedJoin;
+
+        private final Broadcast<List<Tuple>> keyDist;
+        private final Integer defaultParallelism;
+
+        transient private boolean initialized = false;
+        transient protected Map<Tuple, Pair<Integer, Integer>> reducerMap;
+        transient private Integer parallelism = -1;
+        transient private Map<Tuple, Integer> currentIndexMap;
+
+        public SkewPartitionIndexKeyFunction(SkewedJoinConverter poSkewedJoin,
+                                             Broadcast<List<Tuple>> keyDist,
+                                             Integer defaultParallelism) {
+            this.poSkewedJoin = poSkewedJoin;
+            this.keyDist = keyDist;
+            this.defaultParallelism = defaultParallelism;
+        }
+
+        @Override
+        public Tuple2<PartitionIndexedKey, Tuple> apply(Tuple tuple) {
+            // attach tuple to LocalRearrange
+            poSkewedJoin.LRs[0].attachInput(tuple);
+
+            try {
+                Result lrOut = poSkewedJoin.LRs[0].getNextTuple();
+
+                // If tuple is (AA, 5) and key index is $1, then it lrOut is 0 5
+                // (AA), so get(1) returns key
+                Byte index = (Byte) ((Tuple) lrOut.result).get(0);
+                Object key = ((Tuple) lrOut.result).get(1);
+
+                Tuple keyTuple = (Tuple) key;
+                int partitionId = getPartitionId(keyTuple);
+                PartitionIndexedKey pIndexKey = new PartitionIndexedKey(index, keyTuple, partitionId);
+
+                // make a (key, value) pair
+                Tuple2<PartitionIndexedKey, Tuple> tuple_KeyValue = new Tuple2<PartitionIndexedKey, Tuple>(
+                        pIndexKey,
+                        tuple);
+
+                return tuple_KeyValue;
+            } catch (Exception e) {
+                System.out.print(e);
+                return null;
+            }
+        }
+
+        private Integer getPartitionId(Tuple keyTuple) {
+            if (!initialized) {
+                Integer[] reducers = new Integer[1];
+                reducerMap = loadKeyDistribution(keyDist, reducers);
+                parallelism = reducers[0];
+
+                if (parallelism <= 0) {
+                    parallelism = defaultParallelism;
+                }
+
+                currentIndexMap = Maps.newHashMap();
+
+                initialized = true;
+            }
+
+            // for partition table, compute the index based on the sampler output
+            Pair<Integer, Integer> indexes;
+            Integer curIndex = -1;
+
+            indexes = reducerMap.get(keyTuple);
+
+            // if the reducerMap does not contain the key return -1 so that the
+            // partitioner will do the default hash based partitioning
+            if (indexes == null) {
+                return -1;
+            }
+
+            if (currentIndexMap.containsKey(keyTuple)) {
+                curIndex = currentIndexMap.get(keyTuple);
+            }
+
+            if (curIndex >= (indexes.first + indexes.second) || curIndex == -1) {
+                curIndex = indexes.first;
+            } else {
+                curIndex++;
+            }
+
+            // set it in the map
+            currentIndexMap.put(keyTuple, curIndex);
+            return (curIndex % parallelism);
+        }
+
+    }
+
+    /**
+     * POPartitionRearrange is not used in spark mode now,
+     * Here, use flatMap and CopyStreamWithPidFunction to copy the
+     * stream records to the multiple reducers
+     * <p>
+     * see: https://wiki.apache.org/pig/PigSkewedJoinSpec
+     */
+    private static class StreamPartitionIndexKeyFunction implements FlatMapFunction<Tuple, Tuple2<PartitionIndexedKey, Tuple>> {
+
+        private SkewedJoinConverter poSkewedJoin;
+        private final Broadcast<List<Tuple>> keyDist;
+        private final Integer defaultParallelism;
+
+        private transient boolean initialized = false;
+        protected transient Map<Tuple, Pair<Integer, Integer>> reducerMap;
+        private transient Integer parallelism;
+
+        public StreamPartitionIndexKeyFunction(SkewedJoinConverter poSkewedJoin,
+                                               Broadcast<List<Tuple>> keyDist,
+                                               Integer defaultParallelism) {
+            this.poSkewedJoin = poSkewedJoin;
+            this.keyDist = keyDist;
+            this.defaultParallelism = defaultParallelism;
+        }
+
+        public Iterable<Tuple2<PartitionIndexedKey, Tuple>> call(Tuple tuple) throws Exception {
+            if (!initialized) {
+                Integer[] reducers = new Integer[1];
+                reducerMap = loadKeyDistribution(keyDist, reducers);
+                parallelism = reducers[0];
+                if (parallelism <= 0) {
+                    parallelism = defaultParallelism;
+                }
+                initialized = true;
+            }
+
+            // streamed table
+            poSkewedJoin.LRs[1].attachInput(tuple);
+            Result lrOut = poSkewedJoin.LRs[1].getNextTuple();
+
+            Byte index = (Byte) ((Tuple) lrOut.result).get(0);
+            Tuple key = (Tuple) ((Tuple) lrOut.result).get(1);
+
+            ArrayList<Tuple2<PartitionIndexedKey, Tuple>> l = new ArrayList();
+            Pair<Integer, Integer> indexes = reducerMap.get(key);
+
+            // For non skewed keys, we set the partition index to be -1
+            // so that the partitioner will do the default hash based partitioning
+            if (indexes == null) {
+                indexes = new Pair<>(-1, 0);
+            }
+
+            for (Integer reducerIdx = indexes.first, cnt = 0; cnt <= indexes.second; reducerIdx++, cnt++) {
+                if (reducerIdx >= parallelism) {
+                    reducerIdx = 0;
+                }
+
+                // set the partition index
+                int partitionId = reducerIdx.intValue();
+                PartitionIndexedKey pIndexKey = new PartitionIndexedKey(index, key, partitionId);
+
+                l.add(new Tuple2(pIndexKey, tuple));
+            }
+
+            return l;
+        }
+    }
+
+    /**
+     * user defined spark partitioner for skewed join
+     */
+    private static class SkewedJoinPartitioner extends Partitioner {
+        private int numPartitions;
+
+        public SkewedJoinPartitioner(int parallelism) {
+            numPartitions = parallelism;
+        }
+
+        @Override
+        public int numPartitions() {
+            return numPartitions;
+        }
+
+        @Override
+        public int getPartition(Object IdxKey) {
+            if (IdxKey instanceof PartitionIndexedKey) {
+                int partitionId = ((PartitionIndexedKey) IdxKey).getPartitionId();
+                if (partitionId >= 0) {
+                    return partitionId;
+                }
+            }
+
+            //else: by default using hashcode
+            Tuple key = (Tuple) ((PartitionIndexedKey) IdxKey).getKey();
+
+
+            int code = key.hashCode() % numPartitions;
+            if (code >= 0) {
+                return code;
+            } else {
+                return code + numPartitions;
+            }
+        }
+    }
+
+    /**
+     * use parallelism from keyDist or the default parallelism to
+     * create user defined partitioner
+     *
+     * @param keyDist
+     * @param defaultParallelism
+     * @return
+     */
+    private SkewedJoinPartitioner buildPartitioner(Broadcast<List<Tuple>> keyDist, Integer defaultParallelism) {
+        Integer parallelism = -1;
+        Integer[] reducers = new Integer[1];
+        loadKeyDistribution(keyDist, reducers);
+        parallelism = reducers[0];
+        if (parallelism <= 0) {
+            parallelism = defaultParallelism;
+        }
+
+        return new SkewedJoinPartitioner(parallelism);
+    }
+
 }

Modified: pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/operator/POPoissonSampleSpark.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/operator/POPoissonSampleSpark.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/operator/POPoissonSampleSpark.java (original)
+++ pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/operator/POPoissonSampleSpark.java Fri Mar 31 08:51:35 2017
@@ -24,52 +24,16 @@ import org.apache.pig.backend.hadoop.exe
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.PhysicalOperator;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.plans.PhyPlanVisitor;
+import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POPoissonSample;
 import org.apache.pig.data.Tuple;
 import org.apache.pig.impl.builtin.PoissonSampleLoader;
 import org.apache.pig.impl.plan.OperatorKey;
 import org.apache.pig.impl.plan.VisitorException;
 
-public class POPoissonSampleSpark extends PhysicalOperator {
+public class POPoissonSampleSpark extends POPoissonSample {
     private static final Log LOG = LogFactory.getLog(POPoissonSampleSpark.class);
     private static final long serialVersionUID = 1L;
 
-    // 17 is not a magic number. It can be obtained by using a poisson
-    // cumulative distribution function with the mean set to 10 (empirically,
-    // minimum number of samples) and the confidence set to 95%
-    public static final int DEFAULT_SAMPLE_RATE = 17;
-
-    private int sampleRate = 0;
-
-    private float heapPerc = 0f;
-
-    private Long totalMemory;
-
-    private transient boolean initialized;
-
-    // num of rows sampled so far
-    private transient int numRowsSampled;
-
-    // average size of tuple in memory, for tuples sampled
-    private transient long avgTupleMemSz;
-
-    // current row number
-    private transient long rowNum;
-
-    // number of tuples to skip after each sample
-    private transient long skipInterval;
-
-    // number of tuples which have been skipped.
-    private transient long numSkipped = 0;
-
-    // bytes in input to skip after every sample.
-    // divide this by avgTupleMemSize to get skipInterval
-    private transient long memToSkipPerSample;
-
-    // has the special row with row number information been returned
-    private transient boolean numRowSplTupleReturned;
-
-    // new Sample result
-    private transient Result newSample;
 
     // Only for Spark
     private boolean endOfInput = false;
@@ -83,19 +47,9 @@ public class POPoissonSampleSpark extend
     }
 
     public POPoissonSampleSpark(OperatorKey k, int rp, int sr, float hp, long tm) {
-        super(k, rp, null);
-        sampleRate = sr;
-        heapPerc = hp;
-        if (tm != -1) {
-            totalMemory = tm;
-        }
+        super(k, rp, sr, hp, tm);
     }
 
-    @Override
-    public Tuple illustratorMarkup(Object in, Object out, int eqClassIndex) {
-        // Auto-generated method stub
-        return null;
-    }
 
     @Override
     public void visit(PhyPlanVisitor v) throws VisitorException {
@@ -126,9 +80,15 @@ public class POPoissonSampleSpark extend
         Result res;
         res = processInput();
 
-        // if reaches at the end, pick last sampled record and return
-        if (this.isEndOfInput() && newSample != null) {
-            return createNumRowTuple((Tuple)newSample.result);
+        // if reaches at the end, pick a record and return
+        if (this.isEndOfInput()) {
+            // if skip enough, and the last record is OK.
+            if ( numSkipped == skipInterval
+                    && res.returnStatus == POStatus.STATUS_OK) {
+                return createNumRowTuple((Tuple) res.result);
+            } else if (newSample != null) {
+                return createNumRowTuple((Tuple) newSample.result);
+            }
         }
 
         // just return to read next record from input
@@ -165,60 +125,7 @@ public class POPoissonSampleSpark extend
     }
 
     @Override
-    public boolean supportsMultipleInputs() {
-        return false;
-    }
-
-    @Override
-    public boolean supportsMultipleOutputs() {
-        return false;
-    }
-
-    @Override
     public String name() {
-        return getAliasString() + "PoissonSample - " + mKey.toString();
-    }
-
-    /**
-     * Update the average tuple size base on newly sampled tuple t
-     * and recalculate skipInterval
-     *
-     * @param t - tuple
-     */
-    private void updateSkipInterval(Tuple t) {
-        avgTupleMemSz =
-                ((avgTupleMemSz * numRowsSampled) + t.getMemorySize()) / (numRowsSampled + 1);
-        skipInterval = memToSkipPerSample / avgTupleMemSz;
-
-        // skipping fewer number of rows the first few times, to reduce the
-        // probability of first tuples size (if much smaller than rest)
-        // resulting in very few samples being sampled. Sampling a little extra
-        // is OK
-        if (numRowsSampled < 5) {
-            skipInterval = skipInterval / (10 - numRowsSampled);
-        }
-
-        ++numRowsSampled;
-    }
-
-    /**
-     * @param sample - sample tuple
-     * @return - Tuple appended with special marker string column, num-rows column
-     * @throws ExecException
-     */
-    private Result createNumRowTuple(Tuple sample) throws ExecException {
-        int sz = (sample == null) ? 0 : sample.size();
-        Tuple t = mTupleFactory.newTuple(sz + 2);
-
-        if (sample != null) {
-            for (int i = 0; i < sample.size(); i++) {
-                t.set(i, sample.get(i));
-            }
-        }
-
-        t.set(sz, PoissonSampleLoader.NUMROWS_TUPLE_MARKER);
-        t.set(sz + 1, rowNum);
-        numRowSplTupleReturned = true;
-        return new Result(POStatus.STATUS_OK, t);
+        return getAliasString() + "PoissonSampleSpark - " + mKey.toString();
     }
 }

Modified: pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkCompiler.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkCompiler.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkCompiler.java (original)
+++ pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkCompiler.java Fri Mar 31 08:51:35 2017
@@ -25,8 +25,9 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
 import java.util.Properties;
+import java.util.Map;
+import java.util.Random;
 import java.util.Set;
 
 import org.apache.commons.logging.Log;
@@ -111,7 +112,7 @@ public class SparkCompiler extends PhyPl
     private static final Log LOG = LogFactory.getLog(SparkCompiler.class);
 
     private PigContext pigContext;
-    private Properties pigProperties;
+	private Properties pigProperties;
 
 	// The physicalPlan that is being compiled
 	private PhysicalPlan physicalPlan;
@@ -138,6 +139,7 @@ public class SparkCompiler extends PhyPl
 						physicalPlan));
 		this.physicalPlan = physicalPlan;
 		this.pigContext = pigContext;
+		this.pigProperties = pigContext.getProperties();
 		this.sparkPlan = new SparkOperPlan();
 		this.phyToSparkOpMap = new HashMap<PhysicalOperator, SparkOperator>();
 		this.udfFinder = new UDFFinder();
@@ -695,7 +697,23 @@ public class SparkCompiler extends PhyPl
     @Override
     public void visitSkewedJoin(POSkewedJoin op) throws VisitorException {
         try {
+            Random r = new Random();
+            String pigKeyDistFile = "pig.keyDistFile" + r.nextInt();
+            // firstly, build sample job
+            SparkOperator sampleSparkOp = getSkewedJoinSampleJob(op);
+
+            buildBroadcastForSkewedJoin(sampleSparkOp, pigKeyDistFile);
+
+            sampleSparkOp.markSampler();
+            sparkPlan.add(sampleSparkOp);
+
+            // secondly, build the join job.
             addToPlan(op);
+            curSparkOp.setSkewedJoinPartitionFile(pigKeyDistFile);
+
+            // do sampling job before join job
+            sparkPlan.connect(sampleSparkOp, curSparkOp);
+
             phyToSparkOpMap.put(op, curSparkOp);
         } catch (Exception e) {
             int errCode = 2034;
@@ -705,50 +723,6 @@ public class SparkCompiler extends PhyPl
         }
     }
 
-//    /**
-//     * currently use regular join to replace skewedJoin
-//     * Skewed join currently works with two-table inner join.
-//     * More info about pig SkewedJoin, See https://wiki.apache.org/pig/PigSkewedJoinSpec
-//     *
-//     * @param op
-//     * @throws VisitorException
-//     */
-//    @Override
-//    public void visitSkewedJoin(POSkewedJoin op) throws VisitorException {
-//        try {
-//			Random r = new Random();
-//			String pigKeyDistFile = "pig.keyDistFile" + r.nextInt();
-//            // firstly, build sample job
-//            SparkOperator sampleSparkOp = getSkewedJoinSampleJob(op);
-//
-//			buildBroadcastForSkewedJoin(sampleSparkOp, pigKeyDistFile);
-//
-//			sampleSparkOp.markSampler();
-//			sparkPlan.add(sampleSparkOp);
-//
-//			// secondly, build the join job.
-//			addToPlan(op);
-//			curSparkOp.setSkewedJoinPartitionFile(pigKeyDistFile);
-//
-//			// do sampling job before join job
-//			sparkPlan.connect(sampleSparkOp, curSparkOp);
-//
-//			phyToSparkOpMap.put(op, curSparkOp);
-//        } catch (Exception e) {
-//            int errCode = 2034;
-//            String msg = "Error compiling operator " +
-//                    op.getClass().getSimpleName();
-//            throw new SparkCompilerException(msg, errCode, PigException.BUG, e);
-//        }
-//    }
-
-/*    private void buildBroadcastForSkewedJoin(SparkOperator sampleSparkOp, String pigKeyDistFile) throws PlanException {
-
-        POBroadcastSpark poBroadcast = new POBroadcastSpark(new OperatorKey(scope, nig.getNextNodeId(scope)));
-        poBroadcast.setBroadcastedVariableName(pigKeyDistFile);
-        sampleSparkOp.physicalPlan.addAsLeaf(poBroadcast);
-    }*/
-
     @Override
     public void visitFRJoin(POFRJoin op) throws VisitorException {
 		try {
@@ -1506,6 +1480,18 @@ public class SparkCompiler extends PhyPl
         throw new PlanException(msg, errCode, PigException.BUG);
     }
 
+	/**
+	 * Add POBroadcastSpark operator to broadcast key distribution for SkewedJoin's sampling job
+	 * @param sampleSparkOp
+	 * @throws PlanException
+	 */
+	private void buildBroadcastForSkewedJoin(SparkOperator sampleSparkOp, String pigKeyDistFile) throws PlanException {
+
+		POBroadcastSpark poBroadcast = new POBroadcastSpark(new OperatorKey(scope, nig.getNextNodeId(scope)));
+		poBroadcast.setBroadcastedVariableName(pigKeyDistFile);
+		sampleSparkOp.physicalPlan.addAsLeaf(poBroadcast);
+	}
+
     /**
      * Create Sampling job for skewed join.
      */

Modified: pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkOperator.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkOperator.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkOperator.java (original)
+++ pig/branches/spark/src/org/apache/pig/backend/hadoop/executionengine/spark/plan/SparkOperator.java Fri Mar 31 08:51:35 2017
@@ -269,6 +269,10 @@ public class SparkOperator extends Opera
         }
     }
 
+	public boolean isSkewedJoin() {
+		return (skewedJoinPartitionFile != null);
+	}
+
     public void setRequestedParallelism(int requestedParallelism) {
         this.requestedParallelism = requestedParallelism;
     }

Modified: pig/branches/spark/test/org/apache/pig/test/TestSkewedJoin.java
URL: http://svn.apache.org/viewvc/pig/branches/spark/test/org/apache/pig/test/TestSkewedJoin.java?rev=1789631&r1=1789630&r2=1789631&view=diff
==============================================================================
--- pig/branches/spark/test/org/apache/pig/test/TestSkewedJoin.java (original)
+++ pig/branches/spark/test/org/apache/pig/test/TestSkewedJoin.java Fri Mar 31 08:51:35 2017
@@ -53,7 +53,6 @@ import org.apache.pig.impl.builtin.Parti
 import org.apache.pig.impl.logicalLayer.FrontendException;
 import org.apache.pig.test.utils.TestHelper;
 import org.junit.AfterClass;
-import org.junit.Assume;
 import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -306,12 +305,6 @@ public class TestSkewedJoin {
 
     @Test
     public void testSkewedJoinKeyPartition() throws IOException {
-        // This test relies on how the keys are distributed in Skew Join implementation.
-        // Spark engine currently implements skew join as regular join, and hence does
-        // not control key distribution.
-        // TODO: Enable this test when Spark engine implements Skew Join algorithm.
-        Assume.assumeTrue("Skip this test for Spark until PIG-4858 is resolved!",!Util.isSparkExecType(cluster.getExecType()));
-
         String outputDir = "testSkewedJoinKeyPartition";
         try{
              Util.deleteFile(cluster, outputDir);