You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/06 18:36:02 UTC

[GitHub] jgao54 closed pull request #4428: [AIRFLOW-3624] Add masterType parameter to MLEngineTrainingOperator

jgao54 closed pull request #4428: [AIRFLOW-3624] Add masterType parameter to MLEngineTrainingOperator
URL: https://github.com/apache/airflow/pull/4428
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py
index 8091ceefff..101c99d41b 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -473,6 +473,10 @@ class MLEngineTrainingOperator(BaseOperator):
     :param scale_tier: Resource tier for MLEngine training job. (templated)
     :type scale_tier: str
 
+    :param master_type: Cloud ML Engine machine name.
+        Must be set when scale_tier is CUSTOM. (templated)
+    :type master_type: str
+
     :param runtime_version: The Google Cloud ML runtime version to use for
         training. (templated)
     :type runtime_version: str
@@ -507,6 +511,7 @@ class MLEngineTrainingOperator(BaseOperator):
         '_training_args',
         '_region',
         '_scale_tier',
+        '_master_type',
         '_runtime_version',
         '_python_version',
         '_job_dir'
@@ -521,6 +526,7 @@ def __init__(self,
                  training_args,
                  region,
                  scale_tier=None,
+                 master_type=None,
                  runtime_version=None,
                  python_version=None,
                  job_dir=None,
@@ -537,6 +543,7 @@ def __init__(self,
         self._training_args = training_args
         self._region = region
         self._scale_tier = scale_tier
+        self._master_type = master_type
         self._runtime_version = runtime_version
         self._python_version = python_version
         self._job_dir = job_dir
@@ -560,6 +567,9 @@ def __init__(self,
                 'packages is required.')
         if not self._region:
             raise AirflowException('Google Compute Engine region is required.')
+        if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" and not self._master_type:
+            raise AirflowException(
+                'master_type must be set when scale_tier is CUSTOM')
 
     def execute(self, context):
         job_id = _normalize_mlengine_job_id(self._job_id)
@@ -583,6 +593,9 @@ def execute(self, context):
         if self._job_dir:
             training_request['trainingInput']['jobDir'] = self._job_dir
 
+        if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
+            training_request['trainingInput']['masterType'] = self._master_type
+
         if self._mode == 'DRY_RUN':
             self.log.info('In dry_run mode.')
             self.log.info('MLEngine Training job request is: {}'.format(


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services