You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/30 04:14:10 UTC

[spark] branch master updated: [SPARK-41328][CONNECT][PYTHON] Add logical and string API to Column

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 40e6592e02c [SPARK-41328][CONNECT][PYTHON] Add logical and string API to Column
40e6592e02c is described below

commit 40e6592e02cbe679daec9e302e1027ffc64e7323
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Wed Nov 30 13:13:57 2022 +0900

    [SPARK-41328][CONNECT][PYTHON] Add logical and string API to Column
    
    ### What changes were proposed in this pull request?
    
    1. Upgrade `_typing.py` to use `Column`.
    2. Add logical operators (and, or, etc.) and strings (like, substr, etc.) to `Column`.
    3. Add basic tests for new API.
    
    ### Why are the changes needed?
    
    Improve API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #38844 from amaliujia/refactor_column_back_up_2.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/_typing.py              |   4 +-
 python/pyspark/sql/connect/column.py               | 339 ++++++++++++++++++++-
 python/pyspark/sql/connect/function_builder.py     |   6 +-
 .../sql/tests/connect/test_connect_basic.py        |  31 +-
 4 files changed, 367 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index 8629d1c23cc..e5ade4cfcbe 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -26,7 +26,7 @@ from typing import Union, Optional
 import datetime
 import decimal
 
-from pyspark.sql.connect.column import ScalarFunctionExpression, Column
+from pyspark.sql.connect.column import Column
 
 ColumnOrName = Union[Column, str]
 
@@ -42,7 +42,7 @@ DateTimeLiteral = Union[datetime.datetime, datetime.date]
 
 
 class FunctionBuilderCallable(Protocol):
-    def __call__(self, *_: ColumnOrName) -> ScalarFunctionExpression:
+    def __call__(self, *_: ColumnOrName) -> Column:
         ...
 
 
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 83e8b28da0f..c53d2c90bf6 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-from typing import get_args, TYPE_CHECKING, Callable, Any, Union
+from typing import get_args, TYPE_CHECKING, Callable, Any, Union, overload
 
 import json
 import decimal
@@ -29,6 +29,17 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
     import pyspark.sql.connect.proto as proto
 
