You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/09 10:54:54 UTC

spark git commit: [SPARK-22442][SQL] ScalaReflection should produce correct field names for special characters

Repository: spark
Updated Branches:
  refs/heads/master fe93c0bf6 -> 40a8aefaf


[SPARK-22442][SQL] ScalaReflection should produce correct field names for special characters

## What changes were proposed in this pull request?

For a class with field name of special characters, e.g.:
```scala
case class MyType(`field.1`: String, `field 2`: String)
```

Although we can manipulate DataFrame/Dataset, the field names are encoded:
```scala
scala> val df = Seq(MyType("a", "b"), MyType("c", "d")).toDF
df: org.apache.spark.sql.DataFrame = [field$u002E1: string, field$u00202: string]
scala> df.as[MyType].collect
res7: Array[MyType] = Array(MyType(a,b), MyType(c,d))
```

It causes resolving problem when we try to convert the data with non-encoded field names:
```scala
spark.read.json(path).as[MyType]
...
[info]   org.apache.spark.sql.AnalysisException: cannot resolve '`field$u002E1`' given input columns: [field 2, fie
ld.1];
[info]   at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
...
```

We should use decoded field name in Dataset schema.

## How was this patch tested?

Added tests.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #19664 from viirya/SPARK-22442.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40a8aefa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40a8aefa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40a8aefa

Branch: refs/heads/master
Commit: 40a8aefaf3e97e80b23fb05d4afdcc30e1922312
Parents: fe93c0b
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Nov 9 11:54:50 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Nov 9 11:54:50 2017 +0100

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala      |  9 +++++----
 .../catalyst/expressions/objects/objects.scala    | 11 +++++++----
 .../spark/sql/catalyst/ScalaReflectionSuite.scala | 18 +++++++++++++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++
 4 files changed, 41 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 17e595f..f62553d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection {
     def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
       val newPath = path
         .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
-        .getOrElse(UnresolvedAttribute(part))
+        .getOrElse(UnresolvedAttribute.quoted(part))
       upCastToExpectedType(newPath, dataType, walkedTypePath)
     }
 
@@ -675,7 +675,7 @@ object ScalaReflection extends ScalaReflection {
     val m = runtimeMirror(cls.getClassLoader)
     val classSymbol = m.staticClass(cls.getName)
     val t = classSymbol.selfType
-    constructParams(t).map(_.name.toString)
+    constructParams(t).map(_.name.decodedName.toString)
   }
 
   /**
@@ -855,11 +855,12 @@ trait ScalaReflection {
     // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
     if (actualTypeArgs.nonEmpty) {
       params.map { p =>
-        p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+        p.name.decodedName.toString ->
+          p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
       }
     } else {
       params.map { p =>
-        p.name.toString -> p.typeSignature
+        p.name.decodedName.toString -> p.typeSignature
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 6ae3490..f2eee99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
@@ -214,11 +215,13 @@ case class Invoke(
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
 
+  private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
+
   @transient lazy val method = targetObject.dataType match {
     case ObjectType(cls) =>
-      val m = cls.getMethods.find(_.getName == functionName)
+      val m = cls.getMethods.find(_.getName == encodedFunctionName)
       if (m.isEmpty) {
-        sys.error(s"Couldn't find $functionName on $cls")
+        sys.error(s"Couldn't find $encodedFunctionName on $cls")
       } else {
         m
       }
@@ -247,7 +250,7 @@ case class Invoke(
     }
 
     val evaluate = if (returnPrimitive) {
-      getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
+      getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
     } else {
       val funcResult = ctx.freshName("funcResult")
       // If the function can return null, we do an extra check to make sure our null bit is still
@@ -265,7 +268,7 @@ case class Invoke(
       }
       s"""
         Object $funcResult = null;
-        ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
+        ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")}
         $assignResult
       """
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index a5b9855..f77af5d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -79,6 +80,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) {
   def this(b: String, a: Int) = this(a, b, c = 1.0)
 }
 
+case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
+
 object TestingUDT {
   @SQLUserDefinedType(udt = classOf[NestedStructUDT])
   class NestedStruct(val a: Integer, val b: Long, val c: Double)
@@ -335,4 +338,17 @@ class ScalaReflectionSuite extends SparkFunSuite {
     assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
   }
 
+  test("SPARK-22442: Generate correct field names for special characters") {
+    val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
+      0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
+    val deserializer = deserializerFor[SpecialCharAsFieldData]
+    assert(serializer.dataType(0).name == "field.1")
+    assert(serializer.dataType(1).name == "field 2")
+
+    val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
+      case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
+    }}
+    assert(argumentsFields(0) == Seq("field.1"))
+    assert(argumentsFields(1) == Seq("field 2"))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 1537ce3..c67165c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1398,6 +1398,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val actual = kvDataset.toString
     assert(expected === actual)
   }
+
+  test("SPARK-22442: Generate correct field names for special characters") {
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      val data = """{"field.1": 1, "field 2": 2}"""
+      Seq(data).toDF().repartition(1).write.text(path)
+      val ds = spark.read.json(path).as[SpecialCharClass]
+      checkDataset(ds, SpecialCharClass("1", "2"))
+    }
+  }
 }
 
 case class SingleData(id: Int)
@@ -1487,3 +1497,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA)
 case class CircularReferenceClassC(ar: Array[CircularReferenceClassC])
 case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
 case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD])
+
+case class SpecialCharClass(`field.1`: String, `field 2`: String)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org