You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/29 05:27:51 UTC

[spark] branch master updated: [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 661064a7a38 [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop`
661064a7a38 is described below

commit 661064a7a3811da27da0d5b024764238d2a1fb3f
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Nov 29 13:27:33 2022 +0800

    [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop`
    
    ### What changes were proposed in this pull request?
     Implement `DataFrame.dropna ` and `DataFrame.na.drop`
    
    ### Why are the changes needed?
    For API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new method
    
    ### How was this patch tested?
    added UT
    
    Closes #38819 from zhengruifeng/connect_df_na_drop.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  25 ++++
 .../org/apache/spark/sql/connect/dsl/package.scala |  33 +++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  18 +++
 .../connect/planner/SparkConnectProtoSuite.scala   |  19 +++
 python/pyspark/sql/connect/dataframe.py            |  81 +++++++++++
 python/pyspark/sql/connect/plan.py                 |  39 +++++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 158 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  77 ++++++++++
 .../sql/tests/connect/test_connect_basic.py        |  32 +++++
 .../sql/tests/connect/test_connect_plan_only.py    |  16 +++
 10 files changed, 426 insertions(+), 72 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index a676871c9e0..cbdf6311657 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -55,6 +55,7 @@ message Relation {
 
     // NA functions
     NAFill fill_na = 90;
+    NADrop drop_na = 91;
 
     // stat functions
     StatSummary summary = 100;
@@ -440,6 +441,30 @@ message NAFill {
   repeated Expression.Literal values = 3;
 }
 
+
+// Drop rows containing null values.
+// It will invoke 'Dataset.na.drop' (same as 'DataFrameNaFunctions.drop') to compute the results.
+message NADrop {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Optional) Optional list of column names to consider.
+  //
+  // When it is empty, all the columns in the input relation will be considered.
+  repeated string cols = 2;
+
+  // (Optional) The minimum number of non-null and non-NaN values required to keep.
+  //
+  // When not set, it is equivalent to the number of considered columns, which means
+  // a row will be kept only if all columns are non-null.
+  //
+  // 'how' options ('all', 'any') can be easily converted to this field:
+  //   - 'all' -> set 'min_non_nulls' 1;
+  //   - 'any' -> keep 'min_non_nulls' unset;
+  optional int32 min_non_nulls = 3;
+}
+
+
 // Rename columns on the input relation by the same length of names.
 message RenameColumnsBySameLengthNames {
   // (Required) The input relation of RenameColumnsBySameLengthNames.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 61d7abe9e15..dd1c7f0574b 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -278,6 +278,39 @@ package object dsl {
               .build())
           .build()
       }
+
+      def drop(
+          how: Option[String] = None,
+          minNonNulls: Option[Int] = None,
+          cols: Seq[String] = Seq.empty): Relation = {
+        require(!(how.nonEmpty && minNonNulls.nonEmpty))
+        require(how.isEmpty || Seq("any", "all").contains(how.get))
+
+        val dropna = proto.NADrop
+          .newBuilder()
+          .setInput(logicalPlan)
+
+        if (cols.nonEmpty) {
+          dropna.addAllCols(cols.asJava)
+        }
+
+        var _minNonNulls = -1
+        how match {
+          case Some("all") => _minNonNulls = 1
+          case _ =>
+        }
+        if (minNonNulls.nonEmpty) {
+          _minNonNulls = minNonNulls.get
+        }
+        if (_minNonNulls > 0) {
+          dropna.setMinNonNulls(_minNonNulls)
+        }
+
+        Relation
+          .newBuilder()
+          .setDropNa(dropna.build())
+          .build()
+      }
     }
 
     implicit class DslStatFunctions(val logicalPlan: Relation) {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b4eaa03df5d..b2b6e6ffc54 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -83,6 +83,7 @@ class SparkConnectPlanner(session: SparkSession) {
         transformSubqueryAlias(rel.getSubqueryAlias)
       case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
       case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
+      case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
       case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
       case proto.Relation.RelTypeCase.CROSSTAB =>
         transformStatCrosstab(rel.getCrosstab)
@@ -212,6 +213,23 @@ class SparkConnectPlanner(session: SparkSession) {
     }
   }
 
+  private def transformNADrop(rel: proto.NADrop): LogicalPlan = {
+    val dataset = Dataset.ofRows(session, transformRelation(rel.getInput))
+
+    val cols = rel.getColsList.asScala.toArray
+
+    (cols.nonEmpty, rel.hasMinNonNulls) match {
+      case (true, true) =>
+        dataset.na.drop(minNonNulls = rel.getMinNonNulls, cols = cols).logicalPlan
+      case (true, false) =>
+        dataset.na.drop(cols = cols).logicalPlan
+      case (false, true) =>
+        dataset.na.drop(minNonNulls = rel.getMinNonNulls).logicalPlan
+      case (false, false) =>
+        dataset.na.drop().logicalPlan
+    }
+  }
+
   private def transformStatSummary(rel: proto.StatSummary): LogicalPlan = {
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 6c4297f5437..63bd3eccf17 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -342,6 +342,25 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
       sparkTestRelation.na.fill(Map("id" -> 1L, "name" -> "xyz")))
   }
 
