You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/30 07:41:18 UTC

[spark] branch master updated: [SPARK-41315][CONNECT][PYTHON] Implement `DataFrame.replace` and `DataFrame.na.replace`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5fc482eb591 [SPARK-41315][CONNECT][PYTHON] Implement `DataFrame.replace` and `DataFrame.na.replace`
5fc482eb591 is described below

commit 5fc482eb591a42f6dc1eb6ca7cffea72a4616f09
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Nov 30 15:40:53 2022 +0800

    [SPARK-41315][CONNECT][PYTHON] Implement `DataFrame.replace` and `DataFrame.na.replace`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.replace` and `DataFrame.na.replace`
    
    ### Why are the changes needed?
    for api coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new api
    
    ### How was this patch tested?
    added ut
    
    Closes #38836 from zhengruifeng/connect_df_replace.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  30 ++++
 .../org/apache/spark/sql/connect/dsl/package.scala |  25 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  32 ++++
 .../connect/planner/SparkConnectProtoSuite.scala   |  12 ++
 python/pyspark/sql/connect/dataframe.py            | 150 ++++++++++++++++-
 python/pyspark/sql/connect/plan.py                 |  72 +++++++--
 python/pyspark/sql/connect/proto/relations_pb2.py  | 179 ++++++++++++---------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  93 +++++++++++
 .../sql/tests/connect/test_connect_basic.py        |  46 ++++++
 .../sql/tests/connect/test_connect_plan_only.py    |  27 ++++
 10 files changed, 579 insertions(+), 87 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index d0ebce5ccab..8b87845245f 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -57,6 +57,7 @@ message Relation {
     // NA functions
     NAFill fill_na = 90;
     NADrop drop_na = 91;
+    NAReplace replace = 92;
 
     // stat functions
     StatSummary summary = 100;
@@ -466,6 +467,35 @@ message NADrop {
 }
 
 
+// Replaces old values with the corresponding values.
+// It will invoke 'Dataset.na.replace' (same as 'DataFrameNaFunctions.replace')
+// to compute the results.
+message NAReplace {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Optional) List of column names to consider.
+  //
+  // When it is empty, all the type-compatible columns in the input relation will be considered.
+  repeated string cols = 2;
+
+  // (Optional) The value replacement mapping.
+  repeated Replacement replacements = 3;
+
+  message Replacement {
+    // (Required) The old value.
+    //
+    // Only 4 data types are supported now: null, bool, double, string.
+    Expression.Literal old_value = 1;
+
+    // (Required) The new value.
+    //
+    // Should be of the same data type with the old value.
+    Expression.Literal new_value = 2;
+  }
+}
+
+
 // Rename columns on the input relation by the same length of names.
 message RenameColumnsBySameLengthNames {
   // (Required) The input relation of RenameColumnsBySameLengthNames.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 42f59e112af..654a4d5ce20 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -311,6 +311,31 @@ package object dsl {
           .setDropNa(dropna.build())
           .build()
       }
+
+      def replace(cols: Seq[String], replacement: Map[Any, Any]): Relation = {
+        require(cols.nonEmpty)
+
+        val replace = proto.NAReplace
+          .newBuilder()
+          .setInput(logicalPlan)
+
+        if (!(cols.length == 1 && cols.head == "*")) {
+          replace.addAllCols(cols.asJava)
+        }
+
+        replacement.foreach { case (oldValue, newValue) =>
+          replace.addReplacements(
+            proto.NAReplace.Replacement
+              .newBuilder()
+              .setOldValue(convertValue(oldValue))
+              .setNewValue(convertValue(newValue)))
+        }
+
+        Relation
+          .newBuilder()
+          .setReplace(replace.build())
+          .build()
+      }
     }
 
     implicit class DslStatFunctions(val logicalPlan: Relation) {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 2c492d973e4..7b9e13cadab 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -85,6 +85,7 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
       case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
       case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
+      case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace)
       case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
       case proto.Relation.RelTypeCase.CROSSTAB =>
         transformStatCrosstab(rel.getCrosstab)
@@ -232,6 +233,37 @@ class SparkConnectPlanner(session: SparkSession) {
     }
   }
 
+  private def transformReplace(rel: proto.NAReplace): LogicalPlan = {
+    def convert(value: proto.Expression.Literal): Any = {
+      value.getLiteralTypeCase match {
+        case proto.Expression.Literal.LiteralTypeCase.NULL => null
+        case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => value.getBoolean
+        case proto.Expression.Literal.LiteralTypeCase.DOUBLE => value.getDouble
+        case proto.Expression.Literal.LiteralTypeCase.STRING => value.getString
+        case other => throw InvalidPlanInput(s"Unsupported value type: $other")
+      }
+    }
+
+    val replacement = mutable.Map.empty[Any, Any]
+    rel.getReplacementsList.asScala.foreach { replace =>
+      replacement.update(convert(replace.getOldValue), convert(replace.getNewValue))
+    }
+
+    if (rel.getColsCount == 0) {
+      Dataset
+        .ofRows(session, transformRelation(rel.getInput))
+        .na
+        .replace("*", replacement.toMap)
+        .logicalPlan
+    } else {
+      Dataset
+        .ofRows(session, transformRelation(rel.getInput))
+        .na
+        .replace(rel.getColsList.asScala.toSeq, replacement.toMap)
+        .logicalPlan
+    }
+  }
+
   private def transformStatSummary(rel: proto.StatSummary): LogicalPlan = {
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 9ed706e855a..86253b0016c 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -361,6 +361,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
       sparkTestRelation.na.drop(minNonNulls = 1, cols = Seq("id", "name")))
   }
 
+  test("SPARK-41315: Test replace") {
+    comparePlans(
+      connectTestRelation.na.replace(cols = Seq("id"), replacement = Map(1.0 -> 2.0)),
+      sparkTestRelation.na.replace(cols = Seq("id"), replacement = Map(1.0 -> 2.0)))
+    comparePlans(
+      connectTestRelation.na.replace(cols = Seq("name"), replacement = Map("a" -> "b")),
+      sparkTestRelation.na.replace(cols = Seq("name"), replacement = Map("a" -> "b")))
+    comparePlans(
+      connectTestRelation.na.replace(cols = Seq("*"), replacement = Map("a" -> "b")),
+      sparkTestRelation.na.replace(col = "*", replacement = Map("a" -> "b")))
+  }
+
   test("Test summary") {
     comparePlans(
       connectTestRelation.summary("count", "mean", "stddev"),
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 034f410b1ad..c9960a71fb8 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -27,9 +27,12 @@ from typing import (
     overload,
     Callable,
     cast,
+    Type,
 )
 
 import pandas
+import warnings
+from collections.abc import Iterable
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.readwriter import DataFrameWriter
@@ -39,9 +42,11 @@ from pyspark.sql.types import (
     StructType,
     Row,
 )
+from pyspark import _NoValue
+from pyspark._globals import _NoValueType
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect._typing import ColumnOrName, LiteralType
+    from pyspark.sql.connect._typing import ColumnOrName, LiteralType, OptionalPrimitiveType
     from pyspark.sql.connect.session import SparkSession
 
 
@@ -1054,6 +1059,137 @@ class DataFrame(object):
             session=self._session,
         )
 
+    def replace(
+        self,
+        to_replace: Union[
+            "LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]
+        ],
+        value: Optional[
+            Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
+        ] = _NoValue,
+        subset: Optional[List[str]] = None,
+    ) -> "DataFrame":
+        """Returns a new :class:`DataFrame` replacing a value with another value.
+        :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
+        aliases of each other.
+        Values to_replace and value must have the same type and can only be numerics, booleans,
+        or strings. Value can have None. When replacing, the new value will be cast
+        to the type of the existing column.
+        For numeric replacements all values to be replaced should have unique
+        floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`)
+        and arbitrary replacement will be used.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        to_replace : bool, int, float, string, list or dict
+            Value to be replaced.
+            If the value is a dict, then `value` is ignored or can be omitted, and `to_replace`
+            must be a mapping between a value and a replacement.
+        value : bool, int, float, string or None, optional
+            The replacement value must be a bool, int, float, string or None. If `value` is a
+            list, `value` should be of the same length and type as `to_replace`.
+            If `value` is a scalar and `to_replace` is a sequence, then `value` is
+            used as a replacement for each item in `to_replace`.
+        subset : list, optional
+            optional list of column names to consider.
+            Columns specified in subset that do not have matching data type are ignored.
+            For example, if `value` is a string, and subset contains a non-string column,
+            then the non-string column is simply ignored.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            DataFrame with replaced values.
+        """
+        if value is _NoValue:
+            if isinstance(to_replace, dict):
+                value = None
+            else:
+                raise TypeError("value argument is required when to_replace is not a dictionary.")
+
+        # Helper functions
+        def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]:
+            """Given a type or tuple of types and a sequence of xs
+            check if each x is instance of type(s)
+
+            >>> all_of(bool)([True, False])
+            True
+            >>> all_of(str)(["a", 1])
+            False
+            """
+
+            def all_of_(xs: Iterable) -> bool:
+                return all(isinstance(x, types) for x in xs)
+
+            return all_of_
+
+        all_of_bool = all_of(bool)
+        all_of_str = all_of(str)
+        all_of_numeric = all_of((float, int))
+
+        # Validate input types
+        valid_types = (bool, float, int, str, list, tuple)
+        if not isinstance(to_replace, valid_types + (dict,)):
+            raise TypeError(
+                "to_replace should be a bool, float, int, string, list, tuple, or dict. "
+                "Got {0}".format(type(to_replace))
+            )
+
+        if (
+            not isinstance(value, valid_types)
+            and value is not None
+            and not isinstance(to_replace, dict)
+        ):
+            raise TypeError(
+                "If to_replace is not a dict, value should be "
+                "a bool, float, int, string, list, tuple or None. "
+                "Got {0}".format(type(value))
+            )
+
+        if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
+            if len(to_replace) != len(value):
+                raise ValueError(
+                    "to_replace and value lists should be of the same length. "
+                    "Got {0} and {1}".format(len(to_replace), len(value))
+                )
+
+        if not (subset is None or isinstance(subset, (list, tuple, str))):
+            raise TypeError(
+                "subset should be a list or tuple of column names, "
+                "column name or None. Got {0}".format(type(subset))
+            )
+
+        # Reshape input arguments if necessary
+        if isinstance(to_replace, (float, int, str)):
+            to_replace = [to_replace]
+
+        if isinstance(to_replace, dict):
+            rep_dict = to_replace
+            if value is not None:
+                warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
+        else:
+            if isinstance(value, (float, int, str)) or value is None:
+                value = [value for _ in range(len(to_replace))]
+            rep_dict = dict(zip(to_replace, cast("Iterable[Optional[Union[float, str]]]", value)))
+
+        if isinstance(subset, str):
+            subset = [subset]
+
+        # Verify we were not passed in mixed type generics.
+        if not any(
+            all_of_type(rep_dict.keys())
+            and all_of_type(x for x in rep_dict.values() if x is not None)
+            for all_of_type in [all_of_bool, all_of_str, all_of_numeric]
+        ):
+            raise ValueError("Mixed type replacements are not supported")
+
+        return DataFrame.withPlan(
+            plan.NAReplace(child=self._plan, cols=subset, replacements=rep_dict),
+            session=self._session,
+        )
+
     @property
     def stat(self) -> "DataFrameStatFunctions":
         """Returns a :class:`DataFrameStatFunctions` for statistic functions.
