You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/08 06:47:18 UTC
[spark] branch master updated: [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5db84b539e0 [SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
5db84b539e0 is described below
commit 5db84b539e0c8bb7980d3d9499e51372dbe25341
Author: Rithwik Ediga Lakhamsani <ri...@databricks.com>
AuthorDate: Wed Mar 8 14:46:58 2023 +0800
[SPARK-41775][PYTHON][FOLLOW-UP] Updating error message for training using PyTorch functions
### What changes were proposed in this pull request?
I added a better way to show the error instead of having it be confusing for the reader.
### Why are the changes needed?
User experience.
### Does this PR introduce _any_ user-facing change?
Just the error that will be shown to the user.
### How was this patch tested?
Tested it out locally.
Closes #40322 from rithwik-db/torch-distributor-error-fix.
Authored-by: Rithwik Ediga Lakhamsani <ri...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/ml/torch/distributor.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py
index a0a9c5aa932..157cc96717f 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -627,11 +627,16 @@ class TorchDistributor(Distributor):
with TorchDistributor._setup_files(train_fn, *args) as (train_file_path, output_file_path):
args = [] # type: ignore
TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path, *args)
+ if not os.path.exists(output_file_path):
+ raise RuntimeError(
+ "TorchDistributor failed during training. "
+ "View stdout logs for detailed error message."
+ )
try:
output = TorchDistributor._get_pickled_output(output_file_path)
- except FileNotFoundError as e:
+ except Exception as e:
raise RuntimeError(
- "TorchDistributor failed during training. "
+ "TorchDistributor failed due to a pickling error. "
"View stdout logs for detailed error message."
) from e
return output
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org