+  test("SPARK-41148: Test drop na") {
+    comparePlans(connectTestRelation.na.drop(), sparkTestRelation.na.drop())
+    comparePlans(
+      connectTestRelation.na.drop(cols = Seq("id")),
+      sparkTestRelation.na.drop(cols = Seq("id")))
+    comparePlans(
+      connectTestRelation.na.drop(how = Some("all")),
+      sparkTestRelation.na.drop(how = "all"))
+    comparePlans(
+      connectTestRelation.na.drop(how = Some("all"), cols = Seq("id", "name")),
+      sparkTestRelation.na.drop(how = "all", cols = Seq("id", "name")))
+    comparePlans(
+      connectTestRelation.na.drop(minNonNulls = Some(1)),
+      sparkTestRelation.na.drop(minNonNulls = 1))
+    comparePlans(
+      connectTestRelation.na.drop(minNonNulls = Some(1), cols = Seq("id", "name")),
+      sparkTestRelation.na.drop(minNonNulls = 1, cols = Seq("id", "name")))
+  }
+
   test("Test summary") {
     comparePlans(
       connectTestRelation.summary("count", "mean", "stddev"),
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index f157835f4a4..725c7fc90da 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -873,6 +873,77 @@ class DataFrame(object):
             session=self._session,
         )
 
+    def dropna(
+        self,
+        how: str = "any",
+        thresh: Optional[int] = None,
+        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+    ) -> "DataFrame":
+        """Returns a new :class:`DataFrame` omitting rows with null values.
+        :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        how : str, optional
+            'any' or 'all'.
+            If 'any', drop a row if it contains any nulls.
+            If 'all', drop a row only if all its values are null.
+        thresh: int, optional
+            default None
+            If specified, drop rows that have less than `thresh` non-null values.
+            This overwrites the `how` parameter.
+        subset : str, tuple or list, optional
+            optional list of column names to consider.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            DataFrame with null only rows excluded.
+        """
+        min_non_nulls: Optional[int] = None
+
+        if how is not None:
+            if not isinstance(how, str):
+                raise TypeError(f"how should be a str, but got {type(how).__name__}")
+            if how == "all":
+                min_non_nulls = 1
+            elif how == "any":
+                min_non_nulls = None
+            else:
+                raise ValueError("how ('" + how + "') should be 'any' or 'all'")
+
+        if thresh is not None:
+            if not isinstance(thresh, int):
+                raise TypeError(f"thresh should be a int, but got {type(thresh).__name__}")
+
+            # 'thresh' overwrites 'how'
+            min_non_nulls = thresh
+
+        _cols: List[str] = []
+        if subset is not None:
+            if isinstance(subset, str):
+                _cols = [subset]
+            elif isinstance(subset, (tuple, list)):
+                for c in subset:
+                    if not isinstance(c, str):
+                        raise TypeError(
+                            f"cols should be a str, tuple[str] or list[str], "
+                            f"but got {type(c).__name__}"
+                        )
+                _cols = list(subset)
+            else:
+                raise TypeError(
+                    f"cols should be a str, tuple[str] or list[str], "
+                    f"but got {type(subset).__name__}"
+                )
+
+        return DataFrame.withPlan(
+            plan.NADrop(child=self._plan, cols=_cols, min_non_nulls=min_non_nulls),
+            session=self._session,
+        )
+
     @property
     def stat(self) -> "DataFrameStatFunctions":
         """Returns a :class:`DataFrameStatFunctions` for statistic functions.
@@ -1312,6 +1383,16 @@ class DataFrameNaFunctions:
 
     fill.__doc__ = DataFrame.fillna.__doc__
 
+    def drop(
+        self,
+        how: str = "any",
+        thresh: Optional[int] = None,
+        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
+    ) -> DataFrame:
+        return self.df.dropna(how=how, thresh=thresh, subset=subset)
+
+    drop.__doc__ = DataFrame.dropna.__doc__
+
 
 class DataFrameStatFunctions:
     """Functionality for statistic functions with :class:`DataFrame`.
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 805628cfe5b..0f611654ee5 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -995,6 +995,45 @@ class NAFill(LogicalPlan):
         """
 
 
+class NADrop(LogicalPlan):
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        cols: Optional[List[str]],
+        min_non_nulls: Optional[int],
+    ) -> None:
+        super().__init__(child)
+
+        self.cols = cols
+        self.min_non_nulls = min_non_nulls
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+        plan = proto.Relation()
+        plan.drop_na.input.CopyFrom(self._child.plan(session))
+        if self.cols is not None and len(self.cols) > 0:
+            plan.drop_na.cols.extend(self.cols)
+        if self.min_non_nulls is not None:
+            plan.drop_na.min_non_nulls = self.min_non_nulls
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        i = " " * indent
+        return f"{i}" f"<NADrop cols='{self.cols}' " f"min_non_nulls='{self.min_non_nulls}'>"
+
+    def _repr_html_(self) -> str:
+        return f"""
+        <ul>
+           <li>
+              <b>NADrop</b><br />
+              Cols: {self.cols} <br />
+              Min_non_nulls: {self.min_non_nulls} <br />
+              {self._child_repr_()}
+           </li>
+        </ul>
+        """
+
+
 class StatSummary(LogicalPlan):
     def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None:
         super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 5ac2e8d7952..856fbe5b68b 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xfc\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xae\x0c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
 )
 
 
@@ -66,6 +66,7 @@ _SHOWSTRING = DESCRIPTOR.message_types_by_name["ShowString"]
 _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"]
 _STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"]
 _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
+_NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
 _RENAMECOLUMNSBYSAMELENGTHNAMES = DESCRIPTOR.message_types_by_name["RenameColumnsBySameLengthNames"]
 _RENAMECOLUMNSBYNAMETONAMEMAP = DESCRIPTOR.message_types_by_name["RenameColumnsByNameToNameMap"]
 _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = (
@@ -390,6 +391,17 @@ NAFill = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(NAFill)
 
+NADrop = _reflection.GeneratedProtocolMessageType(
+    "NADrop",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _NADROP,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.NADrop)
+    },
+)
+_sym_db.RegisterMessage(NADrop)
+
 RenameColumnsBySameLengthNames = _reflection.GeneratedProtocolMessageType(
     "RenameColumnsBySameLengthNames",
     (_message.Message,),
@@ -431,75 +443,77 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 82
-    _RELATION._serialized_end = 1614
-    _UNKNOWN._serialized_start = 1616
-    _UNKNOWN._serialized_end = 1625
-    _RELATIONCOMMON._serialized_start = 1627
-    _RELATIONCOMMON._serialized_end = 1676
-    _SQL._serialized_start = 1678
-    _SQL._serialized_end = 1705
-    _READ._serialized_start = 1708
-    _READ._serialized_end = 2134
-    _READ_NAMEDTABLE._serialized_start = 1850
-    _READ_NAMEDTABLE._serialized_end = 1911
-    _READ_DATASOURCE._serialized_start = 1914
-    _READ_DATASOURCE._serialized_end = 2121
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2052
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2110
-    _PROJECT._serialized_start = 2136
-    _PROJECT._serialized_end = 2253
-    _FILTER._serialized_start = 2255
-    _FILTER._serialized_end = 2367
-    _JOIN._serialized_start = 2370
-    _JOIN._serialized_end = 2820
-    _JOIN_JOINTYPE._serialized_start = 2633
-    _JOIN_JOINTYPE._serialized_end = 2820
-    _SETOPERATION._serialized_start = 2823
-    _SETOPERATION._serialized_end = 3219
-    _SETOPERATION_SETOPTYPE._serialized_start = 3082
-    _SETOPERATION_SETOPTYPE._serialized_end = 3196
-    _LIMIT._serialized_start = 3221
-    _LIMIT._serialized_end = 3297
-    _OFFSET._serialized_start = 3299
-    _OFFSET._serialized_end = 3378
-    _TAIL._serialized_start = 3380
-    _TAIL._serialized_end = 3455
-    _AGGREGATE._serialized_start = 3458
-    _AGGREGATE._serialized_end = 3668
-    _SORT._serialized_start = 3671
-    _SORT._serialized_end = 4221
-    _SORT_SORTFIELD._serialized_start = 3825
-    _SORT_SORTFIELD._serialized_end = 4013
-    _SORT_SORTDIRECTION._serialized_start = 4015
-    _SORT_SORTDIRECTION._serialized_end = 4123
-    _SORT_SORTNULLS._serialized_start = 4125
-    _SORT_SORTNULLS._serialized_end = 4207
-    _DROP._serialized_start = 4223
-    _DROP._serialized_end = 4323
-    _DEDUPLICATE._serialized_start = 4326
-    _DEDUPLICATE._serialized_end = 4497
-    _LOCALRELATION._serialized_start = 4499
-    _LOCALRELATION._serialized_end = 4534
-    _SAMPLE._serialized_start = 4537
-    _SAMPLE._serialized_end = 4761
-    _RANGE._serialized_start = 4764
-    _RANGE._serialized_end = 4909
-    _SUBQUERYALIAS._serialized_start = 4911
-    _SUBQUERYALIAS._serialized_end = 5025
-    _REPARTITION._serialized_start = 5028
-    _REPARTITION._serialized_end = 5170
-    _SHOWSTRING._serialized_start = 5173
-    _SHOWSTRING._serialized_end = 5314
-    _STATSUMMARY._serialized_start = 5316
-    _STATSUMMARY._serialized_end = 5408
-    _STATCROSSTAB._serialized_start = 5410
-    _STATCROSSTAB._serialized_end = 5511
-    _NAFILL._serialized_start = 5514
-    _NAFILL._serialized_end = 5648
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5650
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5764
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5767
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6026
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5959
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6026
+    _RELATION._serialized_end = 1664
+    _UNKNOWN._serialized_start = 1666
+    _UNKNOWN._serialized_end = 1675
+    _RELATIONCOMMON._serialized_start = 1677
+    _RELATIONCOMMON._serialized_end = 1726
+    _SQL._serialized_start = 1728
+    _SQL._serialized_end = 1755
+    _READ._serialized_start = 1758
+    _READ._serialized_end = 2184
+    _READ_NAMEDTABLE._serialized_start = 1900
+    _READ_NAMEDTABLE._serialized_end = 1961
+    _READ_DATASOURCE._serialized_start = 1964
+    _READ_DATASOURCE._serialized_end = 2171
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2102
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2160
+    _PROJECT._serialized_start = 2186
+    _PROJECT._serialized_end = 2303
+    _FILTER._serialized_start = 2305
+    _FILTER._serialized_end = 2417
+    _JOIN._serialized_start = 2420
+    _JOIN._serialized_end = 2870
+    _JOIN_JOINTYPE._serialized_start = 2683
+    _JOIN_JOINTYPE._serialized_end = 2870
+    _SETOPERATION._serialized_start = 2873
+    _SETOPERATION._serialized_end = 3269
+    _SETOPERATION_SETOPTYPE._serialized_start = 3132
+    _SETOPERATION_SETOPTYPE._serialized_end = 3246
+    _LIMIT._serialized_start = 3271
+    _LIMIT._serialized_end = 3347
+    _OFFSET._serialized_start = 3349
+    _OFFSET._serialized_end = 3428
+    _TAIL._serialized_start = 3430
+    _TAIL._serialized_end = 3505
+    _AGGREGATE._serialized_start = 3508
+    _AGGREGATE._serialized_end = 3718
+    _SORT._serialized_start = 3721
+    _SORT._serialized_end = 4271
+    _SORT_SORTFIELD._serialized_start = 3875
+    _SORT_SORTFIELD._serialized_end = 4063
+    _SORT_SORTDIRECTION._serialized_start = 4065
+    _SORT_SORTDIRECTION._serialized_end = 4173
+    _SORT_SORTNULLS._serialized_start = 4175
+    _SORT_SORTNULLS._serialized_end = 4257
+    _DROP._serialized_start = 4273
+    _DROP._serialized_end = 4373
+    _DEDUPLICATE._serialized_start = 4376
+    _DEDUPLICATE._serialized_end = 4547
+    _LOCALRELATION._serialized_start = 4549
+    _LOCALRELATION._serialized_end = 4584
+    _SAMPLE._serialized_start = 4587
+    _SAMPLE._serialized_end = 4811
+    _RANGE._serialized_start = 4814
+    _RANGE._serialized_end = 4959
+    _SUBQUERYALIAS._serialized_start = 4961
+    _SUBQUERYALIAS._serialized_end = 5075
+    _REPARTITION._serialized_start = 5078
+    _REPARTITION._serialized_end = 5220
+    _SHOWSTRING._serialized_start = 5223
+    _SHOWSTRING._serialized_end = 5364
+    _STATSUMMARY._serialized_start = 5366
+    _STATSUMMARY._serialized_end = 5458
+    _STATCROSSTAB._serialized_start = 5460
+    _STATCROSSTAB._serialized_end = 5561
+    _NAFILL._serialized_start = 5564
+    _NAFILL._serialized_end = 5698
+    _NADROP._serialized_start = 5701
+    _NADROP._serialized_end = 5835
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5837
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5951
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5954
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6213
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6146
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6213
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 6b05621744f..6832db56190 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -82,6 +82,7 @@ class Relation(google.protobuf.message.Message):
     DROP_FIELD_NUMBER: builtins.int
     TAIL_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
+    DROP_NA_FIELD_NUMBER: builtins.int
     SUMMARY_FIELD_NUMBER: builtins.int
     CROSSTAB_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
@@ -133,6 +134,8 @@ class Relation(google.protobuf.message.Message):
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
+    def drop_na(self) -> global___NADrop: ...
+    @property
     def summary(self) -> global___StatSummary:
         """stat functions"""
     @property
@@ -165,6 +168,7 @@ class Relation(google.protobuf.message.Message):
         drop: global___Drop | None = ...,
         tail: global___Tail | None = ...,
         fill_na: global___NAFill | None = ...,
+        drop_na: global___NADrop | None = ...,
         summary: global___StatSummary | None = ...,
         crosstab: global___StatCrosstab | None = ...,
         unknown: global___Unknown | None = ...,
@@ -182,6 +186,8 @@ class Relation(google.protobuf.message.Message):
             b"deduplicate",
             "drop",
             b"drop",
+            "drop_na",
+            b"drop_na",
             "fill_na",
             b"fill_na",
             "filter",
@@ -241,6 +247,8 @@ class Relation(google.protobuf.message.Message):
             b"deduplicate",
             "drop",
             b"drop",
+            "drop_na",
+            b"drop_na",
             "fill_na",
             b"fill_na",
             "filter",
@@ -312,6 +320,7 @@ class Relation(google.protobuf.message.Message):
         "drop",
         "tail",
         "fill_na",
+        "drop_na",
         "summary",
         "crosstab",
         "unknown",
@@ -1547,6 +1556,74 @@ class NAFill(google.protobuf.message.Message):
 
 global___NAFill = NAFill
 
+class NADrop(google.protobuf.message.Message):
+    """Drop rows containing null values.
+    It will invoke 'Dataset.na.drop' (same as 'DataFrameNaFunctions.drop') to compute the results.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COLS_FIELD_NUMBER: builtins.int
+    MIN_NON_NULLS_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    @property
+    def cols(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """(Optional) Optional list of column names to consider.
+
+        When it is empty, all the columns in the input relation will be considered.
+        """
+    min_non_nulls: builtins.int
+    """(Optional) The minimum number of non-null and non-NaN values required to keep.
+
+    When not set, it is equivalent to the number of considered columns, which means
+    a row will be kept only if all columns are non-null.
+
+    'how' options ('all', 'any') can be easily converted to this field:
+      - 'all' -> set 'min_non_nulls' 1;
+      - 'any' -> keep 'min_non_nulls' unset;
+    """
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        cols: collections.abc.Iterable[builtins.str] | None = ...,
+        min_non_nulls: builtins.int | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_min_non_nulls",
+            b"_min_non_nulls",
+            "input",
+            b"input",
+            "min_non_nulls",
+            b"min_non_nulls",
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_min_non_nulls",
+            b"_min_non_nulls",
+            "cols",
+            b"cols",
+            "input",
+            b"input",
+            "min_non_nulls",
+            b"min_non_nulls",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_min_non_nulls", b"_min_non_nulls"]
+    ) -> typing_extensions.Literal["min_non_nulls"] | None: ...
+
+global___NADrop = NADrop
+
 class RenameColumnsBySameLengthNames(google.protobuf.message.Message):
     """Rename columns on the input relation by the same length of names."""
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0fec48779ef..eb025fd5d04 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -498,6 +498,38 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(),
         )
 
