You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/19 12:44:01 UTC

[airflow] branch master updated: Add service_account to Google ML Engine operator (#11619)

This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d854c3  Add service_account to Google ML Engine operator (#11619)
2d854c3 is described below

commit 2d854c3505ccad66e9a7d94267e51bed800433c2
Author: Daniel Burkhardt Cerigo <db...@gmail.com>
AuthorDate: Mon Oct 19 13:42:50 2020 +0100

    Add service_account to Google ML Engine operator (#11619)
---
 airflow/providers/google/cloud/operators/mlengine.py    | 13 +++++++++++++
 tests/providers/google/cloud/operators/test_mlengine.py |  2 ++
 2 files changed, 15 insertions(+)

diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py
index 8dbc306..2fd34e7 100644
--- a/airflow/providers/google/cloud/operators/mlengine.py
+++ b/airflow/providers/google/cloud/operators/mlengine.py
@@ -1115,6 +1115,13 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
     :param job_dir: A Google Cloud Storage path in which to store training
         outputs and other data needed for training. (templated)
     :type job_dir: str
+    :param service_account: Optional service account to use when running the training application.
+        (templated)
+        The specified service account must have the `iam.serviceAccounts.actAs` role. The
+        Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role
+        for the specified service account.
+        If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
+    :type service_account: str
     :param project_id: The Google Cloud project name within which MLEngine training job should run.
         If set to None or missing, the default project_id from the Google Cloud connection is used.
         (templated)
@@ -1156,6 +1163,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         '_runtime_version',
         '_python_version',
         '_job_dir',
+        '_service_account',
         '_impersonation_chain',
     ]
 
@@ -1176,6 +1184,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         runtime_version: Optional[str] = None,
         python_version: Optional[str] = None,
         job_dir: Optional[str] = None,
+        service_account: Optional[str] = None,
         project_id: Optional[str] = None,
         gcp_conn_id: str = 'google_cloud_default',
         delegate_to: Optional[str] = None,
@@ -1197,6 +1206,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         self._runtime_version = runtime_version
         self._python_version = python_version
         self._job_dir = job_dir
+        self._service_account = service_account
         self._gcp_conn_id = gcp_conn_id
         self._delegate_to = delegate_to
         self._mode = mode
@@ -1244,6 +1254,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
         if self._job_dir:
             training_request['trainingInput']['jobDir'] = self._job_dir
 
+        if self._service_account:
+            training_request['trainingInput']['serviceAccount'] = self._service_account
+
         if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
             training_request['trainingInput']['masterType'] = self._master_type
 
diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py
index 99353c8..75e07cf 100644
--- a/tests/providers/google/cloud/operators/test_mlengine.py
+++ b/tests/providers/google/cloud/operators/test_mlengine.py
@@ -413,6 +413,7 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
         training_input['trainingInput']['runtimeVersion'] = '1.6'
         training_input['trainingInput']['pythonVersion'] = '3.5'
         training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'
+        training_input['trainingInput']['serviceAccount'] = 'test@serviceaccount.com'
 
         success_response = self.TRAINING_INPUT.copy()
         success_response['state'] = 'SUCCEEDED'
@@ -423,6 +424,7 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
             runtime_version='1.6',
             python_version='3.5',
             job_dir='gs://some-bucket/jobs/test_training',
+            service_account='test@serviceaccount.com',
             **self.TRAINING_DEFAULT_ARGS,
         )
         training_op.execute(MagicMock())