You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/28 17:33:13 UTC
spark git commit: [SPARK-22370][SQL][PYSPARK] Config values should be
captured in Driver.
Repository: spark
Updated Branches:
refs/heads/master 683ffe062 -> 4c5269f1a
[SPARK-22370][SQL][PYSPARK] Config values should be captured in Driver.
## What changes were proposed in this pull request?
`ArrowEvalPythonExec` and `FlatMapGroupsInPandasExec` are refering config values of `SQLConf` in function for `mapPartitions`/`mapPartitionsInternal`, but we should capture them in Driver.
## How was this patch tested?
Added a test and existing tests.
Author: Takuya UESHIN <ue...@databricks.com>
Closes #19587 from ueshin/issues/SPARK-22370.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4c5269f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4c5269f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4c5269f1
Branch: refs/heads/master
Commit: 4c5269f1aa529e6a397b68d6dc409d89e32685bd
Parents: 683ffe0
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Sat Oct 28 18:33:09 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Sat Oct 28 18:33:09 2017 +0100
----------------------------------------------------------------------
python/pyspark/sql/tests.py | 20 ++++++++++++++++++++
.../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++++
.../execution/python/ArrowEvalPythonExec.scala | 6 ++++--
.../python/FlatMapGroupsInPandasExec.scala | 3 ++-
4 files changed, 32 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 98afae6..8ed37c9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3476,6 +3476,26 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
self.assertEquals(expected, ts)
+ def test_vectorized_udf_check_config(self):
+ from pyspark.sql.functions import pandas_udf, col
+ orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
+ self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
+ try:
+ df = self.spark.range(10, numPartitions=1)
+
+ @pandas_udf(returnType=LongType())
+ def check_records_per_batch(x):
+ self.assertTrue(x.size <= 3)
+ return x
+
+ result = df.select(check_records_per_batch(col("id")))
+ self.assertEquals(df.collect(), result.collect())
+ finally:
+ if orig_value is None:
+ self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
+ else:
+ self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
+
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedPySparkTestCase):
http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d21b4af..ddf2cbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -25,6 +25,12 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
self: PlanType =>
+ /**
+ * The active config object within the current scope.
+ * Note that if you want to refer config values during execution, you have to capture them
+ * in Driver and use the captured values in Executors.
+ * See [[SQLConf.get]] for more information.
+ */
def conf: SQLConf = SQLConf.get
def output: Seq[Attribute]
http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 0db463a..bcda2da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -61,6 +61,9 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends EvalPythonExec(udfs, output, child) {
+ private val batchSize = conf.arrowMaxRecordsPerBatch
+ private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+
protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
@@ -73,13 +76,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
.map { case (attr, i) => attr.withName(s"_$i") })
- val batchSize = conf.arrowMaxRecordsPerBatch
// DO NOT use iter.grouped(). See BatchIterator.
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)
val columnarBatchIter = new ArrowPythonRunner(
funcs, bufferSize, reuseWorker,
- PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone)
+ PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone)
.compute(batchIter, context.partitionId(), context)
new Iterator[InternalRow] {
http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index cc93fda..e1e04e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -77,6 +77,7 @@ case class FlatMapGroupsInPandasExec(
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
val schema = StructType(child.schema.drop(groupingAttributes.length))
+ val sessionLocalTimeZone = conf.sessionLocalTimeZone
inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
@@ -94,7 +95,7 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
- PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone)
+ PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone)
.compute(grouped, context.partitionId(), context)
columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org