You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/03/25 20:08:00 UTC

spark git commit: [SPARK-12443][SQL] encoderFor should support Decimal

Repository: spark
Updated Branches:
  refs/heads/master 11fa8741c -> ca003354d


[SPARK-12443][SQL] encoderFor should support Decimal

## What changes were proposed in this pull request?

JIRA: https://issues.apache.org/jira/browse/SPARK-12443

`constructorFor` will call `dataTypeFor` to determine if a type is `ObjectType` or not. If there is not case for `Decimal`, it will be recognized as `ObjectType` and causes the bug.

## How was this patch tested?

Test is added into `ExpressionEncoderSuite`.

Author: Liang-Chi Hsieh <si...@tw.ibm.com>
Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #10399 from viirya/fix-encoder-decimal.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca003354
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca003354
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca003354

Branch: refs/heads/master
Commit: ca003354da5e738e97418efc5af07be071c16d8f
Parents: 11fa874
Author: Liang-Chi Hsieh <si...@tw.ibm.com>
Authored: Fri Mar 25 12:07:56 2016 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Mar 25 12:07:56 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  1 +
 .../sql/catalyst/encoders/RowEncoder.scala      | 21 +++++++++++++++++---
 .../org/apache/spark/sql/types/Decimal.scala    |  8 ++++++++
 .../encoders/ExpressionEncoderSuite.scala       |  4 +++-
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 17 ++++++++++++++++
 5 files changed, 47 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 5e1672c..f208401 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< definitions.ByteTpe => ByteType
       case t if t <:< definitions.BooleanTpe => BooleanType
       case t if t <:< localTypeOf[Array[Byte]] => BinaryType
+      case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
       case _ =>
         val className = getClassNameFromType(tpe)
         className match {

http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 902644e..30f56d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -79,7 +79,7 @@ object RowEncoder {
       StaticInvoke(
         Decimal.getClass,
         DecimalType.SYSTEM_DEFAULT,
-        "apply",
+        "fromDecimal",
         inputObject :: Nil)
 
     case StringType =>
@@ -95,7 +95,7 @@ object RowEncoder {
           classOf[GenericArrayData],
           inputObject :: Nil,
           dataType = t)
-      case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et))
+      case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et))
     }
 
     case t @ MapType(kt, vt, valueNullable) =>
@@ -129,7 +129,7 @@ object RowEncoder {
           Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
           Literal.create(null, f.dataType),
           extractorsFor(
-            Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
+            Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
             f.dataType))
       }
       If(IsNull(inputObject),
@@ -137,6 +137,21 @@ object RowEncoder {
         CreateStruct(convertedFields))
   }
 
+  /**
+   * Returns the `DataType` that can be used when generating code that converts input data
+   * into the Spark SQL internal format.  Unlike `externalDataTypeFor`, the `DataType` returned
+   * by this function can be more permissive since multiple external types may map to a single
+   * internal type.  For example, for an input with DecimalType in external row, its external types
+   * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
+   * `org.apache.spark.sql.types.Decimal`.
+   */
+  private def externalDataTypeForInput(dt: DataType): DataType = dt match {
+    // In order to support both Decimal and java BigDecimal in external row, we make this
+    // as java.lang.Object.
+    case _: DecimalType => ObjectType(classOf[java.lang.Object])
+    case _ => externalDataTypeFor(dt)
+  }
+
   private def externalDataTypeFor(dt: DataType): DataType = dt match {
     case _ if ScalaReflection.isNativeType(dt) => dt
     case CalendarIntervalType => dt

http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index f0e535b..a30a392 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -376,6 +376,14 @@ object Decimal {
 
   def apply(value: String): Decimal = new Decimal().set(BigDecimal(value))
 
+  // This is used for RowEncoder to handle Decimal inside external row.
+  def fromDecimal(value: Any): Decimal = {
+    value match {
+      case j: java.math.BigDecimal => apply(j)
+      case d: Decimal => d
+    }
+  }
+
   /**
    * Creates a decimal from unscaled, precision and scale without checking the bounds.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3024858..f6583bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType}
+import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
 
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
@@ -101,6 +101,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
   encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
   // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
 
+  encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
+
   encodeDecodeTest("hello", "string")
   encodeDecodeTest(Date.valueOf("2012-12-23"), "date")
   encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp")

http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index bf0360c..a8fa372 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -143,6 +143,23 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(input.getStruct(0) == convertedBack.getStruct(0))
   }
 
+  test("encode/decode Decimal") {
+    val schema = new StructType()
+      .add("int", IntegerType)
+      .add("string", StringType)
+      .add("double", DoubleType)
+      .add("decimal", DecimalType.SYSTEM_DEFAULT)
+
+    val encoder = RowEncoder(schema)
+
+    val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
+    val row = encoder.toRow(input)
+    val convertedBack = encoder.fromRow(row)
+    // Decimal inside external row will be converted back to Java BigDecimal when decoding.
+    assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
+      .compareTo(convertedBack.getDecimal(3)) == 0)
+  }
+
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
       val encoder = RowEncoder(schema)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org