+    def test_drop_na(self):
+        # SPARK-41148: Test drop na
+        query = """
+            SELECT * FROM VALUES
+            (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
+            AS tab(a, b, c)
+            """
+        # +-----+----+----+
+        # |    a|   b|   c|
+        # +-----+----+----+
+        # |false|   1|null|
+        # |false|null| 2.0|
+        # | null|   3| 3.0|
+        # +-----+----+----+
+
+        self.assert_eq(
+            self.connect.sql(query).dropna().toPandas(),
+            self.spark.sql(query).dropna().toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(),
+            self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
+            self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(),
+        )
+        self.assert_eq(
+            self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
+            self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(),
+        )
+
     def test_empty_dataset(self):
         # SPARK-41005: Test arrow based collection with empty dataset.
         self.assertTrue(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 109702af7e8..367c514d0a8 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -102,6 +102,22 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.fill_na.values[1].string, "abc")
         self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b"])
 
+    def test_drop_na(self):
+        # SPARK-41148: Test drop na
+        df = self.connect.readTable(table_name=self.tbl_name)
+
+        plan = df.dropna()._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.drop_na.cols, [])
+        self.assertEqual(plan.root.drop_na.HasField("min_non_nulls"), False)
+
+        plan = df.na.drop(thresh=2, subset=("col_a", "col_b"))._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.drop_na.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.drop_na.min_non_nulls, 2)
+
+        plan = df.dropna(how="all", subset="col_c")._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.drop_na.cols, ["col_c"])
+        self.assertEqual(plan.root.drop_na.min_non_nulls, 1)
+
     def test_summary(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org