You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "maddiedawson (via GitHub)" <gi...@apache.org> on 2023/07/06 21:15:33 UTC

[GitHub] [spark] maddiedawson commented on a diff in pull request #41770: [WIP] Write a Deepspeed Distributed Learning Class DeepspeedTorchDistributor

maddiedawson commented on code in PR #41770:
URL: https://github.com/apache/spark/pull/41770#discussion_r1254923674


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -1003,3 +1007,97 @@ def _get_spark_partition_data_loader(
         # if num_workers is zero, we cannot set `prefetch_factor` otherwise
         # torch will raise error.
         return DataLoader(dataset, batch_size, num_workers=num_workers)
+
+
+class DeepspeedTorchDistributor(TorchDistributor):
+    
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, deepspeed_config = None):
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.deepspeed_config = deepspeed_config 
+        self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+        self._validate_input_params()
+        self.input_params = self._create_input_params()
+
+    @staticmethod
+    def _get_deepspeed_config_path(deepspeed_config):
+        if isinstance(deepspeed_config, dict):
+            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as fil:
+                json.dump(deepspeed_config, fil)
+                deepspeed_config_path = fil.name

Review Comment:
   Just return fil.name here. Then you can remove the else statement and just return deepspeed_config



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -1003,3 +1007,97 @@ def _get_spark_partition_data_loader(
         # if num_workers is zero, we cannot set `prefetch_factor` otherwise
         # torch will raise error.
         return DataLoader(dataset, batch_size, num_workers=num_workers)
+
+
+class DeepspeedTorchDistributor(TorchDistributor):
+    
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, deepspeed_config = None):
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.deepspeed_config = deepspeed_config 
+        self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+        self._validate_input_params()
+        self.input_params = self._create_input_params()
+
+    @staticmethod
+    def _get_deepspeed_config_path(deepspeed_config):
+        if isinstance(deepspeed_config, dict):
+            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as fil:
+                json.dump(deepspeed_config, fil)
+                deepspeed_config_path = fil.name
+        else:
+            deepspeed_config_path = deepspeed_config
+        return deepspeed_config_path
+
+
+    @staticmethod 
+    def _get_torchrun_args(local_mode, num_processes):
+        if local_mode:
+            torchrun_args = ["--standalone", "--nnodes=1"]
+            processes_per_node = num_processes

Review Comment:
   Again, just return here, then remove the else



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -1003,3 +1007,97 @@ def _get_spark_partition_data_loader(
         # if num_workers is zero, we cannot set `prefetch_factor` otherwise
         # torch will raise error.
         return DataLoader(dataset, batch_size, num_workers=num_workers)
+
+
+class DeepspeedTorchDistributor(TorchDistributor):
+    
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, deepspeed_config = None):
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.deepspeed_config = deepspeed_config 
+        self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+        self._validate_input_params()
+        self.input_params = self._create_input_params()
+
+    @staticmethod
+    def _get_deepspeed_config_path(deepspeed_config):
+        if isinstance(deepspeed_config, dict):
+            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as fil:
+                json.dump(deepspeed_config, fil)
+                deepspeed_config_path = fil.name
+        else:
+            deepspeed_config_path = deepspeed_config
+        return deepspeed_config_path
+
+
+    @staticmethod 
+    def _get_torchrun_args(local_mode, num_processes):

Review Comment:
   Add a function comment describing what this returns



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -1003,3 +1007,97 @@ def _get_spark_partition_data_loader(
         # if num_workers is zero, we cannot set `prefetch_factor` otherwise
         # torch will raise error.
         return DataLoader(dataset, batch_size, num_workers=num_workers)
+
+
+class DeepspeedTorchDistributor(TorchDistributor):
+    
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, deepspeed_config = None):
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.deepspeed_config = deepspeed_config 
+        self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+        self._validate_input_params()
+        self.input_params = self._create_input_params()
+
+    @staticmethod
+    def _get_deepspeed_config_path(deepspeed_config):
+        if isinstance(deepspeed_config, dict):
+            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as fil:
+                json.dump(deepspeed_config, fil)
+                deepspeed_config_path = fil.name
+        else:
+            deepspeed_config_path = deepspeed_config
+        return deepspeed_config_path
+
+
+    @staticmethod 
+    def _get_torchrun_args(local_mode, num_processes):
+        if local_mode:
+            torchrun_args = ["--standalone", "--nnodes=1"]
+            processes_per_node = num_processes
+        else:
+            master_addr, master_port = (
+                os.environ["MASTER_ADDR"],
+                os.environ["MASTER_PORT"],
+            )
+            node_rank = os.environ["RANK"]
+            torchrun_args = [
+                f"--nnodes={num_processes}",
+                f"--node_rank={node_rank}",
+                f"--rdzv_endpoint={master_addr}:{master_port}",
+                "--rdzv_id=0",
+            ]
+            processes_per_node = 1
+        return torchrun_args, processes_per_node
+
+    @staticmethod
+    def _create_torchrun_command(
+            input_params: Dict[str, Any], train_path: str, *args: Any) -> List[str]:
+        local_mode = input_params["local_mode"]
+        num_processes = input_params["num_processes"]
+        deepspeed_config = input_params["deepspeed_config"]
+        
+        deepspeed_config_path = DeepspeedTorchDistributor._get_deepspeed_config_path(deepspeed_config)
+
+        torchrun_args, processes_per_node = DeepspeedTorchDistributor._get_torchrun_args(local_mode, num_processes)
+
+        args_string = list(map(str, args))
+        
+        command_to_run = [ 
+                          sys.executable,
+                          "-m",
+                          "torch.distributed.run",
+                          *torchrun_args,
+                          f"--nproc_per_node={processes_per_node}",
+                          train_path,
+                          *args_string,
+                          "-deepspeed",
+                          "--deepspeed_config",
+                          deepspeed_config_path
+                        ]
+        print(command_to_run)