+# TODO(SPARK-41329): solve the circular import between _typing and this class
+# if we want to reuse _type.PrimitiveType
+PrimitiveType = Union[bool, float, int, str]
+
+
+def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
+    def _(self: "Column") -> "Column":
+        return scalar_function(name, self)
+
+    return _
+
 
 def _bin_op(
     name: str, doc: str = "binary function", reverse: bool = False
@@ -219,6 +230,8 @@ class LiteralExpression(Expression):
                 else:
                     pair.value.CopyFrom(lit(value).to_plan(session).literal)
                 expr.literal.map.pairs.append(pair)
+        elif isinstance(self._value, Column):
+            expr.CopyFrom(self._value.to_plan(session))
         else:
             raise ValueError(f"Could not convert literal for type {type(self._value)}")
 
@@ -352,17 +365,326 @@ class Column(object):
     __rpow__ = _bin_op("pow", reverse=True)
     __ge__ = _bin_op(">=")
     __le__ = _bin_op("<=")
-    # __eq__ = _bin_op("==")  # ignore [assignment]
+
+    _eqNullSafe_doc = """
+        Equality test that is safe for null values.
+
+        Parameters
+        ----------
+        other
+            a value or :class:`Column`
+
+        Examples
+        --------
+        >>> from pyspark.sql import Row
+        >>> df1 = spark.createDataFrame([
+        ...     Row(id=1, value='foo'),
+        ...     Row(id=2, value=None)
+        ... ])
+        >>> df1.select(
+        ...     df1['value'] == 'foo',
+        ...     df1['value'].eqNullSafe('foo'),
+        ...     df1['value'].eqNullSafe(None)
+        ... ).show()
+        +-------------+---------------+----------------+
+        |(value = foo)|(value <=> foo)|(value <=> NULL)|
+        +-------------+---------------+----------------+
+        |         true|           true|           false|
+        |         null|          false|            true|
+        +-------------+---------------+----------------+
+        >>> df2 = spark.createDataFrame([
+        ...     Row(value = 'bar'),
+        ...     Row(value = None)
+        ... ])
+        >>> df1.join(df2, df1["value"] == df2["value"]).count()
+        0
+        >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()
+        1
+        >>> df2 = spark.createDataFrame([
+        ...     Row(id=1, value=float('NaN')),
+        ...     Row(id=2, value=42.0),
+        ...     Row(id=3, value=None)
+        ... ])
+        >>> df2.select(
+        ...     df2['value'].eqNullSafe(None),
+        ...     df2['value'].eqNullSafe(float('NaN')),
+        ...     df2['value'].eqNullSafe(42.0)
+        ... ).show()
+        +----------------+---------------+----------------+
+        |(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)|
+        +----------------+---------------+----------------+
+        |           false|           true|           false|
+        |           false|          false|            true|
+        |            true|          false|           false|
+        +----------------+---------------+----------------+
+        Notes
+        -----
+        Unlike Pandas, PySpark doesn't consider NaN values to be NULL. See the
+        `NaN Semantics <https://spark.apache.org/docs/latest/sql-ref-datatypes.html#nan-semantics>`_
+        for details.
+        """
+    eqNullSafe = _bin_op("eqNullSafe", _eqNullSafe_doc)
+
+    __neg__ = _func_op("negate")
+
+    # `and`, `or`, `not` cannot be overloaded in Python,
+    # so use bitwise operators as boolean operators
+    __and__ = _bin_op("and")
+    __or__ = _bin_op("or")
+    __invert__ = _func_op("not")
+    __rand__ = _bin_op("and")
+    __ror__ = _bin_op("or")
+
+    # bitwise operators
+    _bitwiseOR_doc = """
+        Compute bitwise OR of this expression with another expression.
+
+        Parameters
+        ----------
+        other
+            a value or :class:`Column` to calculate bitwise or(|) with
+            this :class:`Column`.
+
+        Examples
+        --------
+        >>> from pyspark.sql import Row
+        >>> df = spark.createDataFrame([Row(a=170, b=75)])
+        >>> df.select(df.a.bitwiseOR(df.b)).collect()
+        [Row((a | b)=235)]
+        """
+    _bitwiseAND_doc = """
+        Compute bitwise AND of this expression with another expression.
+
+        Parameters
+        ----------
+        other
+            a value or :class:`Column` to calculate bitwise and(&) with
+            this :class:`Column`.
+
+        Examples
+        --------
+        >>> from pyspark.sql import Row
+        >>> df = spark.createDataFrame([Row(a=170, b=75)])
+        >>> df.select(df.a.bitwiseAND(df.b)).collect()
+        [Row((a & b)=10)]
+        """
+    _bitwiseXOR_doc = """
+        Compute bitwise XOR of this expression with another expression.
+
+        Parameters
+        ----------
+        other
+            a value or :class:`Column` to calculate bitwise xor(^) with
+            this :class:`Column`.
+
+        Examples
+        --------
+        >>> from pyspark.sql import Row
+        >>> df = spark.createDataFrame([Row(a=170, b=75)])
+        >>> df.select(df.a.bitwiseXOR(df.b)).collect()
+        [Row((a ^ b)=225)]
+        """
+
+    bitwiseOR = _bin_op("bitwiseOR", _bitwiseOR_doc)
+    bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc)
+    bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc)
+
+    # string methods
+    def contains(self, other: Union[PrimitiveType, "Column"]) -> "Column":
+        """
+        Contains the other element. Returns a boolean :class:`Column` based on a string match.
+
+        Parameters
+        ----------
+        other
+            string in line. A value as a literal or a :class:`Column`.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.filter(df.name.contains('o')).collect()
+        [Row(age=5, name='Bob')]
+        """
+        return _bin_op("contains")(self, other)
+
+    def startswith(self, other: Union[PrimitiveType, "Column"]) -> "Column":
+        """
+        String starts with. Returns a boolean :class:`Column` based on a string match.
+
+        Parameters
+        ----------
+        other : :class:`Column` or str
+            string at start of line (do not use a regex `^`)
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.filter(df.name.startswith('Al')).collect()
+        [Row(age=2, name='Alice')]
+        >>> df.filter(df.name.startswith('^Al')).collect()
+        []
+        """
+        return _bin_op("startsWith")(self, other)
+
+    def endswith(self, other: Union[PrimitiveType, "Column"]) -> "Column":
+        """
+        String ends with. Returns a boolean :class:`Column` based on a string match.
+
+        Parameters
+        ----------
+        other : :class:`Column` or str
+            string at end of line (do not use a regex `$`)
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.filter(df.name.endswith('ice')).collect()
+        [Row(age=2, name='Alice')]
+        >>> df.filter(df.name.endswith('ice$')).collect()
+        []
+        """
+        return _bin_op("endsWith")(self, other)
+
+    def like(self: "Column", other: str) -> "Column":
+        """
+        SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match.
+
+        Parameters
+        ----------
+        other : str
+            a SQL LIKE pattern
+        See Also
+        --------
+        pyspark.sql.Column.rlike
+        Returns
+        -------
+        :class:`Column`
+            Column of booleans showing whether each element
+            in the Column is matched by SQL LIKE pattern.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.filter(df.name.like('Al%')).collect()
+        [Row(age=2, name='Alice')]
+        """
+        return _bin_op("like")(self, other)
+
+    def rlike(self: "Column", other: str) -> "Column":
+        """
+        SQL RLIKE expression (LIKE with Regex). Returns a boolean :class:`Column` based on a regex
+        match.
+
+        Parameters
+        ----------
+        other : str
+            an extended regex expression
+        Returns
+        -------
+        :class:`Column`
+            Column of booleans showing whether each element
+            in the Column is matched by extended regex expression.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.filter(df.name.rlike('ice$')).collect()
+        [Row(age=2, name='Alice')]
+        """
+        return _bin_op("like")(self, other)
+
+    def ilike(self: "Column", other: str) -> "Column":
+        """
+        SQL ILIKE expression (case insensitive LIKE). Returns a boolean :class:`Column`
+        based on a case insensitive match.
+
+        Parameters
+        ----------
+        other : str
+            a SQL LIKE pattern
+        See Also
+        --------
+        pyspark.sql.Column.rlike
+        Returns
+        -------
+        :class:`Column`
+            Column of booleans showing whether each element
+            in the Column is matched by SQL LIKE pattern.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.filter(df.name.ilike('%Ice')).collect()
+        [Row(age=2, name='Alice')]
+        """
+        return _bin_op("ilike")(self, other)
+
+    @overload
+    def substr(self, startPos: int, length: int) -> "Column":
+        ...
+
+    @overload
+    def substr(self, startPos: "Column", length: "Column") -> "Column":
+        ...
+
+    def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column":
+        """
+        Return a :class:`Column` which is a substring of the column.
+
+        Parameters
+        ----------
+        startPos : :class:`Column` or int
+            start position
+        length : :class:`Column` or int
+            length of the substring
+        Returns
+        -------
+        :class:`Column`
+            Column representing whether each element of Column is substr of origin Column.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.select(df.name.substr(1, 3).alias("col")).collect()
+        [Row(col='Ali'), Row(col='Bob')]
+        """
+        if type(startPos) != type(length):
+            raise TypeError(
+                "startPos and length must be the same type. "
+                "Got {startPos_t} and {length_t}, respectively.".format(
+                    startPos_t=type(startPos),
+                    length_t=type(length),
+                )
+            )
+        from pyspark.sql.connect.function_builder import functions as F
+
+        if isinstance(length, int):
+            length_exp = self._lit(length)
+        elif isinstance(length, Column):
+            length_exp = length
+        else:
+            raise TypeError("Unsupported type for substr().")
+
+        if isinstance(startPos, int):
+            start_exp = self._lit(startPos)
+        else:
+            start_exp = startPos
+
+        return F.substr(self, start_exp, length_exp)
 
     def __eq__(self, other: Any) -> "Column":  # type: ignore[override]
         """Returns a binary expression with the current column as the left
         side and the other expression as the right side.
         """
