You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/02/04 07:15:50 UTC

spark git commit: [SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions

Repository: spark
Updated Branches:
  refs/heads/master eb1563185 -> 40c4cb2fe


[SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions

```scala
df.selectExpr("abs(colA)", "colB")
df.filter("age > 21")
```

Author: Reynold Xin <rx...@databricks.com>

Closes #4348 from rxin/SPARK-5579 and squashes the following commits:

2baeef2 [Reynold Xin] Fix Python.
b416372 [Reynold Xin] [SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40c4cb2f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40c4cb2f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40c4cb2f

Branch: refs/heads/master
Commit: 40c4cb2fe79ceac0d656be7b72cb2ee8d7db7258
Parents: eb15631
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue Feb 3 22:15:35 2015 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Feb 3 22:15:35 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py                           |  5 ++---
 .../apache/spark/sql/catalyst/SqlParser.scala   | 10 +++++++++
 .../scala/org/apache/spark/sql/DataFrame.scala  | 23 ++++++++++++++++++--
 .../org/apache/spark/sql/DataFrameImpl.scala    | 22 ++++++++++++++-----
 .../apache/spark/sql/IncomputableColumn.scala   |  8 +++++--
 .../org/apache/spark/sql/DataFrameSuite.scala   | 12 ++++++++++
 6 files changed, 67 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40c4cb2f/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 268c7ef..74305de 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2126,10 +2126,9 @@ class DataFrame(object):
         """
         if not cols:
             raise ValueError("should sort by at least one column")
-        jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
                                         self._sc._gateway._gateway_client)
-        jdf = self._jdf.sort(_to_java_column(cols[0]),
-                             self._sc._jvm.Dsl.toColumns(jcols))
+        jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
         return DataFrame(jdf, self.sql_ctx)
 
     sortBy = sort

http://git-wip-us.apache.org/repos/asf/spark/blob/40c4cb2f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 5c006e9..a9bd079 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -36,6 +36,16 @@ import org.apache.spark.sql.types._
  * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
  */
 class SqlParser extends AbstractSparkSQLParser {
+
+  def parseExpression(input: String): Expression = {
+    // Initialize the Keywords.
+    lexical.initialize(reservedWords)
+    phrase(expression)(new lexical.Scanner(input)) match {
+      case Success(plan, _) => plan
+      case failureOrError => sys.error(failureOrError.toString)
+    }
+  }
+
   // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
   // properties via reflection the class in runtime for constructing the SqlLexical object
   protected val ABS = Keyword("ABS")

http://git-wip-us.apache.org/repos/asf/spark/blob/40c4cb2f/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 732b685..a4997fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -173,7 +173,7 @@ trait DataFrame extends RDDApi[Row] {
    * }}}
    */
   @scala.annotation.varargs
-  def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+  def sort(sortExprs: Column*): DataFrame
 
   /**
    * Returns a new [[DataFrame]] sorted by the given expressions.
@@ -187,7 +187,7 @@ trait DataFrame extends RDDApi[Row] {
    * This is an alias of the `sort` function.
    */
   @scala.annotation.varargs
-  def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
+  def orderBy(sortExprs: Column*): DataFrame
 
   /**
    * Selects column based on the column name and return it as a [[Column]].
@@ -237,6 +237,17 @@ trait DataFrame extends RDDApi[Row] {
   def select(col: String, cols: String*): DataFrame
 
   /**
+   * Selects a set of SQL expressions. This is a variant of `select` that accepts
+   * SQL expressions.
+   *
+   * {{{
+   *   df.selectExpr("colA", "colB as newName", "abs(colC)")
+   * }}}
+   */
+  @scala.annotation.varargs
+  def selectExpr(exprs: String*): DataFrame
+
+  /**
    * Filters rows using the given condition.
    * {{{
    *   // The following are equivalent:
@@ -248,6 +259,14 @@ trait DataFrame extends RDDApi[Row] {
   def filter(condition: Column): DataFrame
 
   /**
+   * Filters rows using the given SQL expression.
+   * {{{
+   *   peopleDf.filter("age > 15")
+   * }}}
+   */
+  def filter(conditionExpr: String): DataFrame
+
+  /**
    * Filters rows using the given condition. This is an alias for `filter`.
    * {{{
    *   // The following are equivalent:

http://git-wip-us.apache.org/repos/asf/spark/blob/40c4cb2f/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index a52bfa5..c702adc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -124,11 +124,11 @@ private[sql] class DataFrameImpl protected[sql](
   }
 
   override def sort(sortCol: String, sortCols: String*): DataFrame = {
-    orderBy(apply(sortCol), sortCols.map(apply) :_*)
+    sort((sortCol +: sortCols).map(apply) :_*)
   }
 
-  override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
-    val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
+  override def sort(sortExprs: Column*): DataFrame = {
+    val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
       col.expr match {
         case expr: SortOrder =>
           expr
@@ -143,8 +143,8 @@ private[sql] class DataFrameImpl protected[sql](
     sort(sortCol, sortCols :_*)
   }
 
-  override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
-    sort(sortExpr, sortExprs :_*)
+  override def orderBy(sortExprs: Column*): DataFrame = {
+    sort(sortExprs :_*)
   }
 
   override def col(colName: String): Column = colName match {
@@ -179,10 +179,20 @@ private[sql] class DataFrameImpl protected[sql](
     select((col +: cols).map(Column(_)) :_*)
   }
 
+  override def selectExpr(exprs: String*): DataFrame = {
+    select(exprs.map { expr =>
+      Column(new SqlParser().parseExpression(expr))
+    } :_*)
+  }
+
   override def filter(condition: Column): DataFrame = {
     Filter(condition.expr, logicalPlan)
   }
 
+  override def filter(conditionExpr: String): DataFrame = {
+    filter(Column(new SqlParser().parseExpression(conditionExpr)))
+  }
+
   override def where(condition: Column): DataFrame = {
     filter(condition)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/40c4cb2f/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index ba5c735..6b032d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -66,11 +66,11 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
 
   override def sort(sortCol: String, sortCols: String*): DataFrame = err()
 
-  override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err()
+  override def sort(sortExprs: Column*): DataFrame = err()
 
   override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
 
-  override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
+  override def orderBy(sortExprs: Column*): DataFrame = err()
 
   override def col(colName: String): Column = err()
 
@@ -80,8 +80,12 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
 
   override def select(col: String, cols: String*): DataFrame = err()
 
+  override def selectExpr(exprs: String*): DataFrame = err()
+
   override def filter(condition: Column): DataFrame = err()
 
+  override def filter(conditionExpr: String): DataFrame = err()
+
   override def where(condition: Column): DataFrame = err()
 
   override def apply(condition: Column): DataFrame = err()

http://git-wip-us.apache.org/repos/asf/spark/blob/40c4cb2f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 19d4f34..e588555 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -47,6 +47,18 @@ class DataFrameSuite extends QueryTest {
       testData.collect().toSeq)
   }
 
+  test("selectExpr") {
+    checkAnswer(
+      testData.selectExpr("abs(key)", "value"),
+      testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq)
+  }
+
+  test("filterExpr") {
+    checkAnswer(
+      testData.filter("key > 90"),
+      testData.collect().filter(_.getInt(0) > 90).toSeq)
+  }
+
   test("repartition") {
     checkAnswer(
       testData.select('key).repartition(10).select('key),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org