You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/28 23:39:29 UTC

spark git commit: [SPARK-8003][SQL] Added virtual column support to Spark

Repository: spark
Updated Branches:
  refs/heads/master 8d5bb5283 -> b88b868eb


[SPARK-8003][SQL] Added virtual column support to Spark

Added virtual column support by adding a new resolution role to the query analyzer. Additional virtual columns can be added by adding case expressions to [the new rule](https://github.com/JDrit/spark/blob/virt_columns/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala#L1026) and my modifying the [logical plan](https://github.com/JDrit/spark/blob/virt_columns/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala#L216) to resolve them.

This also solves [SPARK-8003](https://issues.apache.org/jira/browse/SPARK-8003)

This allows you to perform queries such as:
```sql
select spark__partition__id, count(*) as c from table group by spark__partition__id;
```

Author: Joseph Batchik <jo...@gmail.com>
Author: JD <jd...@csh.rit.edu>

Closes #7478 from JDrit/virt_columns and squashes the following commits:

7932bf0 [Joseph Batchik] adding spark__partition__id to hive as well
f8a9c6c [Joseph Batchik] merging in master
e49da48 [JD] fixes for @rxin's suggestions
60e120b [JD] fixing test in merge
4bf8554 [JD] merging in master
c68bc0f [Joseph Batchik] Adding function register ability to SQLContext and adding a function for spark__partition__id()


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b88b868e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b88b868e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b88b868e

Branch: refs/heads/master
Commit: b88b868eb378bdb7459978842b5572a0b498f412
Parents: 8d5bb52
Author: Joseph Batchik <jo...@gmail.com>
Authored: Tue Jul 28 14:39:25 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jul 28 14:39:25 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/FunctionRegistry.scala |  2 +-
 .../main/scala/org/apache/spark/sql/SQLContext.scala   | 11 ++++++++++-
 .../sql/execution/expressions/SparkPartitionID.scala   |  2 +-
 .../main/scala/org/apache/spark/sql/functions.scala    |  2 +-
 .../src/test/scala/org/apache/spark/sql/UDFSuite.scala |  7 +++++++
 .../execution/expression/NondeterministicSuite.scala   |  2 +-
 .../scala/org/apache/spark/sql/hive/HiveContext.scala  | 13 +++++++++++--
 .../scala/org/apache/spark/sql/hive/UDFSuite.scala     |  9 ++++++++-
 8 files changed, 40 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 61ee6f6..9b60943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -239,7 +239,7 @@ object FunctionRegistry {
   }
 
   /** See usage above. */
-  private def expression[T <: Expression](name: String)
+  def expression[T <: Expression](name: String)
       (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
 
     // See if we can find a constructor that accepts Seq[Expression]

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbb2a09..56cd8f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,6 +31,8 @@ import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder}
+import org.apache.spark.sql.execution.expressions.SparkPartitionID
 import org.apache.spark.sql.SQLConf.SQLConfEntry
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.errors.DialectException
@@ -140,7 +142,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
 
   // TODO how to handle the temp function per user session?
   @transient
-  protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin
+  protected[sql] lazy val functionRegistry: FunctionRegistry = {
+    val reg = FunctionRegistry.builtin
+    val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))](
+      FunctionExpression[SparkPartitionID]("spark__partition__id")
+    )
+    extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) }
+    reg
+  }
 
   @transient
   protected[sql] lazy val analyzer: Analyzer =

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 61ef079..98c8eab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType}
 /**
  * Expression that returns the current partition id of the Spark task.
  */
-private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic {
+private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
   override def nullable: Boolean = false
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index cec61b6..0148991 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -741,7 +741,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
+  def sparkPartitionId(): Column = execution.expressions.SparkPartitionID()
 
   /**
    * Computes the square root of the specified float value.

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index c1516b4..9b326c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -51,6 +51,13 @@ class UDFSuite extends QueryTest {
     df.selectExpr("count(distinct a)")
   }
 
+  test("SPARK-8003 spark__partition__id") {
+    val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
+    df.registerTempTable("tmp_table")
+    checkAnswer(ctx.sql("select spark__partition__id() from tmp_table").toDF(), Row(0))
+    ctx.dropTempTable("tmp_table")
+  }
+
   test("error reporting for incorrect number of arguments") {
     val df = ctx.emptyDataFrame
     val e = intercept[AnalysisException] {

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
index 1c5a2ed..b6e79ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala
@@ -27,6 +27,6 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("SparkPartitionID") {
-    checkEvaluation(SparkPartitionID, 0)
+    checkEvaluation(SparkPartitionID(), 0)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 110f51a..8b35c12 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -38,6 +38,9 @@ import org.apache.spark.Logging
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder}
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.execution.expressions.SparkPartitionID
 import org.apache.spark.sql.SQLConf.SQLConfEntry
 import org.apache.spark.sql.SQLConf.SQLConfEntry._
 import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect}
@@ -372,8 +375,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
 
   // Note that HiveUDFs will be overridden by functions registered in this context.
   @transient
-  override protected[sql] lazy val functionRegistry: FunctionRegistry =
-    new HiveFunctionRegistry(FunctionRegistry.builtin)
+  override protected[sql] lazy val functionRegistry: FunctionRegistry = {
+    val reg = new HiveFunctionRegistry(FunctionRegistry.builtin)
+    val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))](
+      FunctionExpression[SparkPartitionID]("spark__partition__id")
+    )
+    extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) }
+    reg
+  }
 
   /* An analyzer that uses the Hive metastore. */
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/b88b868e/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 4056dee..9cea5d4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.sql.hive
 
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{Row, QueryTest}
 
 case class FunctionResult(f1: String, f2: String)
 
 class UDFSuite extends QueryTest {
 
   private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
+  import ctx.implicits._
 
   test("UDF case insensitive") {
     ctx.udf.register("random0", () => { Math.random() })
@@ -33,4 +34,10 @@ class UDFSuite extends QueryTest {
     assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
     assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
   }
+
+  test("SPARK-8003 spark__partition__id") {
+    val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying")
+    ctx.registerDataFrameAsTable(df, "test_table")
+    checkAnswer(ctx.sql("select spark__partition__id() from test_table LIMIT 1").toDF(), Row(0))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org