You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Gengliang Wang (Jira)" <ji...@apache.org> on 2022/02/10 04:02:00 UTC

[jira] [Resolved] (SPARK-38146) UDAF fails to aggregate TIMESTAMP_NTZ column

     [ https://issues.apache.org/jira/browse/SPARK-38146?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]

Gengliang Wang resolved SPARK-38146.
------------------------------------
    Fix Version/s: 3.3.0
       Resolution: Fixed

Issue resolved by pull request 35470
[https://github.com/apache/spark/pull/35470]

> UDAF fails to aggregate TIMESTAMP_NTZ column
> --------------------------------------------
>
>                 Key: SPARK-38146
>                 URL: https://issues.apache.org/jira/browse/SPARK-38146
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 3.3.0
>            Reporter: Bruce Robbins
>            Assignee: Bruce Robbins
>            Priority: Major
>             Fix For: 3.3.0
>
>
> When using a UDAF against unsafe rows containing a TIMESTAMP_NTZ column, Spark throws the error:
> {noformat}
> 22/02/08 18:05:12 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)
> java.lang.UnsupportedOperationException: null
> 	at org.apache.spark.sql.catalyst.expressions.UnsafeRow.update(UnsafeRow.java:218) ~[spark-catalyst_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT]
> 	at org.apache.spark.sql.execution.aggregate.BufferSetterGetterUtils.$anonfun$createSetters$15(udaf.scala:217) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT]
> 	at org.apache.spark.sql.execution.aggregate.BufferSetterGetterUtils.$anonfun$createSetters$15$adapted(udaf.scala:215) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT]
> 	at org.apache.spark.sql.execution.aggregate.MutableAggregationBufferImpl.update(udaf.scala:272) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT]
> 	at $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$ScalaAggregateFunction.$anonfun$update$1(<console>:46) ~[scala-library.jar:?]
> 	at scala.collection.immutable.Range.foreach$mVc$sp(Range.scala:158) ~[scala-library.jar:?]
> 	at $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$ScalaAggregateFunction.update(<console>:45) ~[scala-library.jar:?]
> 	at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:458) ~[spark-sql_2.12-3.3.0-SNAPSHOT.jar:3.3.0-SNAPSHOT]
> 	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1.$anonfun$applyOrElse$2(AggregationIterator.scala:197) ~[spark-sql_2.12-3.3.0-SNAPSHO
> {noformat}
> This  is because {{BufferSetterGetterUtils#createSetters}} does not have a case statement for {{TimestampNTZType}}, so it generates a function that tries to call {{UnsafeRow.update}}, which throws an {{UnsupportedOperationException}}.
> This reproduction example is mostly taken from {{AggregationQuerySuite}}:
> {noformat}
> import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
> import org.apache.spark.sql.types._
> import org.apache.spark.sql.Row
> class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
>   def inputSchema: StructType = schema
>   def bufferSchema: StructType = schema
>   def dataType: DataType = schema
>   def deterministic: Boolean = true
>   def initialize(buffer: MutableAggregationBuffer): Unit = {
>     (0 until schema.length).foreach { i =>
>       buffer.update(i, null)
>     }
>   }
>   def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
>     if (!input.isNullAt(0) && input.getInt(0) == 50) {
>       (0 until schema.length).foreach { i =>
>         buffer.update(i, input.get(i))
>       }
>     }
>   }
>   def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
>     if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
>       (0 until schema.length).foreach { i =>
>         buffer1.update(i, buffer2.get(i))
>       }
>     }
>   }
>   def evaluate(buffer: Row): Any = {
>     Row.fromSeq(buffer.toSeq)
>   }
> }
> import scala.util.Random
> import java.time.LocalDateTime
> val r = new Random(65676563L)
> val data = Seq.tabulate(50) { x =>
>   Row((x + 1).toInt, (x + 2).toDouble, (x + 2).toLong, LocalDateTime.parse("2100-01-01T01:33:33.123").minusDays(x + 1))
> }
> val schema = StructType.fromDDL("id int, col1 double, col2 bigint, col3 timestamp_ntz")
> val rdd = spark.sparkContext.parallelize(data, 1)
> val df = spark.createDataFrame(rdd, schema)
> val udaf = new ScalaAggregateFunction(df.schema)
> val allColumns = df.schema.fields.map(f => col(f.name))
> df.groupBy().agg(udaf(allColumns: _*)).show(false)
> {noformat}



--
This message was sent by Atlassian Jira
(v8.20.1#820001)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org