-        from pyspark.sql.connect._typing import PrimitiveType
-        from pyspark.sql.connect.functions import lit
-
         if isinstance(other, get_args(PrimitiveType)):
-            other = lit(other)
+            other = self._lit(other)
         return scalar_function("==", self, other)
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
@@ -380,5 +702,10 @@ class Column(object):
     def name(self) -> str:
         return self._expr.name()
 
+    # TODO(SPARK-41329): solve the circular import between functions.py and
+    # this class if we want to reuse functions.lit
+    def _lit(self, x: Any) -> "Column":
+        return Column(LiteralExpression(x))
+
     def __str__(self) -> str:
         return self._expr.__str__()
diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py
index b65348c6862..1edca287367 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
 
 
-def _build(name: str, *args: "ColumnOrName") -> ScalarFunctionExpression:
+def _build(name: str, *args: "ColumnOrName") -> Column:
     """
     Simple wrapper function that converts the arguments into the appropriate types.
     Parameters
@@ -46,14 +46,14 @@ def _build(name: str, *args: "ColumnOrName") -> ScalarFunctionExpression:
     :class:`ScalarFunctionExpression`
     """
     cols = [x if isinstance(x, Column) else col(x) for x in args]
-    return ScalarFunctionExpression(name, *cols)
+    return Column(ScalarFunctionExpression(name, *cols))
 
 
 class FunctionBuilder:
     """This class is used to build arbitrary functions used in expressions"""
 
     def __getattr__(self, name: str) -> "FunctionBuilderCallable":
