You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/01/03 03:00:23 UTC

[spark] branch branch-2.2 updated: [SPARK-25591][PYSPARK][SQL][BRANCH-2.2] Avoid overwriting deserialized accumulator

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-2.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.2 by this push:
     new 6f435e9  [SPARK-25591][PYSPARK][SQL][BRANCH-2.2] Avoid overwriting deserialized accumulator
6f435e9 is described below

commit 6f435e9f76a389ea1cbd65fd5a629ebff9c6b229
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Thu Jan 3 11:00:10 2019 +0800

    [SPARK-25591][PYSPARK][SQL][BRANCH-2.2] Avoid overwriting deserialized accumulator
    
    ## What changes were proposed in this pull request?
    
    If we use accumulators in more than one UDFs, it is possible to overwrite deserialized accumulators and its values. We should check if an accumulator was deserialized before overwriting it in accumulator registry.
    
    ## How was this patch tested?
    
    Added test.
    
    Closes #23433 from viirya/SPARK-25591-2.2.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/accumulators.py | 12 ++++++++----
 python/pyspark/sql/tests.py    | 25 +++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index bc0be07..5d46b92 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -110,10 +110,14 @@ _accumulatorRegistry = {}
 
 def _deserialize_accumulator(aid, zero_value, accum_param):
     from pyspark.accumulators import _accumulatorRegistry
-    accum = Accumulator(aid, zero_value, accum_param)
-    accum._deserialized = True
-    _accumulatorRegistry[aid] = accum
-    return accum
+    # If this certain accumulator was deserialized, don't overwrite it.
+    if aid in _accumulatorRegistry:
+        return _accumulatorRegistry[aid]
+    else:
+        accum = Accumulator(aid, zero_value, accum_param)
+        accum._deserialized = True
+        _accumulatorRegistry[aid] = accum
+        return accum
 
 
 class Accumulator(object):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0926112..083bb19 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2220,6 +2220,31 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
         self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
 
+    # SPARK-25591
+    def test_same_accumulator_in_udfs(self):
+        from pyspark.sql.functions import udf
+
+        data_schema = StructType([StructField("a", IntegerType(), True),
+                                  StructField("b", IntegerType(), True)])
+        data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
+
+        test_accum = self.sc.accumulator(0)
+
+        def first_udf(x):
+            test_accum.add(1)
+            return x
+
+        def second_udf(x):
+            test_accum.add(100)
+            return x
+
+        func_udf = udf(first_udf, IntegerType())
+        func_udf2 = udf(second_udf, IntegerType())
+        data = data.withColumn("out1", func_udf(data["a"]))
+        data = data.withColumn("out2", func_udf2(data["b"]))
+        data.collect()
+        self.assertEqual(test_accum.value, 101)
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org