You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/03/22 04:35:01 UTC

[GitHub] [spark] xinrong-databricks commented on a change in pull request #35888: [SPARK-38608][PYTHON] Implement `bool_only` parameter of `DataFrame.all` and`DataFrame.any`

xinrong-databricks commented on a change in pull request #35888:
URL: https://github.com/apache/spark/pull/35888#discussion_r831761305



##########
File path: python/pyspark/pandas/frame.py
##########
@@ -10195,32 +10190,87 @@ def any(self, axis: Axis = 0) -> "Series":
         col5    False
         col6     True
         dtype: bool
-        """
-        from pyspark.pandas.series import first_series
 
+        Include only boolean columns when set `bool_only=True`.
+
+        >>> df.any(bool_only=True)
+        col1    False
+        col2     True
+        dtype: bool
+        """
         axis = validate_axis(axis)
         if axis != 0:
             raise NotImplementedError('axis should be either 0 or "index" currently.')
 
-        applied = []
         column_labels = self._internal.column_labels
+        if bool_only:
+            column_labels = self._bool_column_labels(column_labels)
+        if len(column_labels) == 0:
+            return ps.Series([], dtype=bool)
+
+        applied = []
         for label in column_labels:
             scol = self._internal.spark_column_for(label)
-            all_col = F.max(F.coalesce(scol.cast("boolean"), SF.lit(False)))
-            applied.append(F.when(all_col.isNull(), False).otherwise(all_col))
+            any_col = F.max(F.coalesce(scol.cast("boolean"), SF.lit(False)))
+            applied.append(F.when(any_col.isNull(), False).otherwise(any_col))
+
+        return self._result_aggregated(column_labels, applied)
+
+    def _bool_column_labels(self, column_labels: List[Label]) -> List[Label]:
+        """
+        Filter column labels of boolean columns (without None).
+        """
+        bool_column_labels = []
+        for label in column_labels:
+            psser = self._psser_for(label)
+            if is_bool_dtype(psser):
+                # Rely on dtype rather than spark type because
+                # columns that consist of bools and Nones should be excluded
+                # if bool_only is True
+                bool_column_labels.append(label)
+        return bool_column_labels
+
+    def _result_aggregated(self, column_labels: List[Label], scols: List[Column]) -> ps.Series:

Review comment:
       Good catch! Thanks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org