You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/01/02 21:57:49 UTC

[2/4] git commit: Make Python function/line appear in the UI.

Make Python function/line appear in the UI.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/fec01664
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/fec01664
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/fec01664

Branch: refs/heads/master
Commit: fec01664a717c8ecf84eaf7a2523a62cd5d3b4b8
Parents: d812aee
Author: Tor Myklebust <tm...@gmail.com>
Authored: Sat Dec 28 23:34:16 2013 -0500
Committer: Tor Myklebust <tm...@gmail.com>
Committed: Sat Dec 28 23:34:16 2013 -0500

----------------------------------------------------------------------
 python/pyspark/rdd.py | 66 ++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 55 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/fec01664/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e..6fb4a7b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
 import os
 import sys
 import shlex
+import traceback
 from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
 
 __all__ = ["RDD"]
 
+def _extract_concise_traceback():
+    tb = traceback.extract_stack()
+    if len(tb) == 0:
+        return "I'm lost!"
+    # HACK:  This function is in a file called 'rdd.py' in the top level of
+    # everything PySpark.  Just trim off the directory name and assume
+    # everything in that tree is PySpark guts.
+    file, line, module, what = tb[len(tb) - 1]
+    sparkpath = os.path.dirname(file)
+    first_spark_frame = len(tb) - 1
+    for i in range(0, len(tb)):
+        file, line, fun, what = tb[i]
+        if file.startswith(sparkpath):
+            first_spark_frame = i
+            break
+    if first_spark_frame == 0:
+        file, line, fun, what = tb[0]
+        return "%s at %s:%d" % (fun, file, line)
+    sfile, sline, sfun, swhat = tb[first_spark_frame]
+    ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+    return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+    def __init__(self, sc):
+        self._traceback = _extract_concise_traceback()
+        self._context = sc
+
+    def __enter__(self):
+        global _spark_stack_depth
+        if _spark_stack_depth == 0:
+            self._context._jsc.setCallSite(self._traceback)
+        _spark_stack_depth += 1
+
+    def __exit__(self, type, value, tb):
+        global _spark_stack_depth
+        _spark_stack_depth -= 1
+        if _spark_stack_depth == 0:
+            self._context._jsc.setCallSite(None)
 
 class RDD(object):
     """
@@ -401,7 +442,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        bytesInJava = self._jrdd.collect().iterator()
+        with _JavaStackTrace(self.context) as st:
+          bytesInJava = self._jrdd.collect().iterator()
         return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
         # TODO(shivaram): Similar to the scala implementation, update the take 
         # method to scan multiple splits based on an estimate of how many elements 
         # we have per-split.
-        for partition in range(mapped._jrdd.splits().size()):
-            partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
-            partitionsToTake[0] = partition
-            iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
-            items.extend(mapped._collect_iterator_through_file(iterator))
-            if len(items) >= num:
-                break
+        with _JavaStackTrace(self.context) as st:
+            for partition in range(mapped._jrdd.splits().size()):
+                partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+                partitionsToTake[0] = partition
+                iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+                items.extend(mapped._collect_iterator_through_file(iterator))
+                if len(items) >= num:
+                    break
         return items[:num]
 
     def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
                 yield outputSerializer.dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
-        pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
-        partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
-                                                     id(partitionFunc))
+        with _JavaStackTrace(self.context) as st:
+            pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+            partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+                                                          id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
         rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if