You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/13 06:44:12 UTC

spark git commit: [SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise)

Repository: spark
Updated Branches:
  refs/heads/master 8fd55358b -> 97dee313f


[SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise)

This builds on https://github.com/apache/spark/pull/5932 and should close https://github.com/apache/spark/pull/5932 as well.

As an example:
```python
df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
```

Author: Reynold Xin <rx...@databricks.com>
Author: kaka1992 <ka...@163.com>

Closes #6072 from rxin/when-expr and squashes the following commits:

8f49201 [Reynold Xin] Throw exception if otherwise is applied twice.
0455eda [Reynold Xin] Reset run-tests.
bfb9d9f [Reynold Xin] Updated documentation and test cases.
762f6a5 [Reynold Xin] Merge pull request #5932 from kaka1992/IFCASE
95724c6 [kaka1992] Update
8218d0a [kaka1992] Update
801009e [kaka1992] Update
76d6346 [kaka1992] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/97dee313
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/97dee313
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/97dee313

Branch: refs/heads/master
Commit: 97dee313f23b00f15638cb72a4a80c1f197f8a9d
Parents: 8fd5535
Author: Reynold Xin <rx...@databricks.com>
Authored: Tue May 12 21:43:34 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue May 12 21:43:34 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/__init__.py                  |  2 +
 python/pyspark/sql/dataframe.py                 | 31 ++++++++++
 python/pyspark/sql/functions.py                 | 26 ++++++++-
 .../scala/org/apache/spark/sql/Column.scala     | 61 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 24 ++++++++
 .../spark/sql/ColumnExpressionSuite.scala       | 21 +++++++
 6 files changed, 163 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/97dee313/python/pyspark/sql/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index b60b991..7192c89 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -32,6 +32,8 @@ Important classes of Spark SQL and DataFrames:
       Aggregation methods, returned by :func:`DataFrame.groupBy`.
     - L{DataFrameNaFunctions}
       Methods for handling missing data (null values).
+    - L{DataFrameStatFunctions}
+      Methods for statistics functionality.
     - L{functions}
       List of built-in functions available for :class:`DataFrame`.
     - L{types}

http://git-wip-us.apache.org/repos/asf/spark/blob/97dee313/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 078acfd..82cb1c2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1546,6 +1546,37 @@ class Column(object):
         """
         return (self >= lowerBound) & (self <= upperBound)
 
+    @ignore_unicode_prefix
+    def when(self, condition, value):
+        """Evaluates a list of conditions and returns one of multiple possible result expressions.
+        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+        See :func:`pyspark.sql.functions.when` for example usage.
+
+        :param condition: a boolean :class:`Column` expression.
+        :param value: a literal value, or a :class:`Column` expression.
+
+        """
+        sc = SparkContext._active_spark_context
+        if not isinstance(condition, Column):
+            raise TypeError("condition should be a Column")
+        v = value._jc if isinstance(value, Column) else value
+        jc = sc._jvm.functions.when(condition._jc, v)
+        return Column(jc)
+
+    @ignore_unicode_prefix
+    def otherwise(self, value):
+        """Evaluates a list of conditions and returns one of multiple possible result expressions.
+        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+        See :func:`pyspark.sql.functions.when` for example usage.
+
+        :param value: a literal value, or a :class:`Column` expression.
+        """
+        v = value._jc if isinstance(value, Column) else value
+        jc = self._jc.otherwise(value)
+        return Column(jc)
+
     def __repr__(self):
         return 'Column<%s>' % self._jc.toString().encode('utf8')
 

http://git-wip-us.apache.org/repos/asf/spark/blob/97dee313/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 38a043a..d91265e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -32,13 +32,14 @@ from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
 
 __all__ = [
     'approxCountDistinct',
+    'coalesce',
     'countDistinct',
     'monotonicallyIncreasingId',
     'rand',
     'randn',
     'sparkPartitionId',
-    'coalesce',
-    'udf']
+    'udf',
+    'when']
 
 
 def _create_function(name, doc=""):
@@ -291,6 +292,27 @@ def struct(*cols):
     return Column(jc)
 
 
+def when(condition, value):
+    """Evaluates a list of conditions and returns one of multiple possible result expressions.
+    If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+    :param condition: a boolean :class:`Column` expression.
+    :param value: a literal value, or a :class:`Column` expression.
+
+    >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
+    [Row(age=3), Row(age=4)]
+
+    >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
+    [Row(age=3), Row(age=None)]
+    """
+    sc = SparkContext._active_spark_context
+    if not isinstance(condition, Column):
+        raise TypeError("condition should be a Column")
+    v = value._jc if isinstance(value, Column) else value
+    jc = sc._jvm.functions.when(condition._jc, v)
+    return Column(jc)
+
+
 class UserDefinedFunction(object):
     """
     User defined function in Python

