You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/19 12:44:01 UTC
[airflow] branch master updated: Add service_account to Google ML
Engine operator (#11619)
This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 2d854c3 Add service_account to Google ML Engine operator (#11619)
2d854c3 is described below
commit 2d854c3505ccad66e9a7d94267e51bed800433c2
Author: Daniel Burkhardt Cerigo <db...@gmail.com>
AuthorDate: Mon Oct 19 13:42:50 2020 +0100
Add service_account to Google ML Engine operator (#11619)
---
airflow/providers/google/cloud/operators/mlengine.py | 13 +++++++++++++
tests/providers/google/cloud/operators/test_mlengine.py | 2 ++
2 files changed, 15 insertions(+)
diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py
index 8dbc306..2fd34e7 100644
--- a/airflow/providers/google/cloud/operators/mlengine.py
+++ b/airflow/providers/google/cloud/operators/mlengine.py
@@ -1115,6 +1115,13 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
:param job_dir: A Google Cloud Storage path in which to store training
outputs and other data needed for training. (templated)
:type job_dir: str
+ :param service_account: Optional service account to use when running the training application.
+ (templated)
+ The specified service account must have the `iam.serviceAccounts.actAs` role. The
+ Google-managed Cloud ML Engine service account must have the `iam.serviceAccountAdmin` role
+ for the specified service account.
+ If set to None or missing, the Google-managed Cloud ML Engine service account will be used.
+ :type service_account: str
:param project_id: The Google Cloud project name within which MLEngine training job should run.
If set to None or missing, the default project_id from the Google Cloud connection is used.
(templated)
@@ -1156,6 +1163,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
'_runtime_version',
'_python_version',
'_job_dir',
+ '_service_account',
'_impersonation_chain',
]
@@ -1176,6 +1184,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
runtime_version: Optional[str] = None,
python_version: Optional[str] = None,
job_dir: Optional[str] = None,
+ service_account: Optional[str] = None,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
@@ -1197,6 +1206,7 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
self._runtime_version = runtime_version
self._python_version = python_version
self._job_dir = job_dir
+ self._service_account = service_account
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
self._mode = mode
@@ -1244,6 +1254,9 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
if self._job_dir:
training_request['trainingInput']['jobDir'] = self._job_dir
+ if self._service_account:
+ training_request['trainingInput']['serviceAccount'] = self._service_account
+
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
training_request['trainingInput']['masterType'] = self._master_type
diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py
index 99353c8..75e07cf 100644
--- a/tests/providers/google/cloud/operators/test_mlengine.py
+++ b/tests/providers/google/cloud/operators/test_mlengine.py
@@ -413,6 +413,7 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
training_input['trainingInput']['runtimeVersion'] = '1.6'
training_input['trainingInput']['pythonVersion'] = '3.5'
training_input['trainingInput']['jobDir'] = 'gs://some-bucket/jobs/test_training'
+ training_input['trainingInput']['serviceAccount'] = 'test@serviceaccount.com'
success_response = self.TRAINING_INPUT.copy()
success_response['state'] = 'SUCCEEDED'
@@ -423,6 +424,7 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
runtime_version='1.6',
python_version='3.5',
job_dir='gs://some-bucket/jobs/test_training',
+ service_account='test@serviceaccount.com',
**self.TRAINING_DEFAULT_ARGS,
)
training_op.execute(MagicMock())