You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/13 19:03:27 UTC
[spark] branch branch-3.4 updated: [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 76f9594d296 [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
76f9594d296 is described below
commit 76f9594d296b1b957d7638d3b7c020b90b3a27ed
Author: Rithwik Ediga Lakhamsani <ri...@databricks.com>
AuthorDate: Tue Feb 14 04:03:02 2023 +0900
[SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
### What changes were proposed in this pull request?
Replaced an uninsightful `FileNotFoundError` with a better `RuntimeError` exception when training fails.
### Why are the changes needed?
Improve user experience.
### Does this PR introduce _any_ user-facing change?
Just the message that is shown to the user.
### How was this patch tested?
N/A
Closes #39987 from rithwik-db/error-bug-fix.
Authored-by: Rithwik Ediga Lakhamsani <ri...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 889889e993ad1621f73440ff287dfd0a54c0ea4f)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/ml/torch/distributor.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py
index 5f0e930515c..92b63ab2da4 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -613,7 +613,13 @@ class TorchDistributor(Distributor):
with TorchDistributor._setup_files(train_fn, *args) as (train_file_path, output_file_path):
args = [] # type: ignore
TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path, *args)
- output = TorchDistributor._get_pickled_output(output_file_path)
+ try:
+ output = TorchDistributor._get_pickled_output(output_file_path)
+ except FileNotFoundError as e:
+ raise RuntimeError(
+ "TorchDistributor failed during training. "
+ "View stdout logs for detailed error message."
+ ) from e
return output
@staticmethod
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org