You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/29 05:27:51 UTC
[spark] branch master updated: [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 661064a7a38 [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop`
661064a7a38 is described below
commit 661064a7a3811da27da0d5b024764238d2a1fb3f
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Nov 29 13:27:33 2022 +0800
[SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop`
### What changes were proposed in this pull request?
Implement `DataFrame.dropna ` and `DataFrame.na.drop`
### Why are the changes needed?
For API coverage
### Does this PR introduce _any_ user-facing change?
yes, new method
### How was this patch tested?
added UT
Closes #38819 from zhengruifeng/connect_df_na_drop.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../main/protobuf/spark/connect/relations.proto | 25 ++++
.../org/apache/spark/sql/connect/dsl/package.scala | 33 +++++
.../sql/connect/planner/SparkConnectPlanner.scala | 18 +++
.../connect/planner/SparkConnectProtoSuite.scala | 19 +++
python/pyspark/sql/connect/dataframe.py | 81 +++++++++++
python/pyspark/sql/connect/plan.py | 39 +++++
python/pyspark/sql/connect/proto/relations_pb2.py | 158 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 77 ++++++++++
.../sql/tests/connect/test_connect_basic.py | 32 +++++
.../sql/tests/connect/test_connect_plan_only.py | 16 +++
10 files changed, 426 insertions(+), 72 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index a676871c9e0..cbdf6311657 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -55,6 +55,7 @@ message Relation {
// NA functions
NAFill fill_na = 90;
+ NADrop drop_na = 91;
// stat functions
StatSummary summary = 100;
@@ -440,6 +441,30 @@ message NAFill {
repeated Expression.Literal values = 3;
}
+
+// Drop rows containing null values.
+// It will invoke 'Dataset.na.drop' (same as 'DataFrameNaFunctions.drop') to compute the results.
+message NADrop {
+ // (Required) The input relation.
+ Relation input = 1;
+
+ // (Optional) Optional list of column names to consider.
+ //
+ // When it is empty, all the columns in the input relation will be considered.
+ repeated string cols = 2;
+
+ // (Optional) The minimum number of non-null and non-NaN values required to keep.
+ //
+ // When not set, it is equivalent to the number of considered columns, which means
+ // a row will be kept only if all columns are non-null.
+ //
+ // 'how' options ('all', 'any') can be easily converted to this field:
+ // - 'all' -> set 'min_non_nulls' 1;
+ // - 'any' -> keep 'min_non_nulls' unset;
+ optional int32 min_non_nulls = 3;
+}
+
+
// Rename columns on the input relation by the same length of names.
message RenameColumnsBySameLengthNames {
// (Required) The input relation of RenameColumnsBySameLengthNames.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 61d7abe9e15..dd1c7f0574b 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -278,6 +278,39 @@ package object dsl {
.build())
.build()
}
+
+ def drop(
+ how: Option[String] = None,
+ minNonNulls: Option[Int] = None,
+ cols: Seq[String] = Seq.empty): Relation = {
+ require(!(how.nonEmpty && minNonNulls.nonEmpty))
+ require(how.isEmpty || Seq("any", "all").contains(how.get))
+
+ val dropna = proto.NADrop
+ .newBuilder()
+ .setInput(logicalPlan)
+
+ if (cols.nonEmpty) {
+ dropna.addAllCols(cols.asJava)
+ }
+
+ var _minNonNulls = -1
+ how match {
+ case Some("all") => _minNonNulls = 1
+ case _ =>
+ }
+ if (minNonNulls.nonEmpty) {
+ _minNonNulls = minNonNulls.get
+ }
+ if (_minNonNulls > 0) {
+ dropna.setMinNonNulls(_minNonNulls)
+ }
+
+ Relation
+ .newBuilder()
+ .setDropNa(dropna.build())
+ .build()
+ }
}
implicit class DslStatFunctions(val logicalPlan: Relation) {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b4eaa03df5d..b2b6e6ffc54 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -83,6 +83,7 @@ class SparkConnectPlanner(session: SparkSession) {
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
+ case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
@@ -212,6 +213,23 @@ class SparkConnectPlanner(session: SparkSession) {
}
}
+ private def transformNADrop(rel: proto.NADrop): LogicalPlan = {
+ val dataset = Dataset.ofRows(session, transformRelation(rel.getInput))
+
+ val cols = rel.getColsList.asScala.toArray
+
+ (cols.nonEmpty, rel.hasMinNonNulls) match {
+ case (true, true) =>
+ dataset.na.drop(minNonNulls = rel.getMinNonNulls, cols = cols).logicalPlan
+ case (true, false) =>
+ dataset.na.drop(cols = cols).logicalPlan
+ case (false, true) =>
+ dataset.na.drop(minNonNulls = rel.getMinNonNulls).logicalPlan
+ case (false, false) =>
+ dataset.na.drop().logicalPlan
+ }
+ }
+
private def transformStatSummary(rel: proto.StatSummary): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 6c4297f5437..63bd3eccf17 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -342,6 +342,25 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
sparkTestRelation.na.fill(Map("id" -> 1L, "name" -> "xyz")))
}
+ test("SPARK-41148: Test drop na") {
+ comparePlans(connectTestRelation.na.drop(), sparkTestRelation.na.drop())
+ comparePlans(
+ connectTestRelation.na.drop(cols = Seq("id")),
+ sparkTestRelation.na.drop(cols = Seq("id")))
+ comparePlans(
+ connectTestRelation.na.drop(how = Some("all")),
+ sparkTestRelation.na.drop(how = "all"))
+ comparePlans(
+ connectTestRelation.na.drop(how = Some("all"), cols = Seq("id", "name")),
+ sparkTestRelation.na.drop(how = "all", cols = Seq("id", "name")))
+ comparePlans(
+ connectTestRelation.na.drop(minNonNulls = Some(1)),
+ sparkTestRelation.na.drop(minNonNulls = 1))
+ comparePlans(
+ connectTestRelation.na.drop(minNonNulls = Some(1), cols = Seq("id", "name")),
+ sparkTestRelation.na.drop(minNonNulls = 1, cols = Seq("id", "name")))
+ }
+
test("Test summary") {
comparePlans(
connectTestRelation.summary("count", "mean", "stddev"),
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index f157835f4a4..725c7fc90da 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -873,6 +873,77 @@ class DataFrame(object):
session=self._session,
)
+ def dropna(
+ self,
+ how: str = "any",
+ thresh: Optional[int] = None,
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+ ) -> "DataFrame":
+ """Returns a new :class:`DataFrame` omitting rows with null values.
+ :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ how : str, optional
+ 'any' or 'all'.
+ If 'any', drop a row if it contains any nulls.
+ If 'all', drop a row only if all its values are null.
+ thresh: int, optional
+ default None
+ If specified, drop rows that have less than `thresh` non-null values.
+ This overwrites the `how` parameter.
+ subset : str, tuple or list, optional
+ optional list of column names to consider.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ DataFrame with null only rows excluded.
+ """
+ min_non_nulls: Optional[int] = None
+
+ if how is not None:
+ if not isinstance(how, str):
+ raise TypeError(f"how should be a str, but got {type(how).__name__}")
+ if how == "all":
+ min_non_nulls = 1
+ elif how == "any":
+ min_non_nulls = None
+ else:
+ raise ValueError("how ('" + how + "') should be 'any' or 'all'")
+
+ if thresh is not None:
+ if not isinstance(thresh, int):
+ raise TypeError(f"thresh should be a int, but got {type(thresh).__name__}")
+
+ # 'thresh' overwrites 'how'
+ min_non_nulls = thresh
+
+ _cols: List[str] = []
+ if subset is not None:
+ if isinstance(subset, str):
+ _cols = [subset]
+ elif isinstance(subset, (tuple, list)):
+ for c in subset:
+ if not isinstance(c, str):
+ raise TypeError(
+ f"cols should be a str, tuple[str] or list[str], "
+ f"but got {type(c).__name__}"
+ )
+ _cols = list(subset)
+ else:
+ raise TypeError(
+ f"cols should be a str, tuple[str] or list[str], "
+ f"but got {type(subset).__name__}"
+ )
+
+ return DataFrame.withPlan(
+ plan.NADrop(child=self._plan, cols=_cols, min_non_nulls=min_non_nulls),
+ session=self._session,
+ )
+
@property
def stat(self) -> "DataFrameStatFunctions":
"""Returns a :class:`DataFrameStatFunctions` for statistic functions.
@@ -1312,6 +1383,16 @@ class DataFrameNaFunctions:
fill.__doc__ = DataFrame.fillna.__doc__
+ def drop(
+ self,
+ how: str = "any",
+ thresh: Optional[int] = None,
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+ ) -> DataFrame:
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
+
+ drop.__doc__ = DataFrame.dropna.__doc__
+
class DataFrameStatFunctions:
"""Functionality for statistic functions with :class:`DataFrame`.
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 805628cfe5b..0f611654ee5 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -995,6 +995,45 @@ class NAFill(LogicalPlan):
"""
+class NADrop(LogicalPlan):
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ cols: Optional[List[str]],
+ min_non_nulls: Optional[int],
+ ) -> None:
+ super().__init__(child)
+
+ self.cols = cols
+ self.min_non_nulls = min_non_nulls
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+ plan = proto.Relation()
+ plan.drop_na.input.CopyFrom(self._child.plan(session))
+ if self.cols is not None and len(self.cols) > 0:
+ plan.drop_na.cols.extend(self.cols)
+ if self.min_non_nulls is not None:
+ plan.drop_na.min_non_nulls = self.min_non_nulls
+ return plan
+
+ def print(self, indent: int = 0) -> str:
+ i = " " * indent
+ return f"{i}" f"<NADrop cols='{self.cols}' " f"min_non_nulls='{self.min_non_nulls}'>"
+
+ def _repr_html_(self) -> str:
+ return f"""
+ <ul>
+ <li>
+ <b>NADrop</b><br />
+ Cols: {self.cols} <br />
+ Min_non_nulls: {self.min_non_nulls} <br />
+ {self._child_repr_()}
+ </li>
+ </ul>
+ """
+
+
class StatSummary(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None:
super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 5ac2e8d7952..856fbe5b68b 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xfc\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xae\x0c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
)
@@ -66,6 +66,7 @@ _SHOWSTRING = DESCRIPTOR.message_types_by_name["ShowString"]
_STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"]
_STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"]
_NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
+_NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
_RENAMECOLUMNSBYSAMELENGTHNAMES = DESCRIPTOR.message_types_by_name["RenameColumnsBySameLengthNames"]
_RENAMECOLUMNSBYNAMETONAMEMAP = DESCRIPTOR.message_types_by_name["RenameColumnsByNameToNameMap"]
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = (
@@ -390,6 +391,17 @@ NAFill = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(NAFill)
+NADrop = _reflection.GeneratedProtocolMessageType(
+ "NADrop",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _NADROP,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.NADrop)
+ },
+)
+_sym_db.RegisterMessage(NADrop)
+
RenameColumnsBySameLengthNames = _reflection.GeneratedProtocolMessageType(
"RenameColumnsBySameLengthNames",
(_message.Message,),
@@ -431,75 +443,77 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 82
- _RELATION._serialized_end = 1614
- _UNKNOWN._serialized_start = 1616
- _UNKNOWN._serialized_end = 1625
- _RELATIONCOMMON._serialized_start = 1627
- _RELATIONCOMMON._serialized_end = 1676
- _SQL._serialized_start = 1678
- _SQL._serialized_end = 1705
- _READ._serialized_start = 1708
- _READ._serialized_end = 2134
- _READ_NAMEDTABLE._serialized_start = 1850
- _READ_NAMEDTABLE._serialized_end = 1911
- _READ_DATASOURCE._serialized_start = 1914
- _READ_DATASOURCE._serialized_end = 2121
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2052
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2110
- _PROJECT._serialized_start = 2136
- _PROJECT._serialized_end = 2253
- _FILTER._serialized_start = 2255
- _FILTER._serialized_end = 2367
- _JOIN._serialized_start = 2370
- _JOIN._serialized_end = 2820
- _JOIN_JOINTYPE._serialized_start = 2633
- _JOIN_JOINTYPE._serialized_end = 2820
- _SETOPERATION._serialized_start = 2823
- _SETOPERATION._serialized_end = 3219
- _SETOPERATION_SETOPTYPE._serialized_start = 3082
- _SETOPERATION_SETOPTYPE._serialized_end = 3196
- _LIMIT._serialized_start = 3221
- _LIMIT._serialized_end = 3297
- _OFFSET._serialized_start = 3299
- _OFFSET._serialized_end = 3378
- _TAIL._serialized_start = 3380
- _TAIL._serialized_end = 3455
- _AGGREGATE._serialized_start = 3458
- _AGGREGATE._serialized_end = 3668
- _SORT._serialized_start = 3671
- _SORT._serialized_end = 4221
- _SORT_SORTFIELD._serialized_start = 3825
- _SORT_SORTFIELD._serialized_end = 4013
- _SORT_SORTDIRECTION._serialized_start = 4015
- _SORT_SORTDIRECTION._serialized_end = 4123
- _SORT_SORTNULLS._serialized_start = 4125
- _SORT_SORTNULLS._serialized_end = 4207
- _DROP._serialized_start = 4223
- _DROP._serialized_end = 4323
- _DEDUPLICATE._serialized_start = 4326
- _DEDUPLICATE._serialized_end = 4497
- _LOCALRELATION._serialized_start = 4499
- _LOCALRELATION._serialized_end = 4534
- _SAMPLE._serialized_start = 4537
- _SAMPLE._serialized_end = 4761
- _RANGE._serialized_start = 4764
- _RANGE._serialized_end = 4909
- _SUBQUERYALIAS._serialized_start = 4911
- _SUBQUERYALIAS._serialized_end = 5025
- _REPARTITION._serialized_start = 5028
- _REPARTITION._serialized_end = 5170
- _SHOWSTRING._serialized_start = 5173
- _SHOWSTRING._serialized_end = 5314
- _STATSUMMARY._serialized_start = 5316
- _STATSUMMARY._serialized_end = 5408
- _STATCROSSTAB._serialized_start = 5410
- _STATCROSSTAB._serialized_end = 5511
- _NAFILL._serialized_start = 5514
- _NAFILL._serialized_end = 5648
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5650
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5764
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5767
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6026
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5959
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6026
+ _RELATION._serialized_end = 1664
+ _UNKNOWN._serialized_start = 1666
+ _UNKNOWN._serialized_end = 1675
+ _RELATIONCOMMON._serialized_start = 1677
+ _RELATIONCOMMON._serialized_end = 1726
+ _SQL._serialized_start = 1728
+ _SQL._serialized_end = 1755
+ _READ._serialized_start = 1758
+ _READ._serialized_end = 2184
+ _READ_NAMEDTABLE._serialized_start = 1900
+ _READ_NAMEDTABLE._serialized_end = 1961
+ _READ_DATASOURCE._serialized_start = 1964
+ _READ_DATASOURCE._serialized_end = 2171
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2102
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2160
+ _PROJECT._serialized_start = 2186
+ _PROJECT._serialized_end = 2303
+ _FILTER._serialized_start = 2305
+ _FILTER._serialized_end = 2417
+ _JOIN._serialized_start = 2420
+ _JOIN._serialized_end = 2870
+ _JOIN_JOINTYPE._serialized_start = 2683
+ _JOIN_JOINTYPE._serialized_end = 2870
+ _SETOPERATION._serialized_start = 2873
+ _SETOPERATION._serialized_end = 3269
+ _SETOPERATION_SETOPTYPE._serialized_start = 3132
+ _SETOPERATION_SETOPTYPE._serialized_end = 3246
+ _LIMIT._serialized_start = 3271
+ _LIMIT._serialized_end = 3347
+ _OFFSET._serialized_start = 3349
+ _OFFSET._serialized_end = 3428
+ _TAIL._serialized_start = 3430
+ _TAIL._serialized_end = 3505
+ _AGGREGATE._serialized_start = 3508
+ _AGGREGATE._serialized_end = 3718
+ _SORT._serialized_start = 3721
+ _SORT._serialized_end = 4271
+ _SORT_SORTFIELD._serialized_start = 3875
+ _SORT_SORTFIELD._serialized_end = 4063
+ _SORT_SORTDIRECTION._serialized_start = 4065
+ _SORT_SORTDIRECTION._serialized_end = 4173
+ _SORT_SORTNULLS._serialized_start = 4175
+ _SORT_SORTNULLS._serialized_end = 4257
+ _DROP._serialized_start = 4273
+ _DROP._serialized_end = 4373
+ _DEDUPLICATE._serialized_start = 4376
+ _DEDUPLICATE._serialized_end = 4547
+ _LOCALRELATION._serialized_start = 4549
+ _LOCALRELATION._serialized_end = 4584
+ _SAMPLE._serialized_start = 4587
+ _SAMPLE._serialized_end = 4811
+ _RANGE._serialized_start = 4814
+ _RANGE._serialized_end = 4959
+ _SUBQUERYALIAS._serialized_start = 4961
+ _SUBQUERYALIAS._serialized_end = 5075
+ _REPARTITION._serialized_start = 5078
+ _REPARTITION._serialized_end = 5220
+ _SHOWSTRING._serialized_start = 5223
+ _SHOWSTRING._serialized_end = 5364
+ _STATSUMMARY._serialized_start = 5366
+ _STATSUMMARY._serialized_end = 5458
+ _STATCROSSTAB._serialized_start = 5460
+ _STATCROSSTAB._serialized_end = 5561
+ _NAFILL._serialized_start = 5564
+ _NAFILL._serialized_end = 5698
+ _NADROP._serialized_start = 5701
+ _NADROP._serialized_end = 5835
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5837
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5951
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5954
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6213
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6146
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6213
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 6b05621744f..6832db56190 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -82,6 +82,7 @@ class Relation(google.protobuf.message.Message):
DROP_FIELD_NUMBER: builtins.int
TAIL_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
+ DROP_NA_FIELD_NUMBER: builtins.int
SUMMARY_FIELD_NUMBER: builtins.int
CROSSTAB_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@@ -133,6 +134,8 @@ class Relation(google.protobuf.message.Message):
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
+ def drop_na(self) -> global___NADrop: ...
+ @property
def summary(self) -> global___StatSummary:
"""stat functions"""
@property
@@ -165,6 +168,7 @@ class Relation(google.protobuf.message.Message):
drop: global___Drop | None = ...,
tail: global___Tail | None = ...,
fill_na: global___NAFill | None = ...,
+ drop_na: global___NADrop | None = ...,
summary: global___StatSummary | None = ...,
crosstab: global___StatCrosstab | None = ...,
unknown: global___Unknown | None = ...,
@@ -182,6 +186,8 @@ class Relation(google.protobuf.message.Message):
b"deduplicate",
"drop",
b"drop",
+ "drop_na",
+ b"drop_na",
"fill_na",
b"fill_na",
"filter",
@@ -241,6 +247,8 @@ class Relation(google.protobuf.message.Message):
b"deduplicate",
"drop",
b"drop",
+ "drop_na",
+ b"drop_na",
"fill_na",
b"fill_na",
"filter",
@@ -312,6 +320,7 @@ class Relation(google.protobuf.message.Message):
"drop",
"tail",
"fill_na",
+ "drop_na",
"summary",
"crosstab",
"unknown",
@@ -1547,6 +1556,74 @@ class NAFill(google.protobuf.message.Message):
global___NAFill = NAFill
+class NADrop(google.protobuf.message.Message):
+ """Drop rows containing null values.
+ It will invoke 'Dataset.na.drop' (same as 'DataFrameNaFunctions.drop') to compute the results.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ COLS_FIELD_NUMBER: builtins.int
+ MIN_NON_NULLS_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ @property
+ def cols(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Optional) Optional list of column names to consider.
+
+ When it is empty, all the columns in the input relation will be considered.
+ """
+ min_non_nulls: builtins.int
+ """(Optional) The minimum number of non-null and non-NaN values required to keep.
+
+ When not set, it is equivalent to the number of considered columns, which means
+ a row will be kept only if all columns are non-null.
+
+ 'how' options ('all', 'any') can be easily converted to this field:
+ - 'all' -> set 'min_non_nulls' 1;
+ - 'any' -> keep 'min_non_nulls' unset;
+ """
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ cols: collections.abc.Iterable[builtins.str] | None = ...,
+ min_non_nulls: builtins.int | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_min_non_nulls",
+ b"_min_non_nulls",
+ "input",
+ b"input",
+ "min_non_nulls",
+ b"min_non_nulls",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_min_non_nulls",
+ b"_min_non_nulls",
+ "cols",
+ b"cols",
+ "input",
+ b"input",
+ "min_non_nulls",
+ b"min_non_nulls",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_min_non_nulls", b"_min_non_nulls"]
+ ) -> typing_extensions.Literal["min_non_nulls"] | None: ...
+
+global___NADrop = NADrop
+
class RenameColumnsBySameLengthNames(google.protobuf.message.Message):
"""Rename columns on the input relation by the same length of names."""
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0fec48779ef..eb025fd5d04 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -498,6 +498,38 @@ class SparkConnectTests(SparkConnectSQLTestCase):
self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
)
+ def test_drop_na(self):
+ # SPARK-41148: Test drop na
+ query = """
+ SELECT * FROM VALUES
+ (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
+ AS tab(a, b, c)
+ """
+ # +-----+----+----+
+ # | a| b| c|
+ # +-----+----+----+
+ # |false| 1|null|
+ # |false|null| 2.0|
+ # | null| 3| 3.0|
+ # +-----+----+----+
+
+ self.assert_eq(
+ self.connect.sql(query).dropna().toPandas(),
+ self.spark.sql(query).dropna().toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(),
+ self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
+ self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
+ )
+ self.assert_eq(
+ self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
+ self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
+ )
+
def test_empty_dataset(self):
# SPARK-41005: Test arrow based collection with empty dataset.
self.assertTrue(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 109702af7e8..367c514d0a8 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -102,6 +102,22 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.fill_na.values[1].string, "abc")
self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b"])
+ def test_drop_na(self):
+ # SPARK-41148: Test drop na
+ df = self.connect.readTable(table_name=self.tbl_name)
+
+ plan = df.dropna()._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.drop_na.cols, [])
+ self.assertEqual(plan.root.drop_na.HasField("min_non_nulls"), False)
+
+ plan = df.na.drop(thresh=2, subset=("col_a", "col_b"))._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.drop_na.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.drop_na.min_non_nulls, 2)
+
+ plan = df.dropna(how="all", subset="col_c")._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.drop_na.cols, ["col_c"])
+ self.assertEqual(plan.root.drop_na.min_non_nulls, 1)
+
def test_summary(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org