You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Xiao Jin (Jira)" <ji...@apache.org> on 2021/04/22 07:21:00 UTC

[jira] [Created] (SPARK-35184) Filtering a dataframe after groupBy and user-define-aggregate-function in Pyspark will cause java.lang.UnsupportedOperationException

Xiao Jin created SPARK-35184:
--------------------------------

             Summary: Filtering a dataframe after groupBy and user-define-aggregate-function in Pyspark will cause java.lang.UnsupportedOperationException
                 Key: SPARK-35184
                 URL: https://issues.apache.org/jira/browse/SPARK-35184
             Project: Spark
          Issue Type: Bug
          Components: Optimizer
    Affects Versions: 2.4.0
            Reporter: Xiao Jin


I found some strange error when I'm coding Pyspark UDAF. After I call groupBy function and agg function, I want to filter some data from remaining dataframe, but it seems not work. My sample code is below.
{code:java}
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType, col
>>> df = spark.createDataFrame(
...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
...     ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG)
... def mean_udf(v):
...     return v.mean()
>>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > 5).show()
{code}
The code above will cause exception printed below
{code:java}
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/spark/python/pyspark/sql/dataframe.py", line 378, in show
    print(self._jdf.showString(n, 20, vertical))
  File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
  File "/opt/spark/python/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o3717.showString.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree:
Exchange hashpartitioning(id#1726L, 200)
+- *(1) Filter (mean_udf(v#1727) > 5.0)
   +- Scan ExistingRDD[id#1726L,v#1727]

        at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:119)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
        at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:391)
        at org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121)
        at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:627)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
        at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.python.AggregateInPandasExec.doExecute(AggregateInPandasExec.scala:80)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
        at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:247)
        at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:339)
        at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
        at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
        at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
        at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
        at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
        at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
        at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
        at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
        at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
        at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
        at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
        at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
        at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
        at sun.reflect.GeneratedMethodAccessor139.invoke(Unknown Source)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:238)
        at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.UnsupportedOperationException: Cannot evaluate expression: mean_udf(input[1, double, true])
        at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261)
        at org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50)
        at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108)
        at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105)
        at scala.Option.getOrElse(Option.scala:121)
        at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105)
        at org.apache.spark.sql.catalyst.expressions.BinaryExpression.nullSafeCodeGen(Expression.scala:525)
        at org.apache.spark.sql.catalyst.expressions.BinaryExpression.defineCodeGen(Expression.scala:508)
        at org.apache.spark.sql.catalyst.expressions.BinaryComparison.doGenCode(predicates.scala:563)
        at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108)
        at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105)
        at scala.Option.getOrElse(Option.scala:121)
        at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:105)
        at org.apache.spark.sql.execution.FilterExec.org$apache$spark$sql$execution$FilterExec$$genPredicate$1(basicPhysicalOperators.scala:139)
        at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:179)
        at org.apache.spark.sql.execution.FilterExec$$anonfun$13.apply(basicPhysicalOperators.scala:163)
        at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at scala.collection.immutable.List.foreach(List.scala:392)
        at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
        at scala.collection.immutable.List.map(List.scala:296)
        at org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:163)
        at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:189)
        at org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:374)
        at org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:403)
        at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90)
        at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
        at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85)
        at org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:374)
        at org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125)
        at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:90)
        at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:85)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
        at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:85)
        at org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85)
        at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:544)
        at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:598)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
        at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:92)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:128)
        at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119)
        at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
        ... 48 more
{code}
Optimized Logical Plan here, I found Optimizer had already push down the Filter through PushDownPredicates rule.
{code:java}
>>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > 5).explain(True)
== Parsed Logical Plan ==
'Filter ('mean > 5)
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
   +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan ==
