You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/12 20:30:11 UTC
spark git commit: [SPARK-14563][ML] use a random table name instead
of __THIS__ in SQLTransformer
Repository: spark
Updated Branches:
refs/heads/master 7f024c474 -> 1995c2e64
[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer
## What changes were proposed in this pull request?
Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are:
* It doesn't work under HiveContext (in Spark 1.6)
* Race conditions
## How was this patch tested?
* Manual test with HiveContext.
* Added a unit test for `transformSchema` to improve coverage.
cc: yhuai
Author: Xiangrui Meng <me...@databricks.com>
Closes #12330 from mengxr/SPARK-14563.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1995c2e6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1995c2e6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1995c2e6
Branch: refs/heads/master
Commit: 1995c2e6482bf4af5a4be087bfc156311c1bec19
Parents: 7f024c4
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Apr 12 11:30:09 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Apr 12 11:30:09 2016 -0700
----------------------------------------------------------------------
.../org/apache/spark/ml/feature/SQLTransformer.scala | 10 ++++++----
.../org/apache/spark/ml/feature/SQLTransformerSuite.scala | 10 ++++++++++
2 files changed, 16 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1995c2e6/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 95fe942..2002d15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- val outputDF = dataset.sqlContext.sql(realStatement)
- outputDF
+ dataset.sqlContext.sql(realStatement)
}
@Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
- dummyDF.registerTempTable(tableIdentifier)
- val outputSchema = sqlContext.sql($(statement)).schema
+ val tableName = Identifiable.randomUID(uid)
+ val realStatement = $(statement).replace(tableIdentifier, tableName)
+ dummyDF.registerTempTable(tableName)
+ val outputSchema = sqlContext.sql(realStatement).schema
+ sqlContext.dropTempTable(tableName)
outputSchema
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1995c2e6/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 553e0b8..e213e17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
class SQLTransformerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -49,4 +50,13 @@ class SQLTransformerSuite
.setStatement("select * from __THIS__")
testDefaultReadWrite(t)
}
+
+ test("transformSchema") {
+ val df = sqlContext.range(10)
+ val outputSchema = new SQLTransformer()
+ .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
+ .transformSchema(df.schema)
+ val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
+ assert(outputSchema === expected)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org