You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/12/28 22:30:59 UTC

[GitHub] [spark] mengxr commented on a diff in pull request #37734: [SPARK-40264][ML] add batch_infer_udf function to pyspark.ml.functions

mengxr commented on code in PR #37734:
URL: https://github.com/apache/spark/pull/37734#discussion_r1003749352


##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +117,474 @@ def array_to_vector(col: Column) -> Column:
     return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        for _, batch in data.groupby(np.arange(len(data)) // batch_size):
+            yield batch
+    else:
+        # convert (tuple of) pd.Series into pd.DataFrame
+        if isinstance(data, pd.Series):
+            df = pd.concat((data,), axis=1)
+        else:  # isinstance(data, Tuple[pd.Series]):
+            df = pd.concat(data, axis=1)
+        for _, batch in df.groupby(np.arange(len(df)) // batch_size):
+            yield batch
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool:
+    """Check if input DataFrame contains any tensor-valued columns"""
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:  # isinstance(data, Tuple):
+        return any([d.dtype == np.object_ for d in data]) and any(
+            [isinstance(d.iloc[0], (np.ndarray, list)) for d in data]
+        )
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[
+        [],
+        Callable[
+            [np.ndarray | List[np.ndarray]],
+            np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, np.dtype]],
+        ],
+    ],
+    *,
+    return_type: DataType,
+    batch_size: int,
+    input_tensor_shapes: list[list[int] | None] | Mapping[int, list[int]] | None = None,
+) -> UserDefinedFunctionLike:
+    """Given a function which loads a model, returns a pandas_udf for inferencing over that model.

Review Comment:
   * `loads a model` -> `loads a model and makes batch predictions`
   * Pandas UDF



##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +117,474 @@ def array_to_vector(col: Column) -> Column:
     return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        for _, batch in data.groupby(np.arange(len(data)) // batch_size):
+            yield batch
+    else:
+        # convert (tuple of) pd.Series into pd.DataFrame
+        if isinstance(data, pd.Series):
+            df = pd.concat((data,), axis=1)
+        else:  # isinstance(data, Tuple[pd.Series]):
+            df = pd.concat(data, axis=1)
+        for _, batch in df.groupby(np.arange(len(df)) // batch_size):
+            yield batch
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool:
+    """Check if input DataFrame contains any tensor-valued columns"""
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:  # isinstance(data, Tuple):
+        return any([d.dtype == np.object_ for d in data]) and any(
+            [isinstance(d.iloc[0], (np.ndarray, list)) for d in data]
+        )
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[
+        [],
+        Callable[
+            [np.ndarray | List[np.ndarray]],

Review Comment:
   What is the scenario for `List[np.ndarray]`? If this is a single column of np arrays, we should expect a single np array after batch. If this is a single column of non-array typed values, we should expect `List[Any]`. If this is multiple columns, we should use `[np.ndarray, np.ndarray]`. We won't be able to enumerate all possible combinations. We can consider the following:
   
   ~~~
   [ndarray] | [List[Any]] | [ndarray, ndarray] | Any
   ~~~



##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +117,474 @@ def array_to_vector(col: Column) -> Column:
     return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        for _, batch in data.groupby(np.arange(len(data)) // batch_size):
+            yield batch
+    else:
+        # convert (tuple of) pd.Series into pd.DataFrame
+        if isinstance(data, pd.Series):
+            df = pd.concat((data,), axis=1)
+        else:  # isinstance(data, Tuple[pd.Series]):
+            df = pd.concat(data, axis=1)
+        for _, batch in df.groupby(np.arange(len(df)) // batch_size):
+            yield batch
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool:
+    """Check if input DataFrame contains any tensor-valued columns"""
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:  # isinstance(data, Tuple):
+        return any([d.dtype == np.object_ for d in data]) and any(
+            [isinstance(d.iloc[0], (np.ndarray, list)) for d in data]
+        )
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[
+        [],
+        Callable[
+            [np.ndarray | List[np.ndarray]],
+            np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, np.dtype]],
+        ],
+    ],
+    *,
+    return_type: DataType,
+    batch_size: int,
+    input_tensor_shapes: list[list[int] | None] | Mapping[int, list[int]] | None = None,
+) -> UserDefinedFunctionLike:
+    """Given a function which loads a model, returns a pandas_udf for inferencing over that model.
+
+    This will handle:
+    - conversion of the Spark DataFrame to numpy arrays.
+    - batching of the inputs sent to the model predict() function.
+    - caching of the model and prediction function on the executors.
+
+    This assumes that the `predict_batch_fn` encapsulates all of the necessary dependencies for
+    running the model or the Spark executor environment already satisfies all runtime requirements.
+
+    For the conversion of Spark DataFrame to numpy, the following table describes the behavior,
+    where tensor columns in the Spark DataFrame must be represented as a flattened 1-D array/list.
+
+    | dataframe \\ model | single input | multiple inputs |

