You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2016/04/12 20:30:11 UTC

spark git commit: [SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer

Repository: spark
Updated Branches:
  refs/heads/master 7f024c474 -> 1995c2e64


[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer

## What changes were proposed in this pull request?

Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are:

* It doesn't work under HiveContext (in Spark 1.6)
* Race conditions

## How was this patch tested?

* Manual test with HiveContext.
* Added a unit test for `transformSchema` to improve coverage.

cc: yhuai

Author: Xiangrui Meng <me...@databricks.com>

Closes #12330 from mengxr/SPARK-14563.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1995c2e6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1995c2e6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1995c2e6

Branch: refs/heads/master
Commit: 1995c2e6482bf4af5a4be087bfc156311c1bec19
Parents: 7f024c4
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Apr 12 11:30:09 2016 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Apr 12 11:30:09 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/SQLTransformer.scala      | 10 ++++++----
 .../org/apache/spark/ml/feature/SQLTransformerSuite.scala | 10 ++++++++++
 2 files changed, 16 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1995c2e6/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 95fe942..2002d15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
     val tableName = Identifiable.randomUID(uid)
     dataset.registerTempTable(tableName)
     val realStatement = $(statement).replace(tableIdentifier, tableName)
-    val outputDF = dataset.sqlContext.sql(realStatement)
-    outputDF
+    dataset.sqlContext.sql(realStatement)
   }
 
   @Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
     val sqlContext = SQLContext.getOrCreate(sc)
     val dummyRDD = sc.parallelize(Seq(Row.empty))
     val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
-    dummyDF.registerTempTable(tableIdentifier)
-    val outputSchema = sqlContext.sql($(statement)).schema
+    val tableName = Identifiable.randomUID(uid)
+    val realStatement = $(statement).replace(tableIdentifier, tableName)
+    dummyDF.registerTempTable(tableName)
+    val outputSchema = sqlContext.sql(realStatement).schema
+    sqlContext.dropTempTable(tableName)
     outputSchema
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1995c2e6/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 553e0b8..e213e17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
 
 class SQLTransformerSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -49,4 +50,13 @@ class SQLTransformerSuite
       .setStatement("select * from __THIS__")
     testDefaultReadWrite(t)
   }
+
+  test("transformSchema") {
+    val df = sqlContext.range(10)
+    val outputSchema = new SQLTransformer()
+      .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
+      .transformSchema(df.schema)
+    val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
+    assert(outputSchema === expected)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org