You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/30 07:41:18 UTC
[spark] branch master updated: [SPARK-41315][CONNECT][PYTHON] Implement `DataFrame.replace` and `DataFrame.na.replace`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5fc482eb591 [SPARK-41315][CONNECT][PYTHON] Implement `DataFrame.replace` and `DataFrame.na.replace`
5fc482eb591 is described below
commit 5fc482eb591a42f6dc1eb6ca7cffea72a4616f09
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Nov 30 15:40:53 2022 +0800
[SPARK-41315][CONNECT][PYTHON] Implement `DataFrame.replace` and `DataFrame.na.replace`
### What changes were proposed in this pull request?
Implement `DataFrame.replace` and `DataFrame.na.replace`
### Why are the changes needed?
for api coverage
### Does this PR introduce _any_ user-facing change?
yes, new api
### How was this patch tested?
added ut
Closes #38836 from zhengruifeng/connect_df_replace.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../main/protobuf/spark/connect/relations.proto | 30 ++++
.../org/apache/spark/sql/connect/dsl/package.scala | 25 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 32 ++++
.../connect/planner/SparkConnectProtoSuite.scala | 12 ++
python/pyspark/sql/connect/dataframe.py | 150 ++++++++++++++++-
python/pyspark/sql/connect/plan.py | 72 +++++++--
python/pyspark/sql/connect/proto/relations_pb2.py | 179 ++++++++++++---------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 93 +++++++++++
.../sql/tests/connect/test_connect_basic.py | 46 ++++++
.../sql/tests/connect/test_connect_plan_only.py | 27 ++++
10 files changed, 579 insertions(+), 87 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index d0ebce5ccab..8b87845245f 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -57,6 +57,7 @@ message Relation {
// NA functions
NAFill fill_na = 90;
NADrop drop_na = 91;
+ NAReplace replace = 92;
// stat functions
StatSummary summary = 100;
@@ -466,6 +467,35 @@ message NADrop {
}
+// Replaces old values with the corresponding values.
+// It will invoke 'Dataset.na.replace' (same as 'DataFrameNaFunctions.replace')
+// to compute the results.
+message NAReplace {
+ // (Required) The input relation.
+ Relation input = 1;
+
+ // (Optional) List of column names to consider.
+ //
+ // When it is empty, all the type-compatible columns in the input relation will be considered.
+ repeated string cols = 2;
+
+ // (Optional) The value replacement mapping.
+ repeated Replacement replacements = 3;
+
+ message Replacement {
+ // (Required) The old value.
+ //
+ // Only 4 data types are supported now: null, bool, double, string.
+ Expression.Literal old_value = 1;
+
+ // (Required) The new value.
+ //
+ // Should be of the same data type with the old value.
+ Expression.Literal new_value = 2;
+ }
+}
+
+
// Rename columns on the input relation by the same length of names.
message RenameColumnsBySameLengthNames {
// (Required) The input relation of RenameColumnsBySameLengthNames.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 42f59e112af..654a4d5ce20 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -311,6 +311,31 @@ package object dsl {
.setDropNa(dropna.build())
.build()
}
+
+ def replace(cols: Seq[String], replacement: Map[Any, Any]): Relation = {
+ require(cols.nonEmpty)
+
+ val replace = proto.NAReplace
+ .newBuilder()
+ .setInput(logicalPlan)
+
+ if (!(cols.length == 1 && cols.head == "*")) {
+ replace.addAllCols(cols.asJava)
+ }
+
+ replacement.foreach { case (oldValue, newValue) =>
+ replace.addReplacements(
+ proto.NAReplace.Replacement
+ .newBuilder()
+ .setOldValue(convertValue(oldValue))
+ .setNewValue(convertValue(newValue)))
+ }
+
+ Relation
+ .newBuilder()
+ .setReplace(replace.build())
+ .build()
+ }
}
implicit class DslStatFunctions(val logicalPlan: Relation) {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 2c492d973e4..7b9e13cadab 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -85,6 +85,7 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
+ case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace)
case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
@@ -232,6 +233,37 @@ class SparkConnectPlanner(session: SparkSession) {
}
}
+ private def transformReplace(rel: proto.NAReplace): LogicalPlan = {
+ def convert(value: proto.Expression.Literal): Any = {
+ value.getLiteralTypeCase match {
+ case proto.Expression.Literal.LiteralTypeCase.NULL => null
+ case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => value.getBoolean
+ case proto.Expression.Literal.LiteralTypeCase.DOUBLE => value.getDouble
+ case proto.Expression.Literal.LiteralTypeCase.STRING => value.getString
+ case other => throw InvalidPlanInput(s"Unsupported value type: $other")
+ }
+ }
+
+ val replacement = mutable.Map.empty[Any, Any]
+ rel.getReplacementsList.asScala.foreach { replace =>
+ replacement.update(convert(replace.getOldValue), convert(replace.getNewValue))
+ }
+
+ if (rel.getColsCount == 0) {
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .na
+ .replace("*", replacement.toMap)
+ .logicalPlan
+ } else {
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .na
+ .replace(rel.getColsList.asScala.toSeq, replacement.toMap)
+ .logicalPlan
+ }
+ }
+
private def transformStatSummary(rel: proto.StatSummary): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 9ed706e855a..86253b0016c 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -361,6 +361,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
sparkTestRelation.na.drop(minNonNulls = 1, cols = Seq("id", "name")))
}
+ test("SPARK-41315: Test replace") {
+ comparePlans(
+ connectTestRelation.na.replace(cols = Seq("id"), replacement = Map(1.0 -> 2.0)),
+ sparkTestRelation.na.replace(cols = Seq("id"), replacement = Map(1.0 -> 2.0)))
+ comparePlans(
+ connectTestRelation.na.replace(cols = Seq("name"), replacement = Map("a" -> "b")),
+ sparkTestRelation.na.replace(cols = Seq("name"), replacement = Map("a" -> "b")))
+ comparePlans(
+ connectTestRelation.na.replace(cols = Seq("*"), replacement = Map("a" -> "b")),
+ sparkTestRelation.na.replace(col = "*", replacement = Map("a" -> "b")))
+ }
+
test("Test summary") {
comparePlans(
connectTestRelation.summary("count", "mean", "stddev"),
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 034f410b1ad..c9960a71fb8 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -27,9 +27,12 @@ from typing import (
overload,
Callable,
cast,
+ Type,
)
import pandas
+import warnings
+from collections.abc import Iterable
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.readwriter import DataFrameWriter
@@ -39,9 +42,11 @@ from pyspark.sql.types import (
StructType,
Row,
)
+from pyspark import _NoValue
+from pyspark._globals import _NoValueType
if TYPE_CHECKING:
- from pyspark.sql.connect._typing import ColumnOrName, LiteralType
+ from pyspark.sql.connect._typing import ColumnOrName, LiteralType, OptionalPrimitiveType
from pyspark.sql.connect.session import SparkSession
@@ -1054,6 +1059,137 @@ class DataFrame(object):
session=self._session,
)
+ def replace(
+ self,
+ to_replace: Union[
+ "LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]
+ ],
+ value: Optional[
+ Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
+ ] = _NoValue,
+ subset: Optional[List[str]] = None,
+ ) -> "DataFrame":
+ """Returns a new :class:`DataFrame` replacing a value with another value.
+ :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
+ aliases of each other.
+ Values to_replace and value must have the same type and can only be numerics, booleans,
+ or strings. Value can have None. When replacing, the new value will be cast
+ to the type of the existing column.
+ For numeric replacements all values to be replaced should have unique
+ floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`)
+ and arbitrary replacement will be used.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ to_replace : bool, int, float, string, list or dict
+ Value to be replaced.
+ If the value is a dict, then `value` is ignored or can be omitted, and `to_replace`
+ must be a mapping between a value and a replacement.
+ value : bool, int, float, string or None, optional
+ The replacement value must be a bool, int, float, string or None. If `value` is a
+ list, `value` should be of the same length and type as `to_replace`.
+ If `value` is a scalar and `to_replace` is a sequence, then `value` is
+ used as a replacement for each item in `to_replace`.
+ subset : list, optional
+ optional list of column names to consider.
+ Columns specified in subset that do not have matching data type are ignored.
+ For example, if `value` is a string, and subset contains a non-string column,
+ then the non-string column is simply ignored.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ DataFrame with replaced values.
+ """
+ if value is _NoValue:
+ if isinstance(to_replace, dict):
+ value = None
+ else:
+ raise TypeError("value argument is required when to_replace is not a dictionary.")
+
+ # Helper functions
+ def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]:
+ """Given a type or tuple of types and a sequence of xs
+ check if each x is instance of type(s)
+
+ >>> all_of(bool)([True, False])
+ True
+ >>> all_of(str)(["a", 1])
+ False
+ """
+
+ def all_of_(xs: Iterable) -> bool:
+ return all(isinstance(x, types) for x in xs)
+
+ return all_of_
+
+ all_of_bool = all_of(bool)
+ all_of_str = all_of(str)
+ all_of_numeric = all_of((float, int))
+
+ # Validate input types
+ valid_types = (bool, float, int, str, list, tuple)
+ if not isinstance(to_replace, valid_types + (dict,)):
+ raise TypeError(
+ "to_replace should be a bool, float, int, string, list, tuple, or dict. "
+ "Got {0}".format(type(to_replace))
+ )
+
+ if (
+ not isinstance(value, valid_types)
+ and value is not None
+ and not isinstance(to_replace, dict)
+ ):
+ raise TypeError(
+ "If to_replace is not a dict, value should be "
+ "a bool, float, int, string, list, tuple or None. "
+ "Got {0}".format(type(value))
+ )
+
+ if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
+ if len(to_replace) != len(value):
+ raise ValueError(
+ "to_replace and value lists should be of the same length. "
+ "Got {0} and {1}".format(len(to_replace), len(value))
+ )
+
+ if not (subset is None or isinstance(subset, (list, tuple, str))):
+ raise TypeError(
+ "subset should be a list or tuple of column names, "
+ "column name or None. Got {0}".format(type(subset))
+ )
+
+ # Reshape input arguments if necessary
+ if isinstance(to_replace, (float, int, str)):
+ to_replace = [to_replace]
+
+ if isinstance(to_replace, dict):
+ rep_dict = to_replace
+ if value is not None:
+ warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
+ else:
+ if isinstance(value, (float, int, str)) or value is None:
+ value = [value for _ in range(len(to_replace))]
+ rep_dict = dict(zip(to_replace, cast("Iterable[Optional[Union[float, str]]]", value)))
+
+ if isinstance(subset, str):
+ subset = [subset]
+
+ # Verify we were not passed in mixed type generics.
+ if not any(
+ all_of_type(rep_dict.keys())
+ and all_of_type(x for x in rep_dict.values() if x is not None)
+ for all_of_type in [all_of_bool, all_of_str, all_of_numeric]
+ ):
+ raise ValueError("Mixed type replacements are not supported")
+
+ return DataFrame.withPlan(
+ plan.NAReplace(child=self._plan, cols=subset, replacements=rep_dict),
+ session=self._session,
+ )
+
@property
def stat(self) -> "DataFrameStatFunctions":
"""Returns a :class:`DataFrameStatFunctions` for statistic functions.
@@ -1520,6 +1656,18 @@ class DataFrameNaFunctions:
drop.__doc__ = DataFrame.dropna.__doc__
+ def replace(
+ self,
+ to_replace: Union[List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]],
+ value: Optional[
+ Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
+ ] = _NoValue,
+ subset: Optional[List[str]] = None,
+ ) -> DataFrame:
+ return self.df.replace(to_replace, value, subset)
+
+ replace.__doc__ = DataFrame.replace.__doc__
+
class DataFrameStatFunctions:
"""Functionality for statistic functions with :class:`DataFrame`.
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index e7bbad25e59..cc544819ad6 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -15,16 +15,7 @@
# limitations under the License.
#
-from typing import (
- Any,
- List,
- Optional,
- Sequence,
- Union,
- cast,
- TYPE_CHECKING,
- Mapping,
-)
+from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
import pandas
import pyarrow as pa
import pyspark.sql.connect.proto as proto
@@ -1095,6 +1086,67 @@ class NADrop(LogicalPlan):
"""
+class NAReplace(LogicalPlan):
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ cols: Optional[List[str]],
+ replacements: Dict[Any, Any],
+ ) -> None:
+ super().__init__(child)
+
+ for old_value, new_value in replacements.items():
+ if old_value is not None:
+ assert isinstance(old_value, (bool, int, float, str))
+ if new_value is not None:
+ assert isinstance(new_value, (bool, int, float, str))
+
+ self.cols = cols
+ self.replacements = replacements
+
+ def _convert_value(self, v: Any) -> proto.Expression.Literal:
+ value = proto.Expression.Literal()
+ if v is None:
+ value.null = True
+ elif isinstance(v, bool):
+ value.boolean = v
+ elif isinstance(v, (int, float)):
+ value.double = float(v)
+ else:
+ value.string = v
+ return value
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+ plan = proto.Relation()
+ plan.replace.input.CopyFrom(self._child.plan(session))
+ if self.cols is not None and len(self.cols) > 0:
+ plan.replace.cols.extend(self.cols)
+ if len(self.replacements) > 0:
+ for old_value, new_value in self.replacements.items():
+ replacement = proto.NAReplace.Replacement()
+ replacement.old_value.CopyFrom(self._convert_value(old_value))
+ replacement.new_value.CopyFrom(self._convert_value(new_value))
+ plan.replace.replacements.append(replacement)
+ return plan
+
+ def print(self, indent: int = 0) -> str:
+ i = " " * indent
+ return f"{i}" f"<NAReplace cols='{self.cols}' " f"replacements='{self.replacements}'>"
+
+ def _repr_html_(self) -> str:
+ return f"""
+ <ul>
+ <li>
+ <b>NADrop</b><br />
+ Cols: {self.cols} <br />
+ Replacements: {self.replacements} <br />
+ {self._child_repr_()}
+ </li>
+ </ul>
+ """
+
+
class StatSummary(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None:
super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 36ce4af222c..3b4a2d5de18 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xef\x0c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa5\r\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...]
)
@@ -67,6 +67,8 @@ _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"]
_STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"]
_NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
_NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
+_NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
+_NAREPLACE_REPLACEMENT = _NAREPLACE.nested_types_by_name["Replacement"]
_RENAMECOLUMNSBYSAMELENGTHNAMES = DESCRIPTOR.message_types_by_name["RenameColumnsBySameLengthNames"]
_RENAMECOLUMNSBYNAMETONAMEMAP = DESCRIPTOR.message_types_by_name["RenameColumnsByNameToNameMap"]
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = (
@@ -403,6 +405,27 @@ NADrop = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(NADrop)
+NAReplace = _reflection.GeneratedProtocolMessageType(
+ "NAReplace",
+ (_message.Message,),
+ {
+ "Replacement": _reflection.GeneratedProtocolMessageType(
+ "Replacement",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _NAREPLACE_REPLACEMENT,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.NAReplace.Replacement)
+ },
+ ),
+ "DESCRIPTOR": _NAREPLACE,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.NAReplace)
+ },
+)
+_sym_db.RegisterMessage(NAReplace)
+_sym_db.RegisterMessage(NAReplace.Replacement)
+
RenameColumnsBySameLengthNames = _reflection.GeneratedProtocolMessageType(
"RenameColumnsBySameLengthNames",
(_message.Message,),
@@ -455,79 +478,83 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 82
- _RELATION._serialized_end = 1729
- _UNKNOWN._serialized_start = 1731
- _UNKNOWN._serialized_end = 1740
- _RELATIONCOMMON._serialized_start = 1742
- _RELATIONCOMMON._serialized_end = 1791
- _SQL._serialized_start = 1793
- _SQL._serialized_end = 1820
- _READ._serialized_start = 1823
- _READ._serialized_end = 2249
- _READ_NAMEDTABLE._serialized_start = 1965
- _READ_NAMEDTABLE._serialized_end = 2026
- _READ_DATASOURCE._serialized_start = 2029
- _READ_DATASOURCE._serialized_end = 2236
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2167
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2225
- _PROJECT._serialized_start = 2251
- _PROJECT._serialized_end = 2368
- _FILTER._serialized_start = 2370
- _FILTER._serialized_end = 2482
- _JOIN._serialized_start = 2485
- _JOIN._serialized_end = 2935
- _JOIN_JOINTYPE._serialized_start = 2748
- _JOIN_JOINTYPE._serialized_end = 2935
- _SETOPERATION._serialized_start = 2938
- _SETOPERATION._serialized_end = 3334
- _SETOPERATION_SETOPTYPE._serialized_start = 3197
- _SETOPERATION_SETOPTYPE._serialized_end = 3311
- _LIMIT._serialized_start = 3336
- _LIMIT._serialized_end = 3412
- _OFFSET._serialized_start = 3414
- _OFFSET._serialized_end = 3493
- _TAIL._serialized_start = 3495
- _TAIL._serialized_end = 3570
- _AGGREGATE._serialized_start = 3573
- _AGGREGATE._serialized_end = 3783
- _SORT._serialized_start = 3786
- _SORT._serialized_end = 4336
- _SORT_SORTFIELD._serialized_start = 3940
- _SORT_SORTFIELD._serialized_end = 4128
- _SORT_SORTDIRECTION._serialized_start = 4130
- _SORT_SORTDIRECTION._serialized_end = 4238
- _SORT_SORTNULLS._serialized_start = 4240
- _SORT_SORTNULLS._serialized_end = 4322
- _DROP._serialized_start = 4338
- _DROP._serialized_end = 4438
- _DEDUPLICATE._serialized_start = 4441
- _DEDUPLICATE._serialized_end = 4612
- _LOCALRELATION._serialized_start = 4614
- _LOCALRELATION._serialized_end = 4649
- _SAMPLE._serialized_start = 4652
- _SAMPLE._serialized_end = 4876
- _RANGE._serialized_start = 4879
- _RANGE._serialized_end = 5024
- _SUBQUERYALIAS._serialized_start = 5026
- _SUBQUERYALIAS._serialized_end = 5140
- _REPARTITION._serialized_start = 5143
- _REPARTITION._serialized_end = 5285
- _SHOWSTRING._serialized_start = 5288
- _SHOWSTRING._serialized_end = 5429
- _STATSUMMARY._serialized_start = 5431
- _STATSUMMARY._serialized_end = 5523
- _STATCROSSTAB._serialized_start = 5525
- _STATCROSSTAB._serialized_end = 5626
- _NAFILL._serialized_start = 5629
- _NAFILL._serialized_end = 5763
- _NADROP._serialized_start = 5766
- _NADROP._serialized_end = 5900
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5902
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6016
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6019
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6278
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6211
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6278
- _WITHCOLUMNS._serialized_start = 6281
- _WITHCOLUMNS._serialized_end = 6412
+ _RELATION._serialized_end = 1783
+ _UNKNOWN._serialized_start = 1785
+ _UNKNOWN._serialized_end = 1794
+ _RELATIONCOMMON._serialized_start = 1796
+ _RELATIONCOMMON._serialized_end = 1845
+ _SQL._serialized_start = 1847
+ _SQL._serialized_end = 1874
+ _READ._serialized_start = 1877
+ _READ._serialized_end = 2303
+ _READ_NAMEDTABLE._serialized_start = 2019
+ _READ_NAMEDTABLE._serialized_end = 2080
+ _READ_DATASOURCE._serialized_start = 2083
+ _READ_DATASOURCE._serialized_end = 2290
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2221
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2279
+ _PROJECT._serialized_start = 2305
+ _PROJECT._serialized_end = 2422
+ _FILTER._serialized_start = 2424
+ _FILTER._serialized_end = 2536
+ _JOIN._serialized_start = 2539
+ _JOIN._serialized_end = 2989
+ _JOIN_JOINTYPE._serialized_start = 2802
+ _JOIN_JOINTYPE._serialized_end = 2989
+ _SETOPERATION._serialized_start = 2992
+ _SETOPERATION._serialized_end = 3388
+ _SETOPERATION_SETOPTYPE._serialized_start = 3251
+ _SETOPERATION_SETOPTYPE._serialized_end = 3365
+ _LIMIT._serialized_start = 3390
+ _LIMIT._serialized_end = 3466
+ _OFFSET._serialized_start = 3468
+ _OFFSET._serialized_end = 3547
+ _TAIL._serialized_start = 3549
+ _TAIL._serialized_end = 3624
+ _AGGREGATE._serialized_start = 3627
+ _AGGREGATE._serialized_end = 3837
+ _SORT._serialized_start = 3840
+ _SORT._serialized_end = 4390
+ _SORT_SORTFIELD._serialized_start = 3994
+ _SORT_SORTFIELD._serialized_end = 4182
+ _SORT_SORTDIRECTION._serialized_start = 4184
+ _SORT_SORTDIRECTION._serialized_end = 4292
+ _SORT_SORTNULLS._serialized_start = 4294
+ _SORT_SORTNULLS._serialized_end = 4376
+ _DROP._serialized_start = 4392
+ _DROP._serialized_end = 4492
+ _DEDUPLICATE._serialized_start = 4495
+ _DEDUPLICATE._serialized_end = 4666
+ _LOCALRELATION._serialized_start = 4668
+ _LOCALRELATION._serialized_end = 4703
+ _SAMPLE._serialized_start = 4706
+ _SAMPLE._serialized_end = 4930
+ _RANGE._serialized_start = 4933
+ _RANGE._serialized_end = 5078
+ _SUBQUERYALIAS._serialized_start = 5080
+ _SUBQUERYALIAS._serialized_end = 5194
+ _REPARTITION._serialized_start = 5197
+ _REPARTITION._serialized_end = 5339
+ _SHOWSTRING._serialized_start = 5342
+ _SHOWSTRING._serialized_end = 5483
+ _STATSUMMARY._serialized_start = 5485
+ _STATSUMMARY._serialized_end = 5577
+ _STATCROSSTAB._serialized_start = 5579
+ _STATCROSSTAB._serialized_end = 5680
+ _NAFILL._serialized_start = 5683
+ _NAFILL._serialized_end = 5817
+ _NADROP._serialized_start = 5820
+ _NADROP._serialized_end = 5954
+ _NAREPLACE._serialized_start = 5957
+ _NAREPLACE._serialized_end = 6253
+ _NAREPLACE_REPLACEMENT._serialized_start = 6112
+ _NAREPLACE_REPLACEMENT._serialized_end = 6253
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6255
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6369
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6372
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6631
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6564
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6631
+ _WITHCOLUMNS._serialized_start = 6634
+ _WITHCOLUMNS._serialized_end = 6765
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 84e0c520f36..df7b083103a 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -84,6 +84,7 @@ class Relation(google.protobuf.message.Message):
WITH_COLUMNS_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
+ REPLACE_FIELD_NUMBER: builtins.int
SUMMARY_FIELD_NUMBER: builtins.int
CROSSTAB_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@@ -139,6 +140,8 @@ class Relation(google.protobuf.message.Message):
@property
def drop_na(self) -> global___NADrop: ...
@property
+ def replace(self) -> global___NAReplace: ...
+ @property
def summary(self) -> global___StatSummary:
"""stat functions"""
@property
@@ -173,6 +176,7 @@ class Relation(google.protobuf.message.Message):
with_columns: global___WithColumns | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
+ replace: global___NAReplace | None = ...,
summary: global___StatSummary | None = ...,
crosstab: global___StatCrosstab | None = ...,
unknown: global___Unknown | None = ...,
@@ -218,6 +222,8 @@ class Relation(google.protobuf.message.Message):
b"rename_columns_by_same_length_names",
"repartition",
b"repartition",
+ "replace",
+ b"replace",
"sample",
b"sample",
"set_op",
@@ -281,6 +287,8 @@ class Relation(google.protobuf.message.Message):
b"rename_columns_by_same_length_names",
"repartition",
b"repartition",
+ "replace",
+ b"replace",
"sample",
b"sample",
"set_op",
@@ -330,6 +338,7 @@ class Relation(google.protobuf.message.Message):
"with_columns",
"fill_na",
"drop_na",
+ "replace",
"summary",
"crosstab",
"unknown",
@@ -1633,6 +1642,90 @@ class NADrop(google.protobuf.message.Message):
global___NADrop = NADrop
+class NAReplace(google.protobuf.message.Message):
+ """Replaces old values with the corresponding values.
+ It will invoke 'Dataset.na.replace' (same as 'DataFrameNaFunctions.replace')
+ to compute the results.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class Replacement(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ OLD_VALUE_FIELD_NUMBER: builtins.int
+ NEW_VALUE_FIELD_NUMBER: builtins.int
+ @property
+ def old_value(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression.Literal:
+ """(Required) The old value.
+
+ Only 4 data types are supported now: null, bool, double, string.
+ """
+ @property
+ def new_value(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression.Literal:
+ """(Required) The new value.
+
+ Should be of the same data type with the old value.
+ """
+ def __init__(
+ self,
+ *,
+ old_value: pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ...,
+ new_value: pyspark.sql.connect.proto.expressions_pb2.Expression.Literal | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "new_value", b"new_value", "old_value", b"old_value"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "new_value", b"new_value", "old_value", b"old_value"
+ ],
+ ) -> None: ...
+
+ INPUT_FIELD_NUMBER: builtins.int
+ COLS_FIELD_NUMBER: builtins.int
+ REPLACEMENTS_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ @property
+ def cols(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Optional) List of column names to consider.
+
+ When it is empty, all the type-compatible columns in the input relation will be considered.
+ """
+ @property
+ def replacements(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___NAReplace.Replacement
+ ]:
+ """(Optional) The value replacement mapping."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ cols: collections.abc.Iterable[builtins.str] | None = ...,
+ replacements: collections.abc.Iterable[global___NAReplace.Replacement] | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["input", b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "cols", b"cols", "input", b"input", "replacements", b"replacements"
+ ],
+ ) -> None: ...
+
+global___NAReplace = NAReplace
+
class RenameColumnsBySameLengthNames(google.protobuf.message.Message):
"""Rename columns on the input relation by the same length of names."""
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index c499e393e19..a13fe24d064 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -583,6 +583,52 @@ class SparkConnectTests(SparkConnectSQLTestCase):
self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
)
+ def test_replace(self):
+ # SPARK-41315: Test replace
+ query = """
+ SELECT * FROM VALUES
+ (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
+ AS tab(a, b, c)
+ """
+ # +-----+----+----+
+ # | a| b| c|
+ # +-----+----+----+
+ # |false| 1|null|
+ # |false|null| 2.0|
+ # | null| 3| 3.0|
+ # +-----+----+----+
+
+ self.assert_eq(
+ self.connect.sql(query).replace(2, 3).toPandas(),
+ self.spark.sql(query).replace(2, 3).toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).na.replace(False, True).toPandas(),
+ self.spark.sql(query).na.replace(False, True).toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
+ self.spark.sql(query).replace({1: 2, 3: -1}, subset=("a", "b")).toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
+ self.spark.sql(query).na.replace((1, 2), (3, 1)).toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
+ self.spark.sql(query).na.replace((1, 2), (3, 1), subset=("c", "b")).toPandas(),
+ )
+
+ with self.assertRaises(ValueError) as context:
+ self.connect.sql(query).replace({None: 1}, subset="a").toPandas()
+ self.assertTrue("Mixed type replacements are not supported" in str(context.exception))
+
+ with self.assertRaises(grpc.RpcError) as context:
+ self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas()
+ self.assertIn(
+ """Cannot resolve column name "x" among (a, b, c)""", str(context.exception)
+ )
+
def test_with_columns(self):
# SPARK-41256: test withColumn(s).
self.assert_eq(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 337917335e3..54933d5aa0a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -128,6 +128,33 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.drop_na.cols, ["col_c"])
self.assertEqual(plan.root.drop_na.min_non_nulls, 1)
+ def test_replace(self):
+ # SPARK-41315: Test replace
+ df = self.connect.readTable(table_name=self.tbl_name)
+
+ plan = df.replace(10, 20)._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.replace.cols, [])
+ self.assertEqual(plan.root.replace.replacements[0].old_value.double, 10.0)
+ self.assertEqual(plan.root.replace.replacements[0].new_value.double, 20.0)
+
+ plan = df.na.replace((1, 2, 3), (4, 5, 6), subset=("col_a", "col_b"))._plan.to_proto(
+ self.connect
+ )
+ self.assertEqual(plan.root.replace.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.replace.replacements[0].old_value.double, 1.0)
+ self.assertEqual(plan.root.replace.replacements[0].new_value.double, 4.0)
+ self.assertEqual(plan.root.replace.replacements[1].old_value.double, 2.0)
+ self.assertEqual(plan.root.replace.replacements[1].new_value.double, 5.0)
+ self.assertEqual(plan.root.replace.replacements[2].old_value.double, 3.0)
+ self.assertEqual(plan.root.replace.replacements[2].new_value.double, 6.0)
+
+ plan = df.replace(["Alice", "Bob"], ["A", "B"], subset="col_x")._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.replace.cols, ["col_x"])
+ self.assertEqual(plan.root.replace.replacements[0].old_value.string, "Alice")
+ self.assertEqual(plan.root.replace.replacements[0].new_value.string, "A")
+ self.assertEqual(plan.root.replace.replacements[1].old_value.string, "Bob")
+ self.assertEqual(plan.root.replace.replacements[1].new_value.string, "B")
+
def test_summary(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org