You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/09 10:54:54 UTC
spark git commit: [SPARK-22442][SQL] ScalaReflection should produce
correct field names for special characters
Repository: spark
Updated Branches:
refs/heads/master fe93c0bf6 -> 40a8aefaf
[SPARK-22442][SQL] ScalaReflection should produce correct field names for special characters
## What changes were proposed in this pull request?
For a class with field name of special characters, e.g.:
```scala
case class MyType(`field.1`: String, `field 2`: String)
```
Although we can manipulate DataFrame/Dataset, the field names are encoded:
```scala
scala> val df = Seq(MyType("a", "b"), MyType("c", "d")).toDF
df: org.apache.spark.sql.DataFrame = [field$u002E1: string, field$u00202: string]
scala> df.as[MyType].collect
res7: Array[MyType] = Array(MyType(a,b), MyType(c,d))
```
It causes resolving problem when we try to convert the data with non-encoded field names:
```scala
spark.read.json(path).as[MyType]
...
[info] org.apache.spark.sql.AnalysisException: cannot resolve '`field$u002E1`' given input columns: [field 2, fie
ld.1];
[info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
...
```
We should use decoded field name in Dataset schema.
## How was this patch tested?
Added tests.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #19664 from viirya/SPARK-22442.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40a8aefa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40a8aefa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40a8aefa
Branch: refs/heads/master
Commit: 40a8aefaf3e97e80b23fb05d4afdcc30e1922312
Parents: fe93c0b
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Nov 9 11:54:50 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Nov 9 11:54:50 2017 +0100
----------------------------------------------------------------------
.../spark/sql/catalyst/ScalaReflection.scala | 9 +++++----
.../catalyst/expressions/objects/objects.scala | 11 +++++++----
.../spark/sql/catalyst/ScalaReflectionSuite.scala | 18 +++++++++++++++++-
.../scala/org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++
4 files changed, 41 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 17e595f..f62553d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection {
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
- .getOrElse(UnresolvedAttribute(part))
+ .getOrElse(UnresolvedAttribute.quoted(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
@@ -675,7 +675,7 @@ object ScalaReflection extends ScalaReflection {
val m = runtimeMirror(cls.getClassLoader)
val classSymbol = m.staticClass(cls.getName)
val t = classSymbol.selfType
- constructParams(t).map(_.name.toString)
+ constructParams(t).map(_.name.decodedName.toString)
}
/**
@@ -855,11 +855,12 @@ trait ScalaReflection {
// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
if (actualTypeArgs.nonEmpty) {
params.map { p =>
- p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ p.name.decodedName.toString ->
+ p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
}
} else {
params.map { p =>
- p.name.toString -> p.typeSignature
+ p.name.decodedName.toString -> p.typeSignature
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 6ae3490..f2eee99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
@@ -214,11 +215,13 @@ case class Invoke(
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
+
@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
- val m = cls.getMethods.find(_.getName == functionName)
+ val m = cls.getMethods.find(_.getName == encodedFunctionName)
if (m.isEmpty) {
- sys.error(s"Couldn't find $functionName on $cls")
+ sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else {
m
}
@@ -247,7 +250,7 @@ case class Invoke(
}
val evaluate = if (returnPrimitive) {
- getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
+ getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
// If the function can return null, we do an extra check to make sure our null bit is still
@@ -265,7 +268,7 @@ case class Invoke(
}
s"""
Object $funcResult = null;
- ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
+ ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")}
$assignResult
"""
}
http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index a5b9855..f77af5d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -79,6 +80,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) {
def this(b: String, a: Int) = this(a, b, c = 1.0)
}
+case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
+
object TestingUDT {
@SQLUserDefinedType(udt = classOf[NestedStructUDT])
class NestedStruct(val a: Integer, val b: Long, val c: Double)
@@ -335,4 +338,17 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
}
+ test("SPARK-22442: Generate correct field names for special characters") {
+ val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
+ 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
+ val deserializer = deserializerFor[SpecialCharAsFieldData]
+ assert(serializer.dataType(0).name == "field.1")
+ assert(serializer.dataType(1).name == "field 2")
+
+ val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
+ case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
+ }}
+ assert(argumentsFields(0) == Seq("field.1"))
+ assert(argumentsFields(1) == Seq("field 2"))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/40a8aefa/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 1537ce3..c67165c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1398,6 +1398,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val actual = kvDataset.toString
assert(expected === actual)
}
+
+ test("SPARK-22442: Generate correct field names for special characters") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val data = """{"field.1": 1, "field 2": 2}"""
+ Seq(data).toDF().repartition(1).write.text(path)
+ val ds = spark.read.json(path).as[SpecialCharClass]
+ checkDataset(ds, SpecialCharClass("1", "2"))
+ }
+ }
}
case class SingleData(id: Int)
@@ -1487,3 +1497,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA)
case class CircularReferenceClassC(ar: Array[CircularReferenceClassC])
case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD])
+
+case class SpecialCharClass(`field.1`: String, `field 2`: String)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org