Review Comment:
   I still find this table hard to understand. The conversion logic of the inputs is quite simple:
   
   There is 1:1 mapping between the input args of the `predict` method (returned by `make_predict_fn()`) and the input columns to the UDF (returned by `predict_batch_udf`) at runtime. We batch the input values and invoke `predict` on each batch.
   
   Then we only need to describe per-column conversion:
   
   * scalar column -> np.ndarray
   * tensor column + tensor shape -> np.ndarray
   * non-numeric column -> list
   
   and how multiple columns are handled.



##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +138,605 @@ def array_to_vector(col: Column) -> Column:
     return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        df = data
+    elif isinstance(data, pd.Series):
+        df = pd.concat((data,), axis=1)
+    else:  # isinstance(data, Tuple[pd.Series]):
+        df = pd.concat(data, axis=1)
+
+    index = 0
+    data_size = len(df)
+    while index < data_size:
+        yield df.iloc[index : index + batch_size]
+        index += batch_size
+
+
+def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool:
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:
+        raise ValueError(
+            "Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data))
+        )
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool:
+    """Check if input Series/DataFrame/Tuple contains any tensor-valued columns."""
+    if isinstance(data, (pd.Series, pd.DataFrame)):
+        return _is_tensor_col(data)
+    else:  # isinstance(data, Tuple):
+        return any(_is_tensor_col(elem) for elem in data)
+
+
+def _validate_and_transform_multiple_inputs(
+    batch: pd.DataFrame, input_shapes: List[List[int] | None], num_input_cols: int
+) -> List[np.ndarray]:
+    multi_inputs = [batch[col].to_numpy() for col in batch.columns]
+    if input_shapes:
+        if len(input_shapes) == num_input_cols:
+            multi_inputs = [
+                np.vstack(v).reshape([-1] + input_shapes[i])  # type: ignore
+                if input_shapes[i]
+                else v
+                for i, v in enumerate(multi_inputs)
+            ]
+            if not all([len(x) == len(batch) for x in multi_inputs]):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("input_tensor_shapes must match columns")
+
+    return multi_inputs
+
+
+def _validate_and_transform_single_input(
+    batch: pd.DataFrame,
+    input_shapes: List[List[int] | None],
+    has_tensors: bool,
+    has_tuple: bool,
+) -> np.ndarray:
+    # multiple input columns for single expected input
+    if has_tensors:
+        # tensor columns
+        if len(batch.columns) == 1:
+            # one tensor column and one expected input, vstack rows
+            single_input = np.vstack(batch.iloc[:, 0])
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into tensors."
+            )
+    else:
+        # scalar columns
+        if len(batch.columns) == 1:
+            # single scalar column, remove extra dim
+            single_input = np.squeeze(batch.to_numpy())
+            if input_shapes and input_shapes[0] not in [None, [], [1]]:
+                raise ValueError("Invalid input_tensor_shape for scalar column.")
+        elif not has_tuple:
+            # columns grouped via struct/array, convert to single tensor
+            single_input = batch.to_numpy()
+            if input_shapes and input_shapes[0] != [len(batch.columns)]:
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into tensors."
+            )
+
+    # if input_tensor_shapes provided, try to reshape input
+    if input_shapes:
+        if len(input_shapes) == 1:
+            single_input = single_input.reshape([-1] + input_shapes[0])  # type: ignore
+            if len(single_input) != len(batch):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("Multiple input_tensor_shapes found, but model expected one input")
+
+    return single_input
+
+
+def _validate_and_transform_prediction_result(
+    preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]],
+    num_input_rows: int,
+    return_type: DataType,
+) -> pd.DataFrame | pd.Series:
+    """Validate numpy-based model predictions against the expected pandas_udf return_type and
+    transforms the predictions into an equivalent pandas DataFrame or Series."""
+    if isinstance(return_type, StructType):
+        struct_rtype: StructType = return_type
+        fieldNames = struct_rtype.names
+        if isinstance(preds, dict):
+            # dictionary of columns
+            predNames = list(preds.keys())
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[field.name].shape) == 2:
+                        preds[field.name] = list(preds[field.name])
+                    else:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be two-dimensional."
+                        )
+                elif isinstance(field.dataType, supported_scalar_types):
+                    if len(preds[field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for scalar types must be one-dimensional."
+                        )
+                else:
+                    raise ValueError("Unsupported field type in return struct type.")
+
+                if len(preds[field.name]) != num_input_rows:
+                    raise ValueError("Prediction results must have same length as input data")
+
+        elif isinstance(preds, list) and isinstance(preds[0], dict):
+            # rows of dictionaries
+            predNames = list(preds[0].keys())
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as input data.")
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[0][field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be one-dimensional."
+                        )
+                elif isinstance(field.dataType, supported_scalar_types):
+                    if not np.isscalar(preds[0][field.name]):
+                        raise ValueError("Invalid scalar prediction result.")
+                else:
+                    raise ValueError("Unsupported field type in return struct type.")
+        else:
+            raise ValueError(
+                "Prediction results for StructType must be a dictionary or "
+                "a list of dictionary, got: {}".format(type(preds))
+            )
+
+        # check column names
+        if set(predNames) != set(fieldNames):
+            raise ValueError(
+                "Prediction result columns did not match expected return_type "
+                "columns: expected {}, got: {}".format(fieldNames, predNames)
+            )
+
+        return pd.DataFrame(preds)
+    elif isinstance(return_type, ArrayType):
+        if isinstance(preds, np.ndarray):
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as input data.")
+            if len(preds.shape) != 2:
+                raise ValueError("Prediction results for ArrayType must be two-dimensional.")
+        else:
+            raise ValueError("Prediction results for ArrayType must be an ndarray.")
+
+        return pd.Series(list(preds))
+    elif isinstance(return_type, supported_scalar_types):
+        preds_array: np.ndarray = preds  # type: ignore
+        if len(preds_array) != num_input_rows:
+            raise ValueError("Prediction results must have same length as input data.")
+        if not (
+            (len(preds_array.shape) == 2 and preds_array.shape[1] == 1)
+            or len(preds_array.shape) == 1
+        ):
+            raise ValueError("Invalid shape for scalar prediction result.")
+
+        return pd.Series(np.squeeze(preds))  # type: ignore
+    else:
+        raise ValueError("Unsupported return type")
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[

Review Comment:
   The name is a bit misleading because the value itself is not a method that predicts batches but a method that returns the actual `predict_batch_fn`. For example `model_fn` from `tf.Estimator` is a method that returns a `EstimatorSpec`, not another method. How about renaming to `make_predict_batch_fn` or `make_predict_fn`? I slightly prefer the latter because it is shorter and batch is inferred from the context. Feel free to suggest better names.



##########
python/pyspark/ml/model_cache.py:
##########
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+from threading import Lock
+from typing import Callable, Optional
+from uuid import UUID
+
+
+class ModelCache:
+    """Cache for model prediction functions on executors."""
+
+    _models: OrderedDict[UUID, Callable] = OrderedDict()
+    _capacity: int = 8

Review Comment:
   I'm a bit concerned about the cache usage. The default behavior shouldn't cause any memory leak, which is not the case here. Couple options:
   
   * Recommend using broadcast if users want to fast load/re-load models.
   * Make the caching behavior configurable (spark conf) and leave the default to no caching. And document the potential memory leak issues clearly.



##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +138,605 @@ def array_to_vector(col: Column) -> Column:
     return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        df = data
+    elif isinstance(data, pd.Series):
+        df = pd.concat((data,), axis=1)
+    else:  # isinstance(data, Tuple[pd.Series]):
+        df = pd.concat(data, axis=1)
+
+    index = 0
+    data_size = len(df)
+    while index < data_size:
+        yield df.iloc[index : index + batch_size]
+        index += batch_size
+
+
+def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool:
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:
+        raise ValueError(
+            "Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data))
+        )
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool:
+    """Check if input Series/DataFrame/Tuple contains any tensor-valued columns."""
+    if isinstance(data, (pd.Series, pd.DataFrame)):
+        return _is_tensor_col(data)
+    else:  # isinstance(data, Tuple):
+        return any(_is_tensor_col(elem) for elem in data)
+
+
+def _validate_and_transform_multiple_inputs(
+    batch: pd.DataFrame, input_shapes: List[List[int] | None], num_input_cols: int
+) -> List[np.ndarray]:
+    multi_inputs = [batch[col].to_numpy() for col in batch.columns]
+    if input_shapes:
+        if len(input_shapes) == num_input_cols:
+            multi_inputs = [
+                np.vstack(v).reshape([-1] + input_shapes[i])  # type: ignore
+                if input_shapes[i]
+                else v
+                for i, v in enumerate(multi_inputs)
+            ]
+            if not all([len(x) == len(batch) for x in multi_inputs]):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("input_tensor_shapes must match columns")
+
+    return multi_inputs
+
+
+def _validate_and_transform_single_input(
+    batch: pd.DataFrame,
+    input_shapes: List[List[int] | None],
+    has_tensors: bool,
+    has_tuple: bool,
+) -> np.ndarray:
+    # multiple input columns for single expected input
+    if has_tensors:
+        # tensor columns
+        if len(batch.columns) == 1:
+            # one tensor column and one expected input, vstack rows
+            single_input = np.vstack(batch.iloc[:, 0])
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into tensors."
+            )
+    else:
+        # scalar columns
+        if len(batch.columns) == 1:
+            # single scalar column, remove extra dim
+            single_input = np.squeeze(batch.to_numpy())
+            if input_shapes and input_shapes[0] not in [None, [], [1]]:
+                raise ValueError("Invalid input_tensor_shape for scalar column.")
+        elif not has_tuple:
+            # columns grouped via struct/array, convert to single tensor
+            single_input = batch.to_numpy()
+            if input_shapes and input_shapes[0] != [len(batch.columns)]:
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into tensors."
+            )
+
+    # if input_tensor_shapes provided, try to reshape input
+    if input_shapes:
+        if len(input_shapes) == 1:
+            single_input = single_input.reshape([-1] + input_shapes[0])  # type: ignore
+            if len(single_input) != len(batch):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("Multiple input_tensor_shapes found, but model expected one input")
+
+    return single_input
+
+
+def _validate_and_transform_prediction_result(
+    preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]],
+    num_input_rows: int,
+    return_type: DataType,
+) -> pd.DataFrame | pd.Series:
+    """Validate numpy-based model predictions against the expected pandas_udf return_type and
+    transforms the predictions into an equivalent pandas DataFrame or Series."""
+    if isinstance(return_type, StructType):
+        struct_rtype: StructType = return_type
+        fieldNames = struct_rtype.names
+        if isinstance(preds, dict):
+            # dictionary of columns
+            predNames = list(preds.keys())
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[field.name].shape) == 2:
+                        preds[field.name] = list(preds[field.name])
+                    else:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be two-dimensional."
+                        )
+                elif isinstance(field.dataType, supported_scalar_types):
+                    if len(preds[field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for scalar types must be one-dimensional."
+                        )
+                else:
+                    raise ValueError("Unsupported field type in return struct type.")
+
+                if len(preds[field.name]) != num_input_rows:
+                    raise ValueError("Prediction results must have same length as input data")
+
+        elif isinstance(preds, list) and isinstance(preds[0], dict):
+            # rows of dictionaries
+            predNames = list(preds[0].keys())
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as input data.")
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[0][field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be one-dimensional."
+                        )
+                elif isinstance(field.dataType, supported_scalar_types):
+                    if not np.isscalar(preds[0][field.name]):
+                        raise ValueError("Invalid scalar prediction result.")
+                else:
+                    raise ValueError("Unsupported field type in return struct type.")
+        else:
+            raise ValueError(
+                "Prediction results for StructType must be a dictionary or "
+                "a list of dictionary, got: {}".format(type(preds))
+            )
+
+        # check column names
+        if set(predNames) != set(fieldNames):
+            raise ValueError(
+                "Prediction result columns did not match expected return_type "
+                "columns: expected {}, got: {}".format(fieldNames, predNames)
+            )
+
+        return pd.DataFrame(preds)
+    elif isinstance(return_type, ArrayType):
+        if isinstance(preds, np.ndarray):
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as input data.")
+            if len(preds.shape) != 2:
+                raise ValueError("Prediction results for ArrayType must be two-dimensional.")
+        else:
+            raise ValueError("Prediction results for ArrayType must be an ndarray.")
+
+        return pd.Series(list(preds))
+    elif isinstance(return_type, supported_scalar_types):
+        preds_array: np.ndarray = preds  # type: ignore
+        if len(preds_array) != num_input_rows:
+            raise ValueError("Prediction results must have same length as input data.")
+        if not (
+            (len(preds_array.shape) == 2 and preds_array.shape[1] == 1)
+            or len(preds_array.shape) == 1
+        ):
+            raise ValueError("Invalid shape for scalar prediction result.")
+
+        return pd.Series(np.squeeze(preds))  # type: ignore
+    else:
+        raise ValueError("Unsupported return type")
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[
+        [],
+        Callable[
+            [np.ndarray | List[np.ndarray]],
+            np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, np.dtype]],
+        ],
+    ],
+    *,
+    return_type: DataType,
+    batch_size: int,
+    input_tensor_shapes: List[List[int] | None] | Mapping[int, List[int]] | None = None,
+) -> UserDefinedFunctionLike:
+    """Given a function which loads a model, returns a pandas_udf for inferencing over that model.
+
+    This will handle:

Review Comment:
   The returned UDF does the following on each DataFrame partition:
   - calls `make_precict_fn` to load the model and cache its `predict_fn`
   - batches the input records and invokes `predict_fn` on each batch



##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +138,605 @@ def array_to_vector(col: Column) -> Column:
     return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        df = data
+    elif isinstance(data, pd.Series):
+        df = pd.concat((data,), axis=1)
+    else:  # isinstance(data, Tuple[pd.Series]):
+        df = pd.concat(data, axis=1)
+
+    index = 0
+    data_size = len(df)
+    while index < data_size:
+        yield df.iloc[index : index + batch_size]
+        index += batch_size
+
+
+def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool:
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:
+        raise ValueError(
+            "Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data))
+        )
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool:
+    """Check if input Series/DataFrame/Tuple contains any tensor-valued columns."""
+    if isinstance(data, (pd.Series, pd.DataFrame)):
+        return _is_tensor_col(data)
+    else:  # isinstance(data, Tuple):
+        return any(_is_tensor_col(elem) for elem in data)
+
+
+def _validate_and_transform_multiple_inputs(
+    batch: pd.DataFrame, input_shapes: List[List[int] | None], num_input_cols: int
+) -> List[np.ndarray]:
+    multi_inputs = [batch[col].to_numpy() for col in batch.columns]
+    if input_shapes:
+        if len(input_shapes) == num_input_cols:
+            multi_inputs = [
+                np.vstack(v).reshape([-1] + input_shapes[i])  # type: ignore
+                if input_shapes[i]
+                else v
+                for i, v in enumerate(multi_inputs)
+            ]
+            if not all([len(x) == len(batch) for x in multi_inputs]):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("input_tensor_shapes must match columns")
+
+    return multi_inputs
+
+
+def _validate_and_transform_single_input(
+    batch: pd.DataFrame,
+    input_shapes: List[List[int] | None],
+    has_tensors: bool,
+    has_tuple: bool,
+) -> np.ndarray:
+    # multiple input columns for single expected input
+    if has_tensors:
+        # tensor columns
+        if len(batch.columns) == 1:
+            # one tensor column and one expected input, vstack rows
+            single_input = np.vstack(batch.iloc[:, 0])
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into tensors."
+            )
+    else:
+        # scalar columns
+        if len(batch.columns) == 1:
+            # single scalar column, remove extra dim
+            single_input = np.squeeze(batch.to_numpy())
+            if input_shapes and input_shapes[0] not in [None, [], [1]]:
+                raise ValueError("Invalid input_tensor_shape for scalar column.")
+        elif not has_tuple:
+            # columns grouped via struct/array, convert to single tensor
+            single_input = batch.to_numpy()
+            if input_shapes and input_shapes[0] != [len(batch.columns)]:
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into tensors."
+            )
+
+    # if input_tensor_shapes provided, try to reshape input
+    if input_shapes:
+        if len(input_shapes) == 1:
+            single_input = single_input.reshape([-1] + input_shapes[0])  # type: ignore
+            if len(single_input) != len(batch):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("Multiple input_tensor_shapes found, but model expected one input")
+
+    return single_input
+
+
+def _validate_and_transform_prediction_result(
+    preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]],
+    num_input_rows: int,
+    return_type: DataType,
+) -> pd.DataFrame | pd.Series:
+    """Validate numpy-based model predictions against the expected pandas_udf return_type and
+    transforms the predictions into an equivalent pandas DataFrame or Series."""
+    if isinstance(return_type, StructType):
+        struct_rtype: StructType = return_type
+        fieldNames = struct_rtype.names
+        if isinstance(preds, dict):
+            # dictionary of columns
+            predNames = list(preds.keys())
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[field.name].shape) == 2:
+                        preds[field.name] = list(preds[field.name])
+                    else:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be two-dimensional."
+                        )
+                elif isinstance(field.dataType, supported_scalar_types):
+                    if len(preds[field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for scalar types must be one-dimensional."
+                        )
+                else:
+                    raise ValueError("Unsupported field type in return struct type.")
+
+                if len(preds[field.name]) != num_input_rows:
+                    raise ValueError("Prediction results must have same length as input data")
+
+        elif isinstance(preds, list) and isinstance(preds[0], dict):
+            # rows of dictionaries
+            predNames = list(preds[0].keys())
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as input data.")
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[0][field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be one-dimensional."
+                        )
+                elif isinstance(field.dataType, supported_scalar_types):
+                    if not np.isscalar(preds[0][field.name]):
+                        raise ValueError("Invalid scalar prediction result.")
+                else:
+                    raise ValueError("Unsupported field type in return struct type.")
+        else:
+            raise ValueError(
+                "Prediction results for StructType must be a dictionary or "
+                "a list of dictionary, got: {}".format(type(preds))
+            )
+
+        # check column names
+        if set(predNames) != set(fieldNames):
+            raise ValueError(
+                "Prediction result columns did not match expected return_type "
+                "columns: expected {}, got: {}".format(fieldNames, predNames)
+            )
+
+        return pd.DataFrame(preds)
+    elif isinstance(return_type, ArrayType):
+        if isinstance(preds, np.ndarray):
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as input data.")
+            if len(preds.shape) != 2:
+                raise ValueError("Prediction results for ArrayType must be two-dimensional.")
+        else:
+            raise ValueError("Prediction results for ArrayType must be an ndarray.")
+
+        return pd.Series(list(preds))
+    elif isinstance(return_type, supported_scalar_types):
+        preds_array: np.ndarray = preds  # type: ignore
+        if len(preds_array) != num_input_rows:
+            raise ValueError("Prediction results must have same length as input data.")
+        if not (
+            (len(preds_array.shape) == 2 and preds_array.shape[1] == 1)
+            or len(preds_array.shape) == 1
+        ):
+            raise ValueError("Invalid shape for scalar prediction result.")
+
+        return pd.Series(np.squeeze(preds))  # type: ignore
+    else:
+        raise ValueError("Unsupported return type")
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[
+        [],
+        Callable[
+            [np.ndarray | List[np.ndarray]],
+            np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, np.dtype]],
+        ],
+    ],
+    *,
+    return_type: DataType,
+    batch_size: int,
+    input_tensor_shapes: List[List[int] | None] | Mapping[int, List[int]] | None = None,
+) -> UserDefinedFunctionLike:
+    """Given a function which loads a model, returns a pandas_udf for inferencing over that model.
+
+    This will handle:
+    - conversion of the Spark DataFrame to numpy arrays.
+    - batching of the inputs sent to the model predict() function.
+    - caching of the model and prediction function on the executors.
+
+    This assumes that the `predict_batch_fn` encapsulates all of the necessary dependencies for
+    running the model or the Spark executor environment already satisfies all runtime requirements.
+
+    For the conversion of Spark DataFrame to numpy, the following table describes the behavior,
+    where tensor columns in the Spark DataFrame must be represented as a flattened 1-D array/list.
+
+    | dataframe \\ model | single input | multiple inputs |
+    | :----------------- | :----------- | :-------------- |
+    | single-col scalar  | 1            | N/A             |
+    | single-col tensor  | 1,2          | N/A             |
+    | multi-col scalar   | 3            | 4               |
+    | multi-col tensor   | N/A          | 4,2             |
+
+    Notes:
+    1. pass through dataframe column => model input as single numpy array.
+    2. reshape flattened tensors into expected tensor shapes.
+    3. user must use `pyspark.sql.functions.struct()` or `pyspark.sql.functions.array()` to
+       combine multiple input columns into the equivalent of a single-col tensor.
+    4. pass thru dataframe column => model input as an ordered list of numpy arrays.
+
+    Example (single-col tensor):
+
+    Input DataFrame has a single column with a flattened tensor value, represented as an array of
+    float.
+    ```
+    from pyspark.ml.functions import predict_batch_udf
+
+    def predict_batch_fn():
+        # load/init happens once per python worker
+        import tensorflow as tf
+        model = tf.keras.models.load_model('/path/to/mnist_model')
+
+        # predict on batches of tasks/partitions, using cached model
+        def predict(inputs: np.ndarray) -> np.ndarray:
+            # inputs.shape = [batch_size, 784]
+            # outputs.shape = [batch_size, 10], return_type = ArrayType(FloatType())
+            return model.predict(inputs)
+
+        return predict
+
+    mnist = predict_batch_udf(predict_batch_fn,
+                              return_type=ArrayType(FloatType()),
+                              batch_size=100,
+                              input_tensor_shapes=[[784]])
+
+    df = spark.read.parquet("/path/to/mnist_data")
+    df.show(5)
+    # +--------------------+
+    # |                data|
+    # +--------------------+
+    # |[0.0, 0.0, 0.0, 0...|
+    # |[0.0, 0.0, 0.0, 0...|
+    # |[0.0, 0.0, 0.0, 0...|
+    # |[0.0, 0.0, 0.0, 0...|
+    # |[0.0, 0.0, 0.0, 0...|
+    # +--------------------+
+
+    df.withColumn("preds", mnist("data")).show(5)
+    # +--------------------+--------------------+
+    # |                data|               preds|
+    # +--------------------+--------------------+
+    # |[0.0, 0.0, 0.0, 0...|[-13.511008, 8.84...|
+    # |[0.0, 0.0, 0.0, 0...|[-5.3957458, -2.2...|
+    # |[0.0, 0.0, 0.0, 0...|[-7.2014456, -8.8...|
+    # |[0.0, 0.0, 0.0, 0...|[-19.466187, -13....|
+    # |[0.0, 0.0, 0.0, 0...|[-5.7757926, -7.8...|
+    # +--------------------+--------------------+
+    ```
+
+    Example (single-col scalar):
+
+    Input DataFrame has a single scalar column, which will be passed to the `predict` function as
+    a 1-D numpy array.
+    ```
+    import numpy as np
+    import pandas as pd
+    from pyspark.ml.functions import predict_batch_udf
+    from pyspark.sql.types import FloatType
+
+    df = spark.createDataFrame(pd.DataFrame(np.arange(100)))
+    df.show(5)
+    # +---+
+    # |  0|
+    # +---+
+    # |  0|
+    # |  1|
+    # |  2|
+    # |  3|
+    # |  4|
+    # +---+
+
+    def predict_batch_fn():
+        def predict(inputs: np.ndarray) -> np.ndarray:
+            # inputs.shape = [batch_size]
+            # outputs.shape = [batch_size], return_type = FloatType()
+            return inputs * 2
+
+        return predict
+
+    times_two = predict_batch_udf(predict_batch_fn,
+                                  return_type=FloatType(),
+                                  batch_size=10)
+
+    df = spark.createDataFrame(pd.DataFrame(np.arange(100)))
+    df.withColumn("x2", times_two("0")).show(5)
+    # +---+---+
+    # |  0| x2|
+    # +---+---+
+    # |  0|0.0|
+    # |  1|2.0|
+    # |  2|4.0|
+    # |  3|6.0|
+    # |  4|8.0|
+    # +---+---+
+    ```
+
+    Example (multi-col scalar):
+
+    Input DataFrame has muliple columns of scalar values.  If the user-provided `predict` function
+    expects a single input, then the user should combine multiple columns into a single tensor using
+    `pyspark.sql.functions.struct` or `pyspark.sql.functions.array`.
+    ```
+    import numpy as np
+    import pandas as pd
+    from pyspark.ml.functions import predict_batch_udf
+    from pyspark.sql.functions import struct
+
+    data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)
+    pdf = pd.DataFrame(data, columns=['a','b','c','d'])
+    df = spark.createDataFrame(pdf)
+    # +----+----+----+----+
+    # |   a|   b|   c|   d|
+    # +----+----+----+----+
+    # | 0.0| 1.0| 2.0| 3.0|
+    # | 4.0| 5.0| 6.0| 7.0|
+    # | 8.0| 9.0|10.0|11.0|
+    # |12.0|13.0|14.0|15.0|
+    # |16.0|17.0|18.0|19.0|
+    # +----+----+----+----+
+
+    def predict_batch_fn():
+        def predict(inputs: np.ndarray) -> np.ndarray:
+            # inputs.shape = [batch_size, 4]
+            # outputs.shape = [batch_size], return_type = FloatType()
+            return np.sum(inputs, axis=1)
+
+        return predict
+
+    sum_rows = predict_batch_udf(predict_batch_fn,
+                                 return_type=FloatType(),
+                                 batch_size=10,
+                                 input_tensor_shapes=[[4]])
+
+    df.withColumn("sum", sum_rows(struct("a", "b", "c", "d"))).show(5)
+    # +----+----+----+----+----+
+    # |   a|   b|   c|   d| sum|
+    # +----+----+----+----+----+
+    # | 0.0| 1.0| 2.0| 3.0| 6.0|
+    # | 4.0| 5.0| 6.0| 7.0|22.0|
+    # | 8.0| 9.0|10.0|11.0|38.0|
+    # |12.0|13.0|14.0|15.0|54.0|
+    # |16.0|17.0|18.0|19.0|70.0|
+    # +----+----+----+----+----+
+
+    # Note: if the `predict` function expects multiple inputs, then the number of selected columns
+    # must match the number of expected inputs.
+
+    def predict_batch_fn():
+        def predict(x1: np.ndarray, x2: np.ndarray, x3: np.ndarray, x4: np.ndarray) -> np.ndarray:
+            # xN.shape = [batch_size]
+            # outputs.shape = [batch_size], return_type = FloatType()
+            return x1 + x2 + x3 + x4
+
+        return predict
+
+    sum_rows = predict_batch_udf(predict_batch_fn,
+                                 return_type=FloatType(),
+                                 batch_size=10)
+
+    df.withColumn("sum", sum_rows("a", "b", "c", "d")).show(5)
+    # +----+----+----+----+----+
+    # |   a|   b|   c|   d| sum|
+    # +----+----+----+----+----+
+    # | 0.0| 1.0| 2.0| 3.0| 6.0|
+    # | 4.0| 5.0| 6.0| 7.0|22.0|
+    # | 8.0| 9.0|10.0|11.0|38.0|
+    # |12.0|13.0|14.0|15.0|54.0|
+    # |16.0|17.0|18.0|19.0|70.0|
+    # +----+----+----+----+----+
+    ```
+
+    Example (multi-col tensor):
+
+    Input DataFrame has multiple columns, where each column is a tensor.  The number of columns
+    should match the number of expected inputs for the user-provided `predict` function.
+    ```
+    import numpy as np
+    import pandas as pd
+    from pyspark.ml.functions import predict_batch_udf
+    from pyspark.sql.types import FloatType, StructType, StructField
+    from typing import Mapping
+
+    data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)
+    pdf = pd.DataFrame(data, columns=['a','b','c','d'])
+    pdf_tensor = pd.DataFrame()
+    pdf_tensor['t1'] = pdf.values.tolist()
+    pdf_tensor['t2'] = pdf.drop(columns='d').values.tolist()
+    df = spark.createDataFrame(pdf_tensor)
+    df.show(5)
+    # +--------------------+------------------+
+    # |                  t1|                t2|
+    # +--------------------+------------------+
+    # |[0.0, 1.0, 2.0, 3.0]|   [0.0, 1.0, 2.0]|
+    # |[4.0, 5.0, 6.0, 7.0]|   [4.0, 5.0, 6.0]|
+    # |[8.0, 9.0, 10.0, ...|  [8.0, 9.0, 10.0]|
+    # |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]|
+    # |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|
+    # +--------------------+------------------+
+
+    def multi_sum_fn():
+        def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
+            # x1.shape = [batch_size, 4]
+            # x2.shape = [batch_size, 3]
+            # outputs.shape = [batch_size], result_type = FloatType()
+            return np.sum(x1, axis=1) + np.sum(x2, axis=1)
+
+        return predict
+
+    sum_cols = predict_batch_udf(
+        multi_sum_fn,
+        return_type=FloatType(),
+        batch_size=5,
+        input_tensor_shapes=[[4], [3]],
+    )
+
+    df.withColumn("sum", sum_cols("t1", "t2")).show(5)
+    # +--------------------+------------------+-----+
+    # |                  t1|                t2|  sum|
+    # +--------------------+------------------+-----+
+    # |[0.0, 1.0, 2.0, 3.0]|   [0.0, 1.0, 2.0]|  9.0|
+    # |[4.0, 5.0, 6.0, 7.0]|   [4.0, 5.0, 6.0]| 37.0|
+    # |[8.0, 9.0, 10.0, ...|  [8.0, 9.0, 10.0]| 65.0|
+    # |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]| 93.0|
+    # |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|121.0|
+    # +--------------------+------------------+-----+
+
+    # Note that some models can provide multiple outputs.  These can be returned as a dictionary
+    # of named values, which can be represented in columnar (or row-based) formats.
+
+    def multi_sum_fn():
+        def predict_columnar(x1: np.ndarray, x2: np.ndarray) -> Mapping[str, np.ndarray]:
+            # x1.shape = [batch_size, 4]
+            # x2.shape = [batch_size, 3]
+            return {
+                "sum1": np.sum(x1, axis=1),
+                "sum2": np.sum(x2, axis=1)
+            }  # return_type = StructType()
+
+        return predict_columnar
+
+    sum_cols = predict_batch_udf(
+        multi_sum_fn,
+        return_type=StructType([
+            StructField("sum1", FloatType(), True),
+            StructField("sum2", FloatType(), True)
+        ])
+        batch_size=5,
+        input_tensor_shapes=[[4], [3]],
+    )
+
+    df.withColumn("preds", sum_cols("t1", "t2")).select("t1", "t2", "preds.*").show(5)
+    # +--------------------+------------------+----+----+
+    # |                  t1|                t2|sum1|sum2|
+    # +--------------------+------------------+----+----+
+    # |[0.0, 1.0, 2.0, 3.0]|   [0.0, 1.0, 2.0]| 6.0| 3.0|
+    # |[4.0, 5.0, 6.0, 7.0]|   [4.0, 5.0, 6.0]|22.0|15.0|
+    # |[8.0, 9.0, 10.0, ...|  [8.0, 9.0, 10.0]|38.0|27.0|
+    # |[12.0, 13.0, 14.0...|[12.0, 13.0, 14.0]|54.0|39.0|
+    # |[16.0, 17.0, 18.0...|[16.0, 17.0, 18.0]|70.0|51.0|
+    # +--------------------+------------------+----+----+
+    ```
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    predict_batch_fn : Callable[[],
+        Callable[..., np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, np.dtype]] ]
+        Function which is responsible for loading a model and returning a `predict` function which
+        takes one or more numpy arrays as input and returns one of the following:

Review Comment:
   Is it better if we adopt the same output conversion as in Pandas UDF's? It means we only expect Series/DataFrame as output here. I try to see whether the extra learning is worth the output conversion code we saved for users.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org