You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Xiao Jin (Jira)" <ji...@apache.org> on 2021/04/22 07:21:00 UTC
[jira] [Created] (SPARK-35184) Filtering a dataframe after groupBy
and user-define-aggregate-function in Pyspark will cause
java.lang.UnsupportedOperationException
Xiao Jin created SPARK-35184:
--------------------------------
Summary: Filtering a dataframe after groupBy and user-define-aggregate-function in Pyspark will cause java.lang.UnsupportedOperationException
Key: SPARK-35184
URL: https://issues.apache.org/jira/browse/SPARK-35184
Project: Spark
Issue Type: Bug
Components: Optimizer
Affects Versions: 2.4.0
Reporter: Xiao Jin
I found some strange error when I'm coding Pyspark UDAF. After I call groupBy function and agg function, I want to filter some data from remaining dataframe, but it seems not work. My sample code is below.
{code:java}
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType, col
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG)
... def mean_udf(v):
... return v.mean()
>>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > 5).show()
{code}
The code above will cause exception printed below
{code:java}
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/spark/python/pyspark/sql/dataframe.py", line 378, in show
print(self._jdf.showString(n, 20, vertical))
File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
File "/opt/spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o3717.showString.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree:
Exchange hashpartitioning(id#1726L, 200)
+- *(1) Filter (mean_udf(v#1727) > 5.0)
+- Scan ExistingRDD[id#1726L,v#1727]
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:119)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:391)
at org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:627)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
at org.apache.spark.sql.execution.python.AggregateInPandasExec.doExecute(AggregateInPandasExec.scala:80)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:247)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:339)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
at sun.reflect.GeneratedMethodAccessor139.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.UnsupportedOperationException: Cannot evaluate expression: mean_udf(input[1, double, true])
at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261)
at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50)
at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108)
at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105)
at scala.Option.getOrElse(Option.scala:121)
at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:525)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:508)
at org.apache.spark.sql.catalyst.expressions.BinaryComparison.doGenCode(predicates.scala:563)
at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108)
at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105)
at scala.Option.getOrElse(Option.scala:121)
at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105)
at org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:139)
at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:179)
at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:163)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.immutable.List.map(List.scala:296)
at org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:163)
at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:189)
at org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:374)
at org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:403)
at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90)
at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85)
at org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:374)
at org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125)
at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90)
at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85)
at org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:544)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:598)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:92)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:128)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
... 48 more
{code}
Optimized Logical Plan here, I found Optimizer had already push down the Filter through PushDownPredicates rule.
{code:java}
>>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > 5).explain(True)
== Parsed Logical Plan ==
'Filter ('mean > 5)
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan ==
id: bigint, mean: double
Filter (mean#79 > cast(5 as double))
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan ==
Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- Filter (mean_udf(v#1) > 5.0)
+- LogicalRDD [id#0L, v#1], false== Physical Plan ==
!AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79]
+- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#0L, 200)
+- *(1) Filter (mean_udf(v#1) > 5.0)
+- Scan ExistingRDD[id#0L,v#1]
{code}
Compare with the official mean function, it will not push down Filter node throuph PushDownPredicates rule.
{code:java}
>>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > 5).explain(True)
== Parsed Logical Plan ==
'Filter ('mean > 5)
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan ==
id: bigint, mean: double
Filter (mean#79 > cast(5 as double))
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan ==
Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- Filter (mean_udf(v#1) > 5.0)
+- LogicalRDD [id#0L, v#1], false== Physical Plan ==
!AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79]
+- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#0L, 200)
+- *(1) Filter (mean_udf(v#1) > 5.0)
+- Scan ExistingRDD[id#0L,v#1]
{code}
And see the code in PushPredicateThroughNonJoin rule matched our case below.
{code:java}
case filter @ Filter(condition, aggregate: Aggregate)
if aggregate.aggregateExpressions.forall(_.deterministic)
&& aggregate.groupingExpressions.nonEmpty =>
val aliasMap = getAliasMap(aggregate) // For each filter, expand the alias and check if the filter can be evaluated using
// attributes produced by the aggregate operator's child operator.
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond =>
val replaced = replaceAlias(cond, aliasMap)
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
} val stayUp = rest ++ nonDeterministic if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap)
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
// If there is no more filter to stay up, just eliminate the filter.
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
} else {
filter
}
{code}
It's easy to infer when I use Python UDAF function, the references in condition is the subset of child of Aggregate node, because Python UDAF function in Catalyst is actually a PythonUDF expression, with its references is the references of its input expression
{code:java}
case class PythonUDF(
name: String,
func: PythonFunction,
dataType: DataType,
children: Seq[Expression],
evalType: Int,
udfDeterministic: Boolean,
resultId: ExprId = NamedExpression.newExprId)
{code}
But the official mean function in Catalyst is Average expression, which is DeclarativeAggregate with multiple aggBufferAttributes, which means the references of Average is a sumDataType and a LongType.
{code:java}
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
...
private lazy val sum = AttributeReference("sum", sumDataType)()
private lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil
...{code}
{code:java}
case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
isDistinct: Boolean,
filter: Option[Expression],
resultId: ExprId)
extends Expression
with Unevaluable {
...
@transient
override lazy val references: AttributeSet = {
val aggAttributes = mode match {
case Partial | Complete => aggregateFunction.references
case PartialMerge | Final => AttributeSet(aggregateFunction.inputAggBufferAttributes)
}
aggAttributes ++ filterAttributes
}{code}
So, the references in PythonUDF is the subset of Aggregate's child's output but Average is not.
I think the root cause of the problem is Catalyst does not treat the Pandas UDAF as real AggregateFunction, so the Pandas UDAF will optimized like normal UDF function. Maybe it's time to redesign the definition of the Pandas UDAF? So it can get on the right track。
PS: All the speculation above is only a guess.
--
This message was sent by Atlassian Jira
(v8.3.4#803005)
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org