@@ -1520,6 +1656,18 @@ class DataFrameNaFunctions:
 
     drop.__doc__ = DataFrame.dropna.__doc__
 
+    def replace(
+        self,
+        to_replace: Union[List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]],
+        value: Optional[
+            Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
+        ] = _NoValue,
+        subset: Optional[List[str]] = None,
+    ) -> DataFrame:
+        return self.df.replace(to_replace, value, subset)
+
+    replace.__doc__ = DataFrame.replace.__doc__
+
 
 class DataFrameStatFunctions:
     """Functionality for statistic functions with :class:`DataFrame`.
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index e7bbad25e59..cc544819ad6 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -15,16 +15,7 @@
 # limitations under the License.
 #
 
-from typing import (
-    Any,
-    List,
-    Optional,
-    Sequence,
-    Union,
-    cast,
-    TYPE_CHECKING,
-    Mapping,
-)
+from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
 import pandas
 import pyarrow as pa
 import pyspark.sql.connect.proto as proto
@@ -1095,6 +1086,67 @@ class NADrop(LogicalPlan):
         """
 
 
+class NAReplace(LogicalPlan):
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        cols: Optional[List[str]],
+        replacements: Dict[Any, Any],
+    ) -> None:
+        super().__init__(child)
+
+        for old_value, new_value in replacements.items():
+            if old_value is not None:
+                assert isinstance(old_value, (bool, int, float, str))
+            if new_value is not None:
+                assert isinstance(new_value, (bool, int, float, str))
+
+        self.cols = cols
+        self.replacements = replacements
+
+    def _convert_value(self, v: Any) -> proto.Expression.Literal:
+        value = proto.Expression.Literal()
+        if v is None:
+            value.null = True
+        elif isinstance(v, bool):
+            value.boolean = v
+        elif isinstance(v, (int, float)):
+            value.double = float(v)
+        else:
+            value.string = v
+        return value
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+        plan = proto.Relation()
+        plan.replace.input.CopyFrom(self._child.plan(session))
+        if self.cols is not None and len(self.cols) > 0:
+            plan.replace.cols.extend(self.cols)
+        if len(self.replacements) > 0:
+            for old_value, new_value in self.replacements.items():
+                replacement = proto.NAReplace.Replacement()
+                replacement.old_value.CopyFrom(self._convert_value(old_value))
+                replacement.new_value.CopyFrom(self._convert_value(new_value))
+                plan.replace.replacements.append(replacement)
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        i = " " * indent
+        return f"{i}" f"<NAReplace cols='{self.cols}' " f"replacements='{self.replacements}'>"
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+           <li>
+              <b>NADrop</b><br />
+              Cols: {self.cols} <br />
+              Replacements: {self.replacements} <br />
+              {self._child_repr_()}
+           </li>
+        </ul>
+        """
+
+
 class StatSummary(LogicalPlan):
     def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None:
         super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 36ce4af222c..3b4a2d5de18 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xef\x0c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa5\r\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...]
 )
 
 
@@ -67,6 +67,8 @@ _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"]
 _STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"]
 _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
 _NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
+_NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
+_NAREPLACE_REPLACEMENT = _NAREPLACE.nested_types_by_name["Replacement"]
 _RENAMECOLUMNSBYSAMELENGTHNAMES = DESCRIPTOR.message_types_by_name["RenameColumnsBySameLengthNames"]
 _RENAMECOLUMNSBYNAMETONAMEMAP = DESCRIPTOR.message_types_by_name["RenameColumnsByNameToNameMap"]
 _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = (
@@ -403,6 +405,27 @@ NADrop = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(NADrop)
 
+NAReplace = _reflection.GeneratedProtocolMessageType(
+    "NAReplace",
+    (_message.Message,),
+    {
+        "Replacement": _reflection.GeneratedProtocolMessageType(
+            "Replacement",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _NAREPLACE_REPLACEMENT,
+                "__module__": "spark.connect.relations_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.NAReplace.Replacement)
+            },
+        ),
+        "DESCRIPTOR": _NAREPLACE,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.NAReplace)
+    },
+)
+_sym_db.RegisterMessage(NAReplace)
+_sym_db.RegisterMessage(NAReplace.Replacement)
+
 RenameColumnsBySameLengthNames = _reflection.GeneratedProtocolMessageType(
     "RenameColumnsBySameLengthNames",
     (_message.Message,),
@@ -455,79 +478,83 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 82
-    _RELATION._serialized_end = 1729
-    _UNKNOWN._serialized_start = 1731
-    _UNKNOWN._serialized_end = 1740
-    _RELATIONCOMMON._serialized_start = 1742
-    _RELATIONCOMMON._serialized_end = 1791
-    _SQL._serialized_start = 1793
-    _SQL._serialized_end = 1820
-    _READ._serialized_start = 1823
-    _READ._serialized_end = 2249
-    _READ_NAMEDTABLE._serialized_start = 1965
-    _READ_NAMEDTABLE._serialized_end = 2026
-    _READ_DATASOURCE._serialized_start = 2029
-    _READ_DATASOURCE._serialized_end = 2236
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2167
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2225
-    _PROJECT._serialized_start = 2251
-    _PROJECT._serialized_end = 2368
-    _FILTER._serialized_start = 2370
-    _FILTER._serialized_end = 2482
-    _JOIN._serialized_start = 2485
-    _JOIN._serialized_end = 2935
-    _JOIN_JOINTYPE._serialized_start = 2748
-    _JOIN_JOINTYPE._serialized_end = 2935
-    _SETOPERATION._serialized_start = 2938
-    _SETOPERATION._serialized_end = 3334
-    _SETOPERATION_SETOPTYPE._serialized_start = 3197
-    _SETOPERATION_SETOPTYPE._serialized_end = 3311
-    _LIMIT._serialized_start = 3336
-    _LIMIT._serialized_end = 3412
-    _OFFSET._serialized_start = 3414
-    _OFFSET._serialized_end = 3493
-    _TAIL._serialized_start = 3495
-    _TAIL._serialized_end = 3570
-    _AGGREGATE._serialized_start = 3573
-    _AGGREGATE._serialized_end = 3783
-    _SORT._serialized_start = 3786
-    _SORT._serialized_end = 4336
-    _SORT_SORTFIELD._serialized_start = 3940
-    _SORT_SORTFIELD._serialized_end = 4128
-    _SORT_SORTDIRECTION._serialized_start = 4130
-    _SORT_SORTDIRECTION._serialized_end = 4238
-    _SORT_SORTNULLS._serialized_start = 4240
-    _SORT_SORTNULLS._serialized_end = 4322
-    _DROP._serialized_start = 4338
-    _DROP._serialized_end = 4438
-    _DEDUPLICATE._serialized_start = 4441
-    _DEDUPLICATE._serialized_end = 4612
-    _LOCALRELATION._serialized_start = 4614
-    _LOCALRELATION._serialized_end = 4649
-    _SAMPLE._serialized_start = 4652
-    _SAMPLE._serialized_end = 4876
-    _RANGE._serialized_start = 4879
-    _RANGE._serialized_end = 5024
-    _SUBQUERYALIAS._serialized_start = 5026
-    _SUBQUERYALIAS._serialized_end = 5140
-    _REPARTITION._serialized_start = 5143
-    _REPARTITION._serialized_end = 5285
-    _SHOWSTRING._serialized_start = 5288
-    _SHOWSTRING._serialized_end = 5429
-    _STATSUMMARY._serialized_start = 5431
-    _STATSUMMARY._serialized_end = 5523
-    _STATCROSSTAB._serialized_start = 5525
-    _STATCROSSTAB._serialized_end = 5626
-    _NAFILL._serialized_start = 5629
-    _NAFILL._serialized_end = 5763
-    _NADROP._serialized_start = 5766
-    _NADROP._serialized_end = 5900
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5902
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6016
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6019
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6278
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6211
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6278
-    _WITHCOLUMNS._serialized_start = 6281
-    _WITHCOLUMNS._serialized_end = 6412
+    _RELATION._serialized_end = 1783
+    _UNKNOWN._serialized_start = 1785
+    _UNKNOWN._serialized_end = 1794
+    _RELATIONCOMMON._serialized_start = 1796
+    _RELATIONCOMMON._serialized_end = 1845
+    _SQL._serialized_start = 1847
+    _SQL._serialized_end = 1874
+    _READ._serialized_start = 1877
+    _READ._serialized_end = 2303
+    _READ_NAMEDTABLE._serialized_start = 2019
+    _READ_NAMEDTABLE._serialized_end = 2080
+    _READ_DATASOURCE._serialized_start = 2083
+    _READ_DATASOURCE._serialized_end = 2290
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2221
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2279
+    _PROJECT._serialized_start = 2305
+    _PROJECT._serialized_end = 2422
+    _FILTER._serialized_start = 2424
+    _FILTER._serialized_end = 2536
+    _JOIN._serialized_start = 2539
+    _JOIN._serialized_end = 2989
+    _JOIN_JOINTYPE._serialized_start = 2802
+    _JOIN_JOINTYPE._serialized_end = 2989
+    _SETOPERATION._serialized_start = 2992
+    _SETOPERATION._serialized_end = 3388
+    _SETOPERATION_SETOPTYPE._serialized_start = 3251
+    _SETOPERATION_SETOPTYPE._serialized_end = 3365
+    _LIMIT._serialized_start = 3390
+    _LIMIT._serialized_end = 3466
+    _OFFSET._serialized_start = 3468
+    _OFFSET._serialized_end = 3547
+    _TAIL._serialized_start = 3549
+    _TAIL._serialized_end = 3624
+    _AGGREGATE._serialized_start = 3627
+    _AGGREGATE._serialized_end = 3837
+    _SORT._serialized_start = 3840
+    _SORT._serialized_end = 4390
+    _SORT_SORTFIELD._serialized_start = 3994
+    _SORT_SORTFIELD._serialized_end = 4182
+    _SORT_SORTDIRECTION._serialized_start = 4184
+    _SORT_SORTDIRECTION._serialized_end = 4292
+    _SORT_SORTNULLS._serialized_start = 4294
+    _SORT_SORTNULLS._serialized_end = 4376
+    _DROP._serialized_start = 4392
+    _DROP._serialized_end = 4492
+    _DEDUPLICATE._serialized_start = 4495
+    _DEDUPLICATE._serialized_end = 4666
+    _LOCALRELATION._serialized_start = 4668
+    _LOCALRELATION._serialized_end = 4703
+    _SAMPLE._serialized_start = 4706
+    _SAMPLE._serialized_end = 4930
+    _RANGE._serialized_start = 4933
+    _RANGE._serialized_end = 5078
+    _SUBQUERYALIAS._serialized_start = 5080
+    _SUBQUERYALIAS._serialized_end = 5194
+    _REPARTITION._serialized_start = 5197
+    _REPARTITION._serialized_end = 5339
+    _SHOWSTRING._serialized_start = 5342
+    _SHOWSTRING._serialized_end = 5483
+    _STATSUMMARY._serialized_start = 5485
+    _STATSUMMARY._serialized_end = 5577
+    _STATCROSSTAB._serialized_start = 5579
+    _STATCROSSTAB._serialized_end = 5680
+    _NAFILL._serialized_start = 5683
+    _NAFILL._serialized_end = 5817
+    _NADROP._serialized_start = 5820
+    _NADROP._serialized_end = 5954
+    _NAREPLACE._serialized_start = 5957
+    _NAREPLACE._serialized_end = 6253
+    _NAREPLACE_REPLACEMENT._serialized_start = 6112
+    _NAREPLACE_REPLACEMENT._serialized_end = 6253
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6255
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6369
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6372
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6631
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6564
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6631
+    _WITHCOLUMNS._serialized_start = 6634
+    _WITHCOLUMNS._serialized_end = 6765
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 84e0c520f36..df7b083103a 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -84,6 +84,7 @@ class Relation(google.protobuf.message.Message):
     WITH_COLUMNS_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
+    REPLACE_FIELD_NUMBER: builtins.int
     SUMMARY_FIELD_NUMBER: builtins.int
     CROSSTAB_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
@@ -139,6 +140,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def drop_na(self) -> global___NADrop: ...
     @property
+    def replace(self) -> global___NAReplace: ...
+    @property
     def summary(self) -> global___StatSummary:
         """stat functions"""
     @property
@@ -173,6 +176,7 @@ class Relation(google.protobuf.message.Message):
         with_columns: global___WithColumns | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
+        replace: global___NAReplace | None = ...,
         summary: global___StatSummary | None = ...,
         crosstab: global___StatCrosstab | None = ...,
         unknown: global___Unknown | None = ...,
@@ -218,6 +222,8 @@ class Relation(google.protobuf.message.Message):
             b"rename_columns_by_same_length_names",
             "repartition",
             b"repartition",
+            "replace",
+            b"replace",
             "sample",
             b"sample",
             "set_op",
@@ -281,6 +287,8 @@ class Relation(google.protobuf.message.Message):
             b"rename_columns_by_same_length_names",
             "repartition",
             b"repartition",
+            "replace",
+            b"replace",
             "sample",
             b"sample",
             "set_op",
@@ -330,6 +338,7 @@ class Relation(google.protobuf.message.Message):
         "with_columns",
         "fill_na",
         "drop_na",
+        "replace",
         "summary",
         "crosstab",
         "unknown",
@@ -1633,6 +1642,90 @@ class NADrop(google.protobuf.message.Message):
 
 global___NADrop = NADrop
 
+class NAReplace(google.protobuf.message.Message):
+    """Replaces old values with the corresponding values.
+    It will invoke 'Dataset.na.replace' (same as 'DataFrameNaFunctions.replace')
+    to compute the results.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class Replacement(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        OLD_VALUE_FIELD_NUMBER: builtins.int
+        NEW_VALUE_FIELD_NUMBER: builtins.int
+        @property
+        def old_value(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression.Literal:
+            """(Required) The old value.
+
+            Only 4 data types are supported now: null, bool, double, string.
+            """
+        @property
+        def new_value(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression.Literal:
+            """(Required) The new value.
+
+            Should be of the same data type with the old value.
+            """
+        def __init__(
+            self,
+            *,
+            old_value: pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ...,
+            new_value: pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ...,
+        ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "new_value", b"new_value", "old_value", b"old_value"
+            ],
+        ) -> builtins.bool: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "new_value", b"new_value", "old_value", b"old_value"
+            ],
+        ) -> None: ...
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COLS_FIELD_NUMBER: builtins.int
+    REPLACEMENTS_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    @property
+    def cols(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """(Optional) List of column names to consider.
+
+        When it is empty, all the type-compatible columns in the input relation will be considered.
+        """
+    @property
+    def replacements(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        global___NAReplace.Replacement
+    ]:
+        """(Optional) The value replacement mapping."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        cols: collections.abc.Iterable[builtins.str] | None = ...,
+        replacements: collections.abc.Iterable[global___NAReplace.Replacement] | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["input", b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "cols", b"cols", "input", b"input", "replacements", b"replacements"
+        ],
+    ) -> None: ...
+
+global___NAReplace = NAReplace
+
 class RenameColumnsBySameLengthNames(google.protobuf.message.Message):
     """Rename columns on the input relation by the same length of names."""
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index c499e393e19..a13fe24d064 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -583,6 +583,52 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
         )
 
+    def test_replace(self):
+        # SPARK-41315: Test replace
+        query = """
+            SELECT * FROM VALUES
+            (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
+            AS tab(a, b, c)
+            """
+        # +-----+----+----+
+        # |    a|   b|   c|
+        # +-----+----+----+
+        # |false|   1|null|
+        # |false|null| 2.0|
+        # | null|   3| 3.0|
+        # +-----+----+----+
+
+        self.assert_eq(
+            self.connect.sql(query).replace(2, 3).toPandas(),
+            self.spark.sql(query).replace(2, 3).toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).na.replace(False, True).toPandas(),
+            self.spark.sql(query).na.replace(False, True).toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
+            self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
+            self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
+            self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
+        )
+
+        with self.assertRaises(ValueError) as context:
+            self.connect.sql(query).replace({None: 1}, subset="a").toPandas()
+            self.assertTrue("Mixed type replacements are not supported" in str(context.exception))
+
+        with self.assertRaises(grpc.RpcError) as context:
+            self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas()
+            self.assertIn(
+                """Cannot resolve column name "x" among (a, b, c)""", str(context.exception)
+            )
+
     def test_with_columns(self):
         # SPARK-41256: test withColumn(s).
         self.assert_eq(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 337917335e3..54933d5aa0a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -128,6 +128,33 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.drop_na.cols, ["col_c"])
         self.assertEqual(plan.root.drop_na.min_non_nulls, 1)
 
+    def test_replace(self):
+        # SPARK-41315: Test replace
+        df = self.connect.readTable(table_name=self.tbl_name)
+
+        plan = df.replace(10, 20)._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.replace.cols, [])
+        self.assertEqual(plan.root.replace.replacements[0].old_value.double, 10.0)
+        self.assertEqual(plan.root.replace.replacements[0].new_value.double, 20.0)
+
+        plan = df.na.replace((1, 2, 3), (4, 5, 6), subset=("col_a", "col_b"))._plan.to_proto(
+            self.connect
+        )
+        self.assertEqual(plan.root.replace.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.replace.replacements[0].old_value.double, 1.0)
+        self.assertEqual(plan.root.replace.replacements[0].new_value.double, 4.0)
+        self.assertEqual(plan.root.replace.replacements[1].old_value.double, 2.0)
+        self.assertEqual(plan.root.replace.replacements[1].new_value.double, 5.0)
+        self.assertEqual(plan.root.replace.replacements[2].old_value.double, 3.0)
+        self.assertEqual(plan.root.replace.replacements[2].new_value.double, 6.0)
+
+        plan = df.replace(["Alice", "Bob"], ["A", "B"], subset="col_x")._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.replace.cols, ["col_x"])
+        self.assertEqual(plan.root.replace.replacements[0].old_value.string, "Alice")
+        self.assertEqual(plan.root.replace.replacements[0].new_value.string, "A")
+        self.assertEqual(plan.root.replace.replacements[1].old_value.string, "Bob")
+        self.assertEqual(plan.root.replace.replacements[1].new_value.string, "B")
+
     def test_summary(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org