You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/08 06:47:34 UTC

[spark] branch branch-3.4 updated: [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 60a40416f50 [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
60a40416f50 is described below

commit 60a40416f50327e914f9bdeca36071d9a3c7973e
Author: Rithwik Ediga Lakhamsani <ri...@databricks.com>
AuthorDate: Wed Mar 8 14:46:58 2023 +0800

    [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
    
    ### What changes were proposed in this pull request?
    
    I added a better way to show the error instead of having it be confusing for the reader.
    
    ### Why are the changes needed?
    
    User experience.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Just the error that will be shown to the user.
    
    ### How was this patch tested?
    
    Tested it out locally.
    
    Closes #40322 from rithwik-db/torch-distributor-error-fix.
    
    Authored-by: Rithwik Ediga Lakhamsani <ri...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
    (cherry picked from commit 5db84b539e0c8bb7980d3d9499e51372dbe25341)
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/ml/torch/distributor.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py
index a0a9c5aa932..157cc96717f 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -627,11 +627,16 @@ class TorchDistributor(Distributor):
         with TorchDistributor._setup_files(train_fn, *args) as (train_file_path, output_file_path):
             args = []  # type: ignore
             TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path, *args)
+            if not os.path.exists(output_file_path):
+                raise RuntimeError(
+                    "TorchDistributor failed during training. "
+                    "View stdout logs for detailed error message."
+                )
             try:
                 output = TorchDistributor._get_pickled_output(output_file_path)
-            except FileNotFoundError as e:
+            except Exception as e:
                 raise RuntimeError(
-                    "TorchDistributor failed during training. "
+                    "TorchDistributor failed due to a pickling error. "
                     "View stdout logs for detailed error message."
                 ) from e
         return output


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org