http://git-wip-us.apache.org/repos/asf/spark/blob/97dee313/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 4773ded..42f5bcd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -328,6 +328,67 @@ class Column(protected[sql] val expr: Expression) extends Logging {
   def eqNullSafe(other: Any): Column = this <=> other
 
   /**
+   * Evaluates a list of conditions and returns one of multiple possible result expressions.
+   * If otherwise is not defined at the end, null is returned for unmatched conditions.
+   *
+   * {{{
+   *   // Example: encoding gender string column into integer.
+   *
+   *   // Scala:
+   *   people.select(when(people("gender") === "male", 0)
+   *     .when(people("gender") === "female", 1)
+   *     .otherwise(2))
+   *
+   *   // Java:
+   *   people.select(when(col("gender").equalTo("male"), 0)
+   *     .when(col("gender").equalTo("female"), 1)
+   *     .otherwise(2))
+   * }}}
+   *
+   * @group expr_ops
+   */
+  def when(condition: Column, value: Any):Column = this.expr match {
+    case CaseWhen(branches: Seq[Expression]) =>
+      CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr))
+    case _ =>
+      throw new IllegalArgumentException(
+        "when() can only be applied on a Column previously generated by when() function")
+  }
+
+  /**
+   * Evaluates a list of conditions and returns one of multiple possible result expressions.
+   * If otherwise is not defined at the end, null is returned for unmatched conditions.
+   *
+   * {{{
+   *   // Example: encoding gender string column into integer.
+   *
+   *   // Scala:
+   *   people.select(when(people("gender") === "male", 0)
+   *     .when(people("gender") === "female", 1)
+   *     .otherwise(2))
+   *
+   *   // Java:
+   *   people.select(when(col("gender").equalTo("male"), 0)
+   *     .when(col("gender").equalTo("female"), 1)
+   *     .otherwise(2))
+   * }}}
+   *
+   * @group expr_ops
+   */
+  def otherwise(value: Any):Column = this.expr match {
+    case CaseWhen(branches: Seq[Expression]) =>
+      if (branches.size % 2 == 0) {
+        CaseWhen(branches :+ lit(value).expr)
+      } else {
+        throw new IllegalArgumentException(
+          "otherwise() can only be applied once on a Column previously generated by when()")
+      }
+    case _ =>
+      throw new IllegalArgumentException(
+        "otherwise() can only be applied on a Column previously generated by when()")
+  }
+
+  /**
    * True if the current column is between the lower bound and upper bound, inclusive.
    *
    * @group java_expr_ops

http://git-wip-us.apache.org/repos/asf/spark/blob/97dee313/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 215787e..099e1d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -420,6 +420,30 @@ object functions {
   def not(e: Column): Column = !e
 
   /**
+   * Evaluates a list of conditions and returns one of multiple possible result expressions.
+   * If otherwise is not defined at the end, null is returned for unmatched conditions.
+   *
+   * {{{
+   *   // Example: encoding gender string column into integer.
+   *
+   *   // Scala:
+   *   people.select(when(people("gender") === "male", 0)
+   *     .when(people("gender") === "female", 1)
+   *     .otherwise(2))
+   *
+   *   // Java:
+   *   people.select(when(col("gender").equalTo("male"), 0)
+   *     .when(col("gender").equalTo("female"), 1)
+   *     .otherwise(2))
+   * }}}
+   *
+   * @group normal_funcs
+   */
+  def when(condition: Column, value: Any): Column = {
+    CaseWhen(Seq(condition.expr, lit(value).expr))
+  }
+
+  /**
    * Generate a random column with i.i.d. samples from U[0.0, 1.0].
    *
    * @group normal_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/97dee313/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index d96186c..269e185 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -255,6 +255,27 @@ class ColumnExpressionSuite extends QueryTest {
       Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
   }
 
+  test("SPARK-7321 when conditional statements") {
+    val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value")
+
+    checkAnswer(
+      testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)),
+      Seq(Row(-1), Row(-2), Row(0))
+    )
+
+    // Without the ending otherwise, return null for unmatched conditions.
+    // Also test putting a non-literal value in the expression.
+    checkAnswer(
+      testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)),
+      Seq(Row(-1), Row(-2), Row(null))
+    )
+
+    // Test error handling for invalid expressions.
+    intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) }
+    intercept[IllegalArgumentException] { $"key".otherwise(-1) }
+    intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) }
+  }
+
   test("sqrt") {
     checkAnswer(
       testData.select(sqrt('key)).orderBy('key.asc),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org