Review Comment:
   Can remove this



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -1003,3 +1007,97 @@ def _get_spark_partition_data_loader(
         # if num_workers is zero, we cannot set `prefetch_factor` otherwise
         # torch will raise error.
         return DataLoader(dataset, batch_size, num_workers=num_workers)
+
+
+class DeepspeedTorchDistributor(TorchDistributor):
+    
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, deepspeed_config = None):
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.deepspeed_config = deepspeed_config 
+        self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+        self._validate_input_params()
+        self.input_params = self._create_input_params()
+
+    @staticmethod
+    def _get_deepspeed_config_path(deepspeed_config):
+        if isinstance(deepspeed_config, dict):
+            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.json') as fil:
+                json.dump(deepspeed_config, fil)
+                deepspeed_config_path = fil.name
+        else:
+            deepspeed_config_path = deepspeed_config
+        return deepspeed_config_path
+
+
+    @staticmethod 
+    def _get_torchrun_args(local_mode, num_processes):
+        if local_mode:
+            torchrun_args = ["--standalone", "--nnodes=1"]
+            processes_per_node = num_processes
+        else:
+            master_addr, master_port = (
+                os.environ["MASTER_ADDR"],
+                os.environ["MASTER_PORT"],
+            )
+            node_rank = os.environ["RANK"]
+            torchrun_args = [
+                f"--nnodes={num_processes}",
+                f"--node_rank={node_rank}",
+                f"--rdzv_endpoint={master_addr}:{master_port}",
+                "--rdzv_id=0",
+            ]
+            processes_per_node = 1
+        return torchrun_args, processes_per_node
+
+    @staticmethod
+    def _create_torchrun_command(
+            input_params: Dict[str, Any], train_path: str, *args: Any) -> List[str]:
+        local_mode = input_params["local_mode"]
+        num_processes = input_params["num_processes"]
+        deepspeed_config = input_params["deepspeed_config"]
+        
+        deepspeed_config_path = DeepspeedTorchDistributor._get_deepspeed_config_path(deepspeed_config)
+
+        torchrun_args, processes_per_node = DeepspeedTorchDistributor._get_torchrun_args(local_mode, num_processes)
+
+        args_string = list(map(str, args))
+        
+        command_to_run = [ 
+                          sys.executable,
+                          "-m",
+                          "torch.distributed.run",
+                          *torchrun_args,
+                          f"--nproc_per_node={processes_per_node}",
+                          train_path,
+                          *args_string,
+                          "-deepspeed",
+                          "--deepspeed_config",
+                          deepspeed_config_path
+                        ]
+        print(command_to_run)
+        return command_to_run
+
+
+    @staticmethod
+    def _run_training_on_pytorch_file(input_params: Dict[str, Any], train_path: str, *args: Any, **kwargs : Any) -> None :
+        if kwargs:
+            raise ValueError("DeepspeedTorchDistributor with pytorch file doesn't support key-word type arguments")
+
+        log_streaming_client = input_params.get("log_streaming_client", None)
+        training_command = DeepspeedTorchDistributor._create_torchrun_command(input_params, train_path, *args)
+        DeepspeedTorchDistributor._execute_command(training_command, log_streaming_client=log_streaming_client)
+
+    def run(self, train_object: Union[Callable, str], *args : Any, **kwargs: Any) -> Optional[Any]:
+        # if the "train_object" is a string, then we assume it's a filepath. Otherwise, we assume it's a function
+        if isinstance(train_object, str):
+            framework_wrapper_fn = DeepspeedTorchDistributor._run_training_on_pytorch_file
+        else:
+            framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_file

Review Comment:
   What does this line do? Shouldn't this be _run_training_on_pytorch_function? But that won't use deepspeed, right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org