id: bigint, mean: double
Filter (mean#79 > cast(5 as double))
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
   +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan ==
Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- Filter (mean_udf(v#1) > 5.0)
   +- LogicalRDD [id#0L, v#1], false== Physical Plan ==
!AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79]
+- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#0L, 200)
      +- *(1) Filter (mean_udf(v#1) > 5.0)
         +- Scan ExistingRDD[id#0L,v#1]
{code}
Compare with the official mean function, it will not push down Filter node throuph PushDownPredicates rule.
{code:java}
>>> df.groupby("id").agg(mean_udf(df['v']).alias("mean")).filter(col("mean") > 5).explain(True)
== Parsed Logical Plan ==
'Filter ('mean > 5)
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
   +- LogicalRDD [id#0L, v#1], false== Analyzed Logical Plan ==
id: bigint, mean: double
Filter (mean#79 > cast(5 as double))
+- Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
   +- LogicalRDD [id#0L, v#1], false== Optimized Logical Plan ==
Aggregate [id#0L], [id#0L, mean_udf(v#1) AS mean#79]
+- Filter (mean_udf(v#1) > 5.0)
   +- LogicalRDD [id#0L, v#1], false== Physical Plan ==
!AggregateInPandas [id#0L], [mean_udf(v#1)], [id#0L, mean_udf(v)#78 AS mean#79]
+- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#0L, 200)
      +- *(1) Filter (mean_udf(v#1) > 5.0)
         +- Scan ExistingRDD[id#0L,v#1]
{code}
And see the code in PushPredicateThroughNonJoin rule matched our case below.
{code:java}
    case filter @ Filter(condition, aggregate: Aggregate)
      if aggregate.aggregateExpressions.forall(_.deterministic)
        && aggregate.groupingExpressions.nonEmpty =>
      val aliasMap = getAliasMap(aggregate)      // For each filter, expand the alias and check if the filter can be evaluated using
      // attributes produced by the aggregate operator's child operator.
      val (candidates, nonDeterministic) =
        splitConjunctivePredicates(condition).partition(_.deterministic)      val (pushDown, rest) = candidates.partition { cond =>
        val replaced = replaceAlias(cond, aliasMap)
        cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
      }      val stayUp = rest ++ nonDeterministic      if (pushDown.nonEmpty) {
        val pushDownPredicate = pushDown.reduce(And)
        val replaced = replaceAlias(pushDownPredicate, aliasMap)
        val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
        // If there is no more filter to stay up, just eliminate the filter.
        // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
        if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
      } else {
        filter
      }
{code}
It's easy to infer when I use Python UDAF function, the references in condition is the subset of child of Aggregate node, because Python UDAF function in Catalyst is actually a PythonUDF expression, with its references is the references of its input expression
{code:java}
case class PythonUDF(
    name: String,
    func: PythonFunction,
    dataType: DataType,
    children: Seq[Expression],
    evalType: Int,
    udfDeterministic: Boolean,
    resultId: ExprId = NamedExpression.newExprId)
{code}
But the official mean function in Catalyst is Average expression, which is DeclarativeAggregate with multiple aggBufferAttributes, which means the references of Average is a sumDataType and a LongType.
{code:java}
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
  ...
  private lazy val sum = AttributeReference("sum", sumDataType)()
  private lazy val count = AttributeReference("count", LongType)()  override lazy val aggBufferAttributes = sum :: count :: Nil
  ...{code}
{code:java}
case class AggregateExpression(
    aggregateFunction: AggregateFunction,
    mode: AggregateMode,
    isDistinct: Boolean,
    filter: Option[Expression],
    resultId: ExprId)
  extends Expression
  with Unevaluable {
  ...
  @transient
  override lazy val references: AttributeSet = {
    val aggAttributes = mode match {
      case Partial | Complete => aggregateFunction.references
      case PartialMerge | Final => AttributeSet(aggregateFunction.inputAggBufferAttributes)
    }
    aggAttributes ++ filterAttributes
  }{code}
So, the references in PythonUDF is the subset of Aggregate's child's output but Average is not.

I think the root cause of the problem is Catalyst does not treat the Pandas UDAF as real AggregateFunction, so the Pandas UDAF will optimized like normal UDF function. Maybe it's time to redesign the definition of the Pandas UDAF? So it can get on the right track。

PS: All the speculation above is only a guess.



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org