-        def _(*args: "ColumnOrName") -> ScalarFunctionExpression:
+        def _(*args: "ColumnOrName") -> Column:
             return _build(name, *args)
 
         _.__doc__ = f"""Function to apply {name}"""
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 47a50a2cecb..c499e393e19 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -116,8 +116,35 @@ class SparkConnectTests(SparkConnectSQLTestCase):
 
     def test_columns(self):
         # SPARK-41036: test `columns` API for python client.
-        columns = self.connect.read.table(self.tbl_name).columns
-        self.assertEqual(["id", "name"], columns)
+        df = self.connect.read.table(self.tbl_name)
+        df2 = self.spark.read.table(self.tbl_name)
+        self.assertEqual(["id", "name"], df.columns)
+
+        self.assert_eq(
+            df.filter(df.name.rlike("20")).toPandas(), df2.filter(df2.name.rlike("20")).toPandas()
+        )
+        self.assert_eq(
+            df.filter(df.name.like("20")).toPandas(), df2.filter(df2.name.like("20")).toPandas()
+        )
+        self.assert_eq(
+            df.filter(df.name.ilike("20")).toPandas(), df2.filter(df2.name.ilike("20")).toPandas()
+        )
+        self.assert_eq(
+            df.filter(df.name.contains("20")).toPandas(),
+            df2.filter(df2.name.contains("20")).toPandas(),
+        )
+        self.assert_eq(
+            df.filter(df.name.startswith("2")).toPandas(),
+            df2.filter(df2.name.startswith("2")).toPandas(),
+        )
+        self.assert_eq(
+            df.filter(df.name.endswith("0")).toPandas(),
+            df2.filter(df2.name.endswith("0")).toPandas(),
+        )
+        self.assert_eq(
+            df.select(df.name.substr(0, 1).alias("col")).toPandas(),
+            df2.select(df2.name.substr(0, 1).alias("col")).toPandas(),
+        )
 
     def test_collect(self):
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org