You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2017/07/13 21:33:37 UTC
incubator-airflow git commit: [AIRFLOW-1401] Standardize cloud ml
operator arguments
Repository: incubator-airflow
Updated Branches:
refs/heads/master 9fd0beaac -> b6d363104
[AIRFLOW-1401] Standardize cloud ml operator arguments
Standardize on project_id, to be consistent with
other cloud operators,
better-supporting default arguments.
This is one of multiple commits that will be
required to resolve
AIRFLOW-1401.
Closes #2439 from peterjdolan/cloudml_project_id
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b6d36310
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b6d36310
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b6d36310
Branch: refs/heads/master
Commit: b6d3631043ceb896dd1f8b7ade84751a284770b0
Parents: 9fd0bea
Author: Peter Dolan <pe...@google.com>
Authored: Thu Jul 13 14:33:32 2017 -0700
Committer: Alex Guziel <al...@airbnb.com>
Committed: Thu Jul 13 14:33:32 2017 -0700
----------------------------------------------------------------------
airflow/contrib/hooks/gcp_cloudml_hook.py | 44 +++++++++---------
airflow/contrib/operators/cloudml_operator.py | 47 ++++++++++----------
tests/contrib/hooks/test_gcp_cloudml_hook.py | 20 ++++-----
.../contrib/operators/test_cloudml_operator.py | 2 +-
4 files changed, 57 insertions(+), 56 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
index 6f634b2..e1ff155 100644
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ b/airflow/contrib/hooks/gcp_cloudml_hook.py
@@ -62,13 +62,13 @@ class CloudMLHook(GoogleCloudBaseHook):
credentials = GoogleCredentials.get_application_default()
return build('ml', 'v1', credentials=credentials)
- def create_job(self, project_name, job, use_existing_job_fn=None):
+ def create_job(self, project_id, job, use_existing_job_fn=None):
"""
Launches a CloudML job and wait for it to reach a terminal state.
- :param project_name: The Google Cloud project name within which CloudML
+ :param project_id: The Google Cloud project id within which CloudML
job will be launched.
- :type project_name: string
+ :type project_id: string
:param job: CloudML Job object that should be provided to the CloudML
API, such as:
@@ -95,7 +95,7 @@ class CloudMLHook(GoogleCloudBaseHook):
:rtype: dict
"""
request = self._cloudml.projects().jobs().create(
- parent='projects/{}'.format(project_name),
+ parent='projects/{}'.format(project_id),
body=job)
job_id = job['jobId']
@@ -105,7 +105,7 @@ class CloudMLHook(GoogleCloudBaseHook):
# 409 means there is an existing job with the same job ID.
if e.resp.status == 409:
if use_existing_job_fn is not None:
- existing_job = self._get_job(project_name, job_id)
+ existing_job = self._get_job(project_id, job_id)
if not use_existing_job_fn(existing_job):
logging.error(
'Job with job_id {} already exist, but it does '
@@ -118,9 +118,9 @@ class CloudMLHook(GoogleCloudBaseHook):
else:
logging.error('Failed to create CloudML job: {}'.format(e))
raise
- return self._wait_for_job_done(project_name, job_id)
+ return self._wait_for_job_done(project_id, job_id)
- def _get_job(self, project_name, job_id):
+ def _get_job(self, project_id, job_id):
"""
Gets a CloudML job based on the job name.
@@ -130,7 +130,7 @@ class CloudMLHook(GoogleCloudBaseHook):
Raises:
apiclient.errors.HttpError: if HTTP error is returned from server
"""
- job_name = 'projects/{}/jobs/{}'.format(project_name, job_id)
+ job_name = 'projects/{}/jobs/{}'.format(project_id, job_id)
request = self._cloudml.projects().jobs().get(name=job_name)
while True:
try:
@@ -143,7 +143,7 @@ class CloudMLHook(GoogleCloudBaseHook):
logging.error('Failed to get CloudML job: {}'.format(e))
raise
- def _wait_for_job_done(self, project_name, job_id, interval=30):
+ def _wait_for_job_done(self, project_id, job_id, interval=30):
"""
Waits for the Job to reach a terminal state.
@@ -156,19 +156,19 @@ class CloudMLHook(GoogleCloudBaseHook):
"""
assert interval > 0
while True:
- job = self._get_job(project_name, job_id)
+ job = self._get_job(project_id, job_id)
if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
return job
time.sleep(interval)
- def create_version(self, project_name, model_name, version_spec):
+ def create_version(self, project_id, model_name, version_spec):
"""
Creates the Version on Cloud ML.
Returns the operation if the version was created successfully and
raises an error otherwise.
"""
- parent_name = 'projects/{}/models/{}'.format(project_name, model_name)
+ parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
create_request = self._cloudml.projects().models().versions().create(
parent=parent_name, body=version_spec)
response = create_request.execute()
@@ -181,12 +181,12 @@ class CloudMLHook(GoogleCloudBaseHook):
is_done_func=lambda resp: resp.get('done', False),
is_error_func=lambda resp: resp.get('error', None) is not None)
- def set_default_version(self, project_name, model_name, version_name):
+ def set_default_version(self, project_id, model_name, version_name):
"""
Sets a version to be the default. Blocks until finished.
"""
full_version_name = 'projects/{}/models/{}/versions/{}'.format(
- project_name, model_name, version_name)
+ project_id, model_name, version_name)
request = self._cloudml.projects().models().versions().setDefault(
name=full_version_name, body={})
@@ -199,13 +199,13 @@ class CloudMLHook(GoogleCloudBaseHook):
logging.error('Something went wrong: {}'.format(e))
raise
- def list_versions(self, project_name, model_name):
+ def list_versions(self, project_id, model_name):
"""
Lists all available versions of a model. Blocks until finished.
"""
result = []
full_parent_name = 'projects/{}/models/{}'.format(
- project_name, model_name)
+ project_id, model_name)
request = self._cloudml.projects().models().versions().list(
parent=full_parent_name, pageSize=100)
@@ -223,12 +223,12 @@ class CloudMLHook(GoogleCloudBaseHook):
time.sleep(5)
return result
- def delete_version(self, project_name, model_name, version_name):
+ def delete_version(self, project_id, model_name, version_name):
"""
Deletes the given version of a model. Blocks until finished.
"""
full_name = 'projects/{}/models/{}/versions/{}'.format(
- project_name, model_name, version_name)
+ project_id, model_name, version_name)
delete_request = self._cloudml.projects().models().versions().delete(
name=full_name)
response = delete_request.execute()
@@ -241,24 +241,24 @@ class CloudMLHook(GoogleCloudBaseHook):
is_done_func=lambda resp: resp.get('done', False),
is_error_func=lambda resp: resp.get('error', None) is not None)
- def create_model(self, project_name, model):
+ def create_model(self, project_id, model):
"""
Create a Model. Blocks until finished.
"""
assert model['name'] is not None and model['name'] is not ''
- project = 'projects/{}'.format(project_name)
+ project = 'projects/{}'.format(project_id)
request = self._cloudml.projects().models().create(
parent=project, body=model)
return request.execute()
- def get_model(self, project_name, model_name):
+ def get_model(self, project_id, model_name):
"""
Gets a Model. Blocks until finished.
"""
assert model_name is not None and model_name is not ''
full_model_name = 'projects/{}/models/{}'.format(
- project_name, model_name)
+ project_id, model_name)
request = self._cloudml.projects().models().get(name=full_model_name)
try:
return request.execute()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index 3ad6f5a..34b2e83 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -272,9 +272,9 @@ class CloudMLModelOperator(BaseOperator):
should contain the `name` of the model.
:type model: dict
- :param project_name: The Google Cloud project name to which CloudML
+ :param project_id: The Google Cloud project name to which CloudML
model belongs.
- :type project_name: string
+ :type project_id: string
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: string
@@ -291,12 +291,13 @@ class CloudMLModelOperator(BaseOperator):
template_fields = [
'_model',
+ '_model_name',
]
@apply_defaults
def __init__(self,
+ project_id,
model,
- project_name,
gcp_conn_id='google_cloud_default',
operation='create',
delegate_to=None,
@@ -307,15 +308,15 @@ class CloudMLModelOperator(BaseOperator):
self._operation = operation
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
- self._project_name = project_name
+ self._project_id = project_id
def execute(self, context):
hook = CloudMLHook(
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
if self._operation == 'create':
- hook.create_model(self._project_name, self._model)
+ hook.create_model(self._project_id, self._model)
elif self._operation == 'get':
- hook.get_model(self._project_name, self._model['name'])
+ hook.get_model(self._project_id, self._model['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
@@ -328,9 +329,9 @@ class CloudMLVersionOperator(BaseOperator):
belongs to.
:type model_name: string
- :param project_name: The Google Cloud project name to which CloudML
+ :param project_id: The Google Cloud project name to which CloudML
model belongs.
- :type project_name: string
+ :type project_id: string
:param version: A dictionary containing the information about the version.
If the `operation` is `create`, `version` should contain all the
@@ -376,8 +377,8 @@ class CloudMLVersionOperator(BaseOperator):
@apply_defaults
def __init__(self,
model_name,
- project_name,
- version=None,
+ project_id,
+ version,
gcp_conn_id='google_cloud_default',
operation='create',
delegate_to=None,
@@ -389,7 +390,7 @@ class CloudMLVersionOperator(BaseOperator):
self._version = version
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
- self._project_name = project_name
+ self._project_id = project_id
self._operation = operation
def execute(self, context):
@@ -398,16 +399,16 @@ class CloudMLVersionOperator(BaseOperator):
if self._operation == 'create':
assert self._version is not None
- return hook.create_version(self._project_name, self._model_name,
+ return hook.create_version(self._project_id, self._model_name,
self._version)
elif self._operation == 'set_default':
return hook.set_default_version(
- self._project_name, self._model_name,
+ self._project_id, self._model_name,
self._version['name'])
elif self._operation == 'list':
- return hook.list_versions(self._project_name, self._model_name)
+ return hook.list_versions(self._project_id, self._model_name)
elif self._operation == 'delete':
- return hook.delete_version(self._project_name, self._model_name,
+ return hook.delete_version(self._project_id, self._model_name,
self._version['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
@@ -417,9 +418,9 @@ class CloudMLTrainingOperator(BaseOperator):
"""
Operator for launching a CloudML training job.
- :param project_name: The Google Cloud project name within which CloudML
+ :param project_id: The Google Cloud project name within which CloudML
training job should run. This field could be templated.
- :type project_name: string
+ :type project_id: string
:param job_id: A unique templated id for the submitted Google CloudML
training job.
@@ -461,7 +462,7 @@ class CloudMLTrainingOperator(BaseOperator):
"""
template_fields = [
- '_project_name',
+ '_project_id',
'_job_id',
'_package_uris',
'_training_python_module',
@@ -472,7 +473,7 @@ class CloudMLTrainingOperator(BaseOperator):
@apply_defaults
def __init__(self,
- project_name,
+ project_id,
job_id,
package_uris,
training_python_module,
@@ -485,7 +486,7 @@ class CloudMLTrainingOperator(BaseOperator):
*args,
**kwargs):
super(CloudMLTrainingOperator, self).__init__(*args, **kwargs)
- self._project_name = project_name
+ self._project_id = project_id
self._job_id = job_id
self._package_uris = package_uris
self._training_python_module = training_python_module
@@ -496,8 +497,8 @@ class CloudMLTrainingOperator(BaseOperator):
self._delegate_to = delegate_to
self._mode = mode
- if not self._project_name:
- raise AirflowException('Google Cloud project name is required.')
+ if not self._project_id:
+ raise AirflowException('Google Cloud project id is required.')
if not self._job_id:
raise AirflowException(
'An unique job id is required for Google CloudML training '
@@ -542,7 +543,7 @@ class CloudMLTrainingOperator(BaseOperator):
training_request['trainingInput']
try:
finished_training_job = hook.create_job(
- self._project_name, training_request, check_existing_job)
+ self._project_id, training_request, check_existing_job)
except errors.HttpError:
raise
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
index 53aba41..f56018d 100644
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py
@@ -121,7 +121,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=[succeeded_response] * 2,
expected_requests=expected_requests) as cml_hook:
create_version_response = cml_hook.create_version(
- project_name=project, model_name=model_name,
+ project_id=project, model_name=model_name,
version_spec=version)
self.assertEquals(create_version_response, response_body)
@@ -147,7 +147,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=[succeeded_response],
expected_requests=expected_requests) as cml_hook:
set_default_version_response = cml_hook.set_default_version(
- project_name=project, model_name=model_name,
+ project_id=project, model_name=model_name,
version_name=version)
self.assertEquals(set_default_version_response, response_body)
@@ -187,7 +187,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=responses,
expected_requests=expected_requests) as cml_hook:
list_versions_response = cml_hook.list_versions(
- project_name=project, model_name=model_name)
+ project_id=project, model_name=model_name)
self.assertEquals(list_versions_response, versions)
@_SKIP_IF
@@ -220,7 +220,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=[not_done_response, succeeded_response],
expected_requests=expected_requests) as cml_hook:
delete_version_response = cml_hook.delete_version(
- project_name=project, model_name=model_name,
+ project_id=project, model_name=model_name,
version_name=version)
self.assertEquals(delete_version_response, done_response_body)
@@ -245,7 +245,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=[succeeded_response],
expected_requests=expected_requests) as cml_hook:
create_model_response = cml_hook.create_model(
- project_name=project, model=model)
+ project_id=project, model=model)
self.assertEquals(create_model_response, response_body)
@_SKIP_IF
@@ -266,7 +266,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=[succeeded_response],
expected_requests=expected_requests) as cml_hook:
get_model_response = cml_hook.get_model(
- project_name=project, model_name=model_name)
+ project_id=project, model_name=model_name)
self.assertEquals(get_model_response, response_body)
@_SKIP_IF
@@ -302,7 +302,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=responses,
expected_requests=expected_requests) as cml_hook:
create_job_response = cml_hook.create_job(
- project_name=project, job=my_job)
+ project_id=project, job=my_job)
self.assertEquals(create_job_response, my_job)
@_SKIP_IF
@@ -334,7 +334,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=responses,
expected_requests=expected_requests) as cml_hook:
create_job_response = cml_hook.create_job(
- project_name=project, job=my_job)
+ project_id=project, job=my_job)
self.assertEquals(create_job_response, my_job)
@_SKIP_IF
@@ -386,7 +386,7 @@ class TestCloudMLHook(unittest.TestCase):
expected_requests=expected_requests) as cml_hook:
with self.assertRaises(errors.HttpError):
cml_hook.create_job(
- project_name=project, job=my_job,
+ project_id=project, job=my_job,
use_existing_job_fn=check_input)
my_job_response = ({'status': '200'}, my_job_response_body)
@@ -404,7 +404,7 @@ class TestCloudMLHook(unittest.TestCase):
responses=responses,
expected_requests=expected_requests) as cml_hook:
create_job_response = cml_hook.create_job(
- project_name=project, job=my_job,
+ project_id=project, job=my_job,
use_existing_job_fn=check_input)
self.assertEquals(create_job_response, my_job)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b6d36310/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
index dc8c204..dc2366e 100644
--- a/tests/contrib/operators/test_cloudml_operator.py
+++ b/tests/contrib/operators/test_cloudml_operator.py
@@ -285,7 +285,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
class CloudMLTrainingOperatorTest(unittest.TestCase):
TRAINING_DEFAULT_ARGS = {
- 'project_name': 'test-project',
+ 'project_id': 'test-project',
'job_id': 'test_training',
'package_uris': ['gs://some-bucket/package1'],
'training_python_module': 'trainer',