You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/03/25 20:08:00 UTC
spark git commit: [SPARK-12443][SQL] encoderFor should support Decimal
Repository: spark
Updated Branches:
refs/heads/master 11fa8741c -> ca003354d
[SPARK-12443][SQL] encoderFor should support Decimal
## What changes were proposed in this pull request?
JIRA: https://issues.apache.org/jira/browse/SPARK-12443
`constructorFor` will call `dataTypeFor` to determine if a type is `ObjectType` or not. If there is not case for `Decimal`, it will be recognized as `ObjectType` and causes the bug.
## How was this patch tested?
Test is added into `ExpressionEncoderSuite`.
Author: Liang-Chi Hsieh <si...@tw.ibm.com>
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #10399 from viirya/fix-encoder-decimal.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca003354
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca003354
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca003354
Branch: refs/heads/master
Commit: ca003354da5e738e97418efc5af07be071c16d8f
Parents: 11fa874
Author: Liang-Chi Hsieh <si...@tw.ibm.com>
Authored: Fri Mar 25 12:07:56 2016 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Mar 25 12:07:56 2016 -0700
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 1 +
.../sql/catalyst/encoders/RowEncoder.scala | 21 +++++++++++++++++---
.../org/apache/spark/sql/types/Decimal.scala | 8 ++++++++
.../encoders/ExpressionEncoderSuite.scala | 4 +++-
.../sql/catalyst/encoders/RowEncoderSuite.scala | 17 ++++++++++++++++
5 files changed, 47 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 5e1672c..f208401 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
+ case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
case _ =>
val className = getClassNameFromType(tpe)
className match {
http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 902644e..30f56d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -79,7 +79,7 @@ object RowEncoder {
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
- "apply",
+ "fromDecimal",
inputObject :: Nil)
case StringType =>
@@ -95,7 +95,7 @@ object RowEncoder {
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
- case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et))
+ case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et))
}
case t @ MapType(kt, vt, valueNullable) =>
@@ -129,7 +129,7 @@ object RowEncoder {
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
extractorsFor(
- Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
+ Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
If(IsNull(inputObject),
@@ -137,6 +137,21 @@ object RowEncoder {
CreateStruct(convertedFields))
}
+ /**
+ * Returns the `DataType` that can be used when generating code that converts input data
+ * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned
+ * by this function can be more permissive since multiple external types may map to a single
+ * internal type. For example, for an input with DecimalType in external row, its external types
+ * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
+ * `org.apache.spark.sql.types.Decimal`.
+ */
+ private def externalDataTypeForInput(dt: DataType): DataType = dt match {
+ // In order to support both Decimal and java BigDecimal in external row, we make this
+ // as java.lang.Object.
+ case _: DecimalType => ObjectType(classOf[java.lang.Object])
+ case _ => externalDataTypeFor(dt)
+ }
+
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case CalendarIntervalType => dt
http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index f0e535b..a30a392 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -376,6 +376,14 @@ object Decimal {
def apply(value: String): Decimal = new Decimal().set(BigDecimal(value))
+ // This is used for RowEncoder to handle Decimal inside external row.
+ def fromDecimal(value: Any): Decimal = {
+ value match {
+ case j: java.math.BigDecimal => apply(j)
+ case d: Decimal => d
+ }
+ }
+
/**
* Creates a decimal from unscaled, precision and scale without checking the bounds.
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3024858..f6583bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType}
+import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -101,6 +101,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
// encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
+ encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
+
encodeDecodeTest("hello", "string")
encodeDecodeTest(Date.valueOf("2012-12-23"), "date")
encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp")
http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index bf0360c..a8fa372 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -143,6 +143,23 @@ class RowEncoderSuite extends SparkFunSuite {
assert(input.getStruct(0) == convertedBack.getStruct(0))
}
+ test("encode/decode Decimal") {
+ val schema = new StructType()
+ .add("int", IntegerType)
+ .add("string", StringType)
+ .add("double", DoubleType)
+ .add("decimal", DecimalType.SYSTEM_DEFAULT)
+
+ val encoder = RowEncoder(schema)
+
+ val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
+ val row = encoder.toRow(input)
+ val convertedBack = encoder.fromRow(row)
+ // Decimal inside external row will be converted back to Java BigDecimal when decoding.
+ assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
+ .compareTo(convertedBack.getDecimal(3)) == 0)
+ }
+
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org