You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2018/12/20 06:22:24 UTC

[GitHub] asfgit closed pull request #23308: [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF

asfgit closed pull request #23308: [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF
URL: https://github.com/apache/spark/pull/23308
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 133fa119b7aa6..1706b3eece6d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -879,6 +879,37 @@ object TypeCoercion {
           }
         }
         e.withNewChildren(children)
+
+      case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
+        val children = udf.children.zip(udf.inputTypes).map { case (in, expected) =>
+          implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in)
+        }
+        udf.withNewChildren(children)
+    }
+
+    private def udfInputToCastType(input: DataType, expectedType: DataType): DataType = {
+      (input, expectedType) match {
+        // SPARK-26308: avoid casting to an arbitrary precision and scale for decimals. Please note
+        // that precision and scale cannot be inferred properly for a ScalaUDF because, when it is
+        // created, it is not bound to any column. So here the precision and scale of the input
+        // column is used.
+        case (in: DecimalType, _: DecimalType) => in
+        case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) =>
+          ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp)
+        case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp, nullableExp)) =>
+          MapType(udfInputToCastType(keyDtIn, keyDtExp),
+            udfInputToCastType(valueDtIn, valueDtExp),
+            nullableExp)
+        case (StructType(fieldsIn), StructType(fieldsExp)) =>
+          val fieldTypes =
+            fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case (dtIn, dtExp) =>
+              udfInputToCastType(dtIn, dtExp)
+            }
+          StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) =>
+            field.copy(dataType = newDt)
+          })
+        case (_, other) => other
+      }
     }
 
     /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index fae90caebf96c..a23aaa3a0b3ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -52,7 +52,7 @@ case class ScalaUDF(
     udfName: Option[String] = None,
     nullable: Boolean = true,
     udfDeterministic: Boolean = true)
-  extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
+  extends Expression with NonSQLExpression with UserDefinedExpression {
 
   // The constructor for SPARK 2.1 and 2.2
   def this(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 20dcefa7e3cad..a26d306cff6b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.math.BigDecimal
+
 import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.QueryExecution
@@ -26,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm
 import org.apache.spark.sql.functions.{lit, udf}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData._
-import org.apache.spark.sql.types.{DataTypes, DoubleType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.QueryExecutionListener
 
 
@@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext {
       checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
     }
   }
+
+  test("SPARK-26308: udf with decimal") {
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Seq(Row(new BigDecimal("2011000000000002456556")))),
+      StructType(Seq(StructField("col1", DecimalType(30, 0)))))
+    val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => {
+      if (value == null) null else value.toBigInteger.toString
+    })
+    checkAnswer(df1.select(udf1(df1.col("col1"))), Seq(Row("2011000000000002456556")))
+  }
+
+  test("SPARK-26308: udf with complex types of decimal") {
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Seq(Row(Array(new BigDecimal("2011000000000002456556"))))),
+      StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0))))))
+    val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => {
+      arr.map(value => if (value == null) null else value.toBigInteger.toString)
+    })
+    checkAnswer(df1.select(udf1($"col1")), Seq(Row(Array("2011000000000002456556"))))
+
+    val df2 = spark.createDataFrame(
+      sparkContext.parallelize(Seq(Row(Map("a" -> new BigDecimal("2011000000000002456556"))))),
+      StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30, 0))))))
+    val udf2 = org.apache.spark.sql.functions.udf((map: Map[String, BigDecimal]) => {
+      map.mapValues(value => if (value == null) null else value.toBigInteger.toString)
+    })
+    checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556"))))
+  }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org