You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/09/06 16:51:30 UTC
[3/3] incubator-airflow git commit: [AIRFLOW-1567][Airflow-1567]
Renamed cloudml hook and operator to mlengine
[AIRFLOW-1567][Airflow-1567] Renamed cloudml hook and operator to mlengine
Closes #2567 from yk5/cmle
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/af91e2ac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/af91e2ac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/af91e2ac
Branch: refs/heads/master
Commit: af91e2ac0636685c0c1c25ddeba97f78b7009b88
Parents: 86063ba
Author: Younghee Kwon <yo...@google.com>
Authored: Wed Sep 6 09:51:17 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Wed Sep 6 09:51:17 2017 -0700
----------------------------------------------------------------------
airflow/contrib/hooks/gcp_cloudml_hook.py | 269 ---------
airflow/contrib/hooks/gcp_mlengine_hook.py | 269 +++++++++
airflow/contrib/operators/cloudml_operator.py | 565 -------------------
.../contrib/operators/cloudml_operator_utils.py | 245 --------
.../operators/cloudml_prediction_summary.py | 177 ------
airflow/contrib/operators/mlengine_operator.py | 564 ++++++++++++++++++
.../operators/mlengine_operator_utils.py | 245 ++++++++
.../operators/mlengine_prediction_summary.py | 177 ++++++
tests/contrib/hooks/test_gcp_cloudml_hook.py | 413 --------------
tests/contrib/hooks/test_gcp_mlengine_hook.py | 413 ++++++++++++++
.../contrib/operators/test_cloudml_operator.py | 373 ------------
.../operators/test_cloudml_operator_utils.py | 183 ------
.../contrib/operators/test_mlengine_operator.py | 373 ++++++++++++
.../operators/test_mlengine_operator_utils.py | 183 ++++++
14 files changed, 2224 insertions(+), 2225 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
deleted file mode 100644
index e1ff155..0000000
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ /dev/null
@@ -1,269 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-import random
-import time
-from airflow import settings
-from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
-from apiclient.discovery import build
-from apiclient import errors
-from oauth2client.client import GoogleCredentials
-
-logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
-
-
-def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
-
- for i in range(0, max_n):
- try:
- response = request.execute()
- if is_error_func(response):
- raise ValueError(
- 'The response contained an error: {}'.format(response))
- elif is_done_func(response):
- logging.info('Operation is done: {}'.format(response))
- return response
- else:
- time.sleep((2**i) + (random.randint(0, 1000) / 1000))
- except errors.HttpError as e:
- if e.resp.status != 429:
- logging.info(
- 'Something went wrong. Not retrying: {}'.format(e))
- raise
- else:
- time.sleep((2**i) + (random.randint(0, 1000) / 1000))
-
-
-class CloudMLHook(GoogleCloudBaseHook):
-
- def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
- super(CloudMLHook, self).__init__(gcp_conn_id, delegate_to)
- self._cloudml = self.get_conn()
-
- def get_conn(self):
- """
- Returns a Google CloudML service object.
- """
- credentials = GoogleCredentials.get_application_default()
- return build('ml', 'v1', credentials=credentials)
-
- def create_job(self, project_id, job, use_existing_job_fn=None):
- """
- Launches a CloudML job and wait for it to reach a terminal state.
-
- :param project_id: The Google Cloud project id within which CloudML
- job will be launched.
- :type project_id: string
-
- :param job: CloudML Job object that should be provided to the CloudML
- API, such as:
- {
- 'jobId': 'my_job_id',
- 'trainingInput': {
- 'scaleTier': 'STANDARD_1',
- ...
- }
- }
- :type job: dict
-
- :param use_existing_job_fn: In case that a CloudML job with the same
- job_id already exist, this method (if provided) will decide whether
- we should use this existing job, continue waiting for it to finish
- and returning the job object. It should accepts a CloudML job
- object, and returns a boolean value indicating whether it is OK to
- reuse the existing job. If 'use_existing_job_fn' is not provided,
- we by default reuse the existing CloudML job.
- :type use_existing_job_fn: function
-
- :return: The CloudML job object if the job successfully reach a
- terminal state (which might be FAILED or CANCELLED state).
- :rtype: dict
- """
- request = self._cloudml.projects().jobs().create(
- parent='projects/{}'.format(project_id),
- body=job)
- job_id = job['jobId']
-
- try:
- request.execute()
- except errors.HttpError as e:
- # 409 means there is an existing job with the same job ID.
- if e.resp.status == 409:
- if use_existing_job_fn is not None:
- existing_job = self._get_job(project_id, job_id)
- if not use_existing_job_fn(existing_job):
- logging.error(
- 'Job with job_id {} already exist, but it does '
- 'not match our expectation: {}'.format(
- job_id, existing_job))
- raise
- logging.info(
- 'Job with job_id {} already exist. Will waiting for it to '
- 'finish'.format(job_id))
- else:
- logging.error('Failed to create CloudML job: {}'.format(e))
- raise
- return self._wait_for_job_done(project_id, job_id)
-
- def _get_job(self, project_id, job_id):
- """
- Gets a CloudML job based on the job name.
-
- :return: CloudML job object if succeed.
- :rtype: dict
-
- Raises:
- apiclient.errors.HttpError: if HTTP error is returned from server
- """
- job_name = 'projects/{}/jobs/{}'.format(project_id, job_id)
- request = self._cloudml.projects().jobs().get(name=job_name)
- while True:
- try:
- return request.execute()
- except errors.HttpError as e:
- if e.resp.status == 429:
- # polling after 30 seconds when quota failure occurs
- time.sleep(30)
- else:
- logging.error('Failed to get CloudML job: {}'.format(e))
- raise
-
- def _wait_for_job_done(self, project_id, job_id, interval=30):
- """
- Waits for the Job to reach a terminal state.
-
- This method will periodically check the job state until the job reach
- a terminal state.
-
- Raises:
- apiclient.errors.HttpError: if HTTP error is returned when getting
- the job
- """
- assert interval > 0
- while True:
- job = self._get_job(project_id, job_id)
- if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
- return job
- time.sleep(interval)
-
- def create_version(self, project_id, model_name, version_spec):
- """
- Creates the Version on Cloud ML.
-
- Returns the operation if the version was created successfully and
- raises an error otherwise.
- """
- parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
- create_request = self._cloudml.projects().models().versions().create(
- parent=parent_name, body=version_spec)
- response = create_request.execute()
- get_request = self._cloudml.projects().operations().get(
- name=response['name'])
-
- return _poll_with_exponential_delay(
- request=get_request,
- max_n=9,
- is_done_func=lambda resp: resp.get('done', False),
- is_error_func=lambda resp: resp.get('error', None) is not None)
-
- def set_default_version(self, project_id, model_name, version_name):
- """
- Sets a version to be the default. Blocks until finished.
- """
- full_version_name = 'projects/{}/models/{}/versions/{}'.format(
- project_id, model_name, version_name)
- request = self._cloudml.projects().models().versions().setDefault(
- name=full_version_name, body={})
-
- try:
- response = request.execute()
- logging.info(
- 'Successfully set version: {} to default'.format(response))
- return response
- except errors.HttpError as e:
- logging.error('Something went wrong: {}'.format(e))
- raise
-
- def list_versions(self, project_id, model_name):
- """
- Lists all available versions of a model. Blocks until finished.
- """
- result = []
- full_parent_name = 'projects/{}/models/{}'.format(
- project_id, model_name)
- request = self._cloudml.projects().models().versions().list(
- parent=full_parent_name, pageSize=100)
-
- response = request.execute()
- next_page_token = response.get('nextPageToken', None)
- result.extend(response.get('versions', []))
- while next_page_token is not None:
- next_request = self._cloudml.projects().models().versions().list(
- parent=full_parent_name,
- pageToken=next_page_token,
- pageSize=100)
- response = next_request.execute()
- next_page_token = response.get('nextPageToken', None)
- result.extend(response.get('versions', []))
- time.sleep(5)
- return result
-
- def delete_version(self, project_id, model_name, version_name):
- """
- Deletes the given version of a model. Blocks until finished.
- """
- full_name = 'projects/{}/models/{}/versions/{}'.format(
- project_id, model_name, version_name)
- delete_request = self._cloudml.projects().models().versions().delete(
- name=full_name)
- response = delete_request.execute()
- get_request = self._cloudml.projects().operations().get(
- name=response['name'])
-
- return _poll_with_exponential_delay(
- request=get_request,
- max_n=9,
- is_done_func=lambda resp: resp.get('done', False),
- is_error_func=lambda resp: resp.get('error', None) is not None)
-
- def create_model(self, project_id, model):
- """
- Create a Model. Blocks until finished.
- """
- assert model['name'] is not None and model['name'] is not ''
- project = 'projects/{}'.format(project_id)
-
- request = self._cloudml.projects().models().create(
- parent=project, body=model)
- return request.execute()
-
- def get_model(self, project_id, model_name):
- """
- Gets a Model. Blocks until finished.
- """
- assert model_name is not None and model_name is not ''
- full_model_name = 'projects/{}/models/{}'.format(
- project_id, model_name)
- request = self._cloudml.projects().models().get(name=full_model_name)
- try:
- return request.execute()
- except errors.HttpError as e:
- if e.resp.status == 404:
- logging.error('Model was not found: {}'.format(e))
- return None
- raise
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/hooks/gcp_mlengine_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py
new file mode 100644
index 0000000..47d9700
--- /dev/null
+++ b/airflow/contrib/hooks/gcp_mlengine_hook.py
@@ -0,0 +1,269 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import random
+import time
+from airflow import settings
+from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
+from apiclient.discovery import build
+from apiclient import errors
+from oauth2client.client import GoogleCredentials
+
+logging.getLogger('GoogleCloudMLEngine').setLevel(settings.LOGGING_LEVEL)
+
+
+def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
+
+ for i in range(0, max_n):
+ try:
+ response = request.execute()
+ if is_error_func(response):
+ raise ValueError(
+ 'The response contained an error: {}'.format(response))
+ elif is_done_func(response):
+ logging.info('Operation is done: {}'.format(response))
+ return response
+ else:
+ time.sleep((2**i) + (random.randint(0, 1000) / 1000))
+ except errors.HttpError as e:
+ if e.resp.status != 429:
+ logging.info(
+ 'Something went wrong. Not retrying: {}'.format(e))
+ raise
+ else:
+ time.sleep((2**i) + (random.randint(0, 1000) / 1000))
+
+
+class MLEngineHook(GoogleCloudBaseHook):
+
+ def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
+ super(MLEngineHook, self).__init__(gcp_conn_id, delegate_to)
+ self._mlengine = self.get_conn()
+
+ def get_conn(self):
+ """
+ Returns a Google MLEngine service object.
+ """
+ credentials = GoogleCredentials.get_application_default()
+ return build('ml', 'v1', credentials=credentials)
+
+ def create_job(self, project_id, job, use_existing_job_fn=None):
+ """
+ Launches a MLEngine job and wait for it to reach a terminal state.
+
+ :param project_id: The Google Cloud project id within which MLEngine
+ job will be launched.
+ :type project_id: string
+
+ :param job: MLEngine Job object that should be provided to the MLEngine
+ API, such as:
+ {
+ 'jobId': 'my_job_id',
+ 'trainingInput': {
+ 'scaleTier': 'STANDARD_1',
+ ...
+ }
+ }
+ :type job: dict
+
+ :param use_existing_job_fn: In case that a MLEngine job with the same
+ job_id already exist, this method (if provided) will decide whether
+ we should use this existing job, continue waiting for it to finish
+ and returning the job object. It should accepts a MLEngine job
+ object, and returns a boolean value indicating whether it is OK to
+ reuse the existing job. If 'use_existing_job_fn' is not provided,
+ we by default reuse the existing MLEngine job.
+ :type use_existing_job_fn: function
+
+ :return: The MLEngine job object if the job successfully reach a
+ terminal state (which might be FAILED or CANCELLED state).
+ :rtype: dict
+ """
+ request = self._mlengine.projects().jobs().create(
+ parent='projects/{}'.format(project_id),
+ body=job)
+ job_id = job['jobId']
+
+ try:
+ request.execute()
+ except errors.HttpError as e:
+ # 409 means there is an existing job with the same job ID.
+ if e.resp.status == 409:
+ if use_existing_job_fn is not None:
+ existing_job = self._get_job(project_id, job_id)
+ if not use_existing_job_fn(existing_job):
+ logging.error(
+ 'Job with job_id {} already exist, but it does '
+ 'not match our expectation: {}'.format(
+ job_id, existing_job))
+ raise
+ logging.info(
+ 'Job with job_id {} already exist. Will waiting for it to '
+ 'finish'.format(job_id))
+ else:
+ logging.error('Failed to create MLEngine job: {}'.format(e))
+ raise
+ return self._wait_for_job_done(project_id, job_id)
+
+ def _get_job(self, project_id, job_id):
+ """
+ Gets a MLEngine job based on the job name.
+
+ :return: MLEngine job object if succeed.
+ :rtype: dict
+
+ Raises:
+ apiclient.errors.HttpError: if HTTP error is returned from server
+ """
+ job_name = 'projects/{}/jobs/{}'.format(project_id, job_id)
+ request = self._mlengine.projects().jobs().get(name=job_name)
+ while True:
+ try:
+ return request.execute()
+ except errors.HttpError as e:
+ if e.resp.status == 429:
+ # polling after 30 seconds when quota failure occurs
+ time.sleep(30)
+ else:
+ logging.error('Failed to get MLEngine job: {}'.format(e))
+ raise
+
+ def _wait_for_job_done(self, project_id, job_id, interval=30):
+ """
+ Waits for the Job to reach a terminal state.
+
+ This method will periodically check the job state until the job reach
+ a terminal state.
+
+ Raises:
+ apiclient.errors.HttpError: if HTTP error is returned when getting
+ the job
+ """
+ assert interval > 0
+ while True:
+ job = self._get_job(project_id, job_id)
+ if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
+ return job
+ time.sleep(interval)
+
+ def create_version(self, project_id, model_name, version_spec):
+ """
+ Creates the Version on Google Cloud ML Engine.
+
+ Returns the operation if the version was created successfully and
+ raises an error otherwise.
+ """
+ parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
+ create_request = self._mlengine.projects().models().versions().create(
+ parent=parent_name, body=version_spec)
+ response = create_request.execute()
+ get_request = self._mlengine.projects().operations().get(
+ name=response['name'])
+
+ return _poll_with_exponential_delay(
+ request=get_request,
+ max_n=9,
+ is_done_func=lambda resp: resp.get('done', False),
+ is_error_func=lambda resp: resp.get('error', None) is not None)
+
+ def set_default_version(self, project_id, model_name, version_name):
+ """
+ Sets a version to be the default. Blocks until finished.
+ """
+ full_version_name = 'projects/{}/models/{}/versions/{}'.format(
+ project_id, model_name, version_name)
+ request = self._mlengine.projects().models().versions().setDefault(
+ name=full_version_name, body={})
+
+ try:
+ response = request.execute()
+ logging.info(
+ 'Successfully set version: {} to default'.format(response))
+ return response
+ except errors.HttpError as e:
+ logging.error('Something went wrong: {}'.format(e))
+ raise
+
+ def list_versions(self, project_id, model_name):
+ """
+ Lists all available versions of a model. Blocks until finished.
+ """
+ result = []
+ full_parent_name = 'projects/{}/models/{}'.format(
+ project_id, model_name)
+ request = self._mlengine.projects().models().versions().list(
+ parent=full_parent_name, pageSize=100)
+
+ response = request.execute()
+ next_page_token = response.get('nextPageToken', None)
+ result.extend(response.get('versions', []))
+ while next_page_token is not None:
+ next_request = self._mlengine.projects().models().versions().list(
+ parent=full_parent_name,
+ pageToken=next_page_token,
+ pageSize=100)
+ response = next_request.execute()
+ next_page_token = response.get('nextPageToken', None)
+ result.extend(response.get('versions', []))
+ time.sleep(5)
+ return result
+
+ def delete_version(self, project_id, model_name, version_name):
+ """
+ Deletes the given version of a model. Blocks until finished.
+ """
+ full_name = 'projects/{}/models/{}/versions/{}'.format(
+ project_id, model_name, version_name)
+ delete_request = self._mlengine.projects().models().versions().delete(
+ name=full_name)
+ response = delete_request.execute()
+ get_request = self._mlengine.projects().operations().get(
+ name=response['name'])
+
+ return _poll_with_exponential_delay(
+ request=get_request,
+ max_n=9,
+ is_done_func=lambda resp: resp.get('done', False),
+ is_error_func=lambda resp: resp.get('error', None) is not None)
+
+ def create_model(self, project_id, model):
+ """
+ Create a Model. Blocks until finished.
+ """
+ assert model['name'] is not None and model['name'] is not ''
+ project = 'projects/{}'.format(project_id)
+
+ request = self._mlengine.projects().models().create(
+ parent=project, body=model)
+ return request.execute()
+
+ def get_model(self, project_id, model_name):
+ """
+ Gets a Model. Blocks until finished.
+ """
+ assert model_name is not None and model_name is not ''
+ full_model_name = 'projects/{}/models/{}'.format(
+ project_id, model_name)
+ request = self._mlengine.projects().models().get(name=full_model_name)
+ try:
+ return request.execute()
+ except errors.HttpError as e:
+ if e.resp.status == 404:
+ logging.error('Model was not found: {}'.format(e))
+ return None
+ raise
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
deleted file mode 100644
index 6bdd516..0000000
--- a/airflow/contrib/operators/cloudml_operator.py
+++ /dev/null
@@ -1,565 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the 'License'); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import re
-
-from airflow import settings
-from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
-from airflow.exceptions import AirflowException
-from airflow.operators import BaseOperator
-from airflow.utils.decorators import apply_defaults
-from apiclient import errors
-
-
-logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL)
-
-
-def _create_prediction_input(project_id,
- region,
- data_format,
- input_paths,
- output_path,
- model_name=None,
- version_name=None,
- uri=None,
- max_worker_count=None,
- runtime_version=None):
- """
- Create the batch prediction input from the given parameters.
-
- Args:
- A subset of arguments documented in __init__ method of class
- CloudMLBatchPredictionOperator
-
- Returns:
- A dictionary representing the predictionInput object as documented
- in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
-
- Raises:
- ValueError: if a unique model/version origin cannot be determined.
- """
-
- prediction_input = {
- 'dataFormat': data_format,
- 'inputPaths': input_paths,
- 'outputPath': output_path,
- 'region': region
- }
-
- if uri:
- if model_name or version_name:
- logging.error(
- 'Ambiguous model origin: Both uri and model/version name are '
- 'provided.')
- raise ValueError('Ambiguous model origin.')
- prediction_input['uri'] = uri
- elif model_name:
- origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
- if not version_name:
- prediction_input['modelName'] = origin_name
- else:
- prediction_input['versionName'] = \
- origin_name + '/versions/{}'.format(version_name)
- else:
- logging.error(
- 'Missing model origin: Batch prediction expects a model, '
- 'a model & version combination, or a URI to savedModel.')
- raise ValueError('Missing model origin.')
-
- if max_worker_count:
- prediction_input['maxWorkerCount'] = max_worker_count
- if runtime_version:
- prediction_input['runtimeVersion'] = runtime_version
-
- return prediction_input
-
-
-def _normalize_cloudml_job_id(job_id):
- """
- Replaces invalid CloudML job_id characters with '_'.
-
- This also adds a leading 'z' in case job_id starts with an invalid
- character.
-
- Args:
- job_id: A job_id str that may have invalid characters.
-
- Returns:
- A valid job_id representation.
- """
- match = re.search(r'\d', job_id)
- if match and match.start() is 0:
- job_id = 'z_{}'.format(job_id)
- return re.sub('[^0-9a-zA-Z]+', '_', job_id)
-
-
-class CloudMLBatchPredictionOperator(BaseOperator):
- """
- Start a Cloud ML prediction job.
-
- NOTE: For model origin, users should consider exactly one from the
- three options below:
- 1. Populate 'uri' field only, which should be a GCS location that
- points to a tensorflow savedModel directory.
- 2. Populate 'model_name' field only, which refers to an existing
- model, and the default version of the model will be used.
- 3. Populate both 'model_name' and 'version_name' fields, which
- refers to a specific version of a specific model.
-
- In options 2 and 3, both model and version name should contain the
- minimal identifier. For instance, call
- CloudMLBatchPredictionOperator(
- ...,
- model_name='my_model',
- version_name='my_version',
- ...)
- if the desired model version is
- "projects/my_project/models/my_model/versions/my_version".
-
-
- :param project_id: The Google Cloud project name where the
- prediction job is submitted.
- :type project_id: string
-
- :param job_id: A unique id for the prediction job on Google Cloud
- ML Engine.
- :type job_id: string
-
- :param data_format: The format of the input data.
- It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
- or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
- :type data_format: string
-
- :param input_paths: A list of GCS paths of input data for batch
- prediction. Accepting wildcard operator *, but only at the end.
- :type input_paths: list of string
-
- :param output_path: The GCS path where the prediction results are
- written to.
- :type output_path: string
-
- :param region: The Google Compute Engine region to run the
- prediction job in.:
- :type region: string
-
- :param model_name: The Google Cloud ML model to use for prediction.
- If version_name is not provided, the default version of this
- model will be used.
- Should not be None if version_name is provided.
- Should be None if uri is provided.
- :type model_name: string
-
- :param version_name: The Google Cloud ML model version to use for
- prediction.
- Should be None if uri is provided.
- :type version_name: string
-
- :param uri: The GCS path of the saved model to use for prediction.
- Should be None if model_name is provided.
- It should be a GCS path pointing to a tensorflow SavedModel.
- :type uri: string
-
- :param max_worker_count: The maximum number of workers to be used
- for parallel processing. Defaults to 10 if not specified.
- :type max_worker_count: int
-
- :param runtime_version: The Google Cloud ML runtime version to use
- for batch prediction.
- :type runtime_version: string
-
- :param gcp_conn_id: The connection ID used for connection to Google
- Cloud Platform.
- :type gcp_conn_id: string
-
- :param delegate_to: The account to impersonate, if any.
- For this to work, the service account making the request must
- have doamin-wide delegation enabled.
- :type delegate_to: string
-
- Raises:
- ValueError: if a unique model/version origin cannot be determined.
- """
-
- template_fields = [
- "prediction_job_request",
- ]
-
- @apply_defaults
- def __init__(self,
- project_id,
- job_id,
- region,
- data_format,
- input_paths,
- output_path,
- model_name=None,
- version_name=None,
- uri=None,
- max_worker_count=None,
- runtime_version=None,
- gcp_conn_id='google_cloud_default',
- delegate_to=None,
- *args,
- **kwargs):
- super(CloudMLBatchPredictionOperator, self).__init__(*args, **kwargs)
-
- self.project_id = project_id
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
-
- try:
- prediction_input = _create_prediction_input(
- project_id, region, data_format, input_paths, output_path,
- model_name, version_name, uri, max_worker_count,
- runtime_version)
- except ValueError as e:
- logging.error(
- 'Cannot create batch prediction job request due to: {}'
- .format(str(e)))
- raise
-
- self.prediction_job_request = {
- 'jobId': _normalize_cloudml_job_id(job_id),
- 'predictionInput': prediction_input
- }
-
- def execute(self, context):
- hook = CloudMLHook(self.gcp_conn_id, self.delegate_to)
-
- def check_existing_job(existing_job):
- return existing_job.get('predictionInput', None) == \
- self.prediction_job_request['predictionInput']
- try:
- finished_prediction_job = hook.create_job(
- self.project_id,
- self.prediction_job_request,
- check_existing_job)
- except errors.HttpError:
- raise
-
- if finished_prediction_job['state'] != 'SUCCEEDED':
- logging.error(
- 'Batch prediction job failed: %s',
- str(finished_prediction_job))
- raise RuntimeError(finished_prediction_job['errorMessage'])
-
- return finished_prediction_job['predictionOutput']
-
-
-class CloudMLModelOperator(BaseOperator):
- """
- Operator for managing a Google Cloud ML model.
-
- :param model: A dictionary containing the information about the model.
- If the `operation` is `create`, then the `model` parameter should
- contain all the information about this model such as `name`.
-
- If the `operation` is `get`, the `model` parameter
- should contain the `name` of the model.
- :type model: dict
-
- :param project_id: The Google Cloud project name to which CloudML
- model belongs.
- :type project_id: string
-
- :param gcp_conn_id: The connection ID to use when fetching connection info.
- :type gcp_conn_id: string
-
- :param operation: The operation to perform. Available operations are:
- 'create': Creates a new model as provided by the `model` parameter.
- 'get': Gets a particular model where the name is specified in `model`.
-
- :param delegate_to: The account to impersonate, if any.
- For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :type delegate_to: string
- """
-
- template_fields = [
- '_model',
- '_model_name',
- ]
-
- @apply_defaults
- def __init__(self,
- project_id,
- model,
- gcp_conn_id='google_cloud_default',
- operation='create',
- delegate_to=None,
- *args,
- **kwargs):
- super(CloudMLModelOperator, self).__init__(*args, **kwargs)
- self._model = model
- self._operation = operation
- self._gcp_conn_id = gcp_conn_id
- self._delegate_to = delegate_to
- self._project_id = project_id
-
- def execute(self, context):
- hook = CloudMLHook(
- gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
- if self._operation == 'create':
- hook.create_model(self._project_id, self._model)
- elif self._operation == 'get':
- hook.get_model(self._project_id, self._model['name'])
- else:
- raise ValueError('Unknown operation: {}'.format(self._operation))
-
-
-class CloudMLVersionOperator(BaseOperator):
- """
- Operator for managing a Google Cloud ML version.
-
- :param model_name: The name of the Google Cloud ML model that the version
- belongs to.
- :type model_name: string
-
- :param project_id: The Google Cloud project name to which CloudML
- model belongs.
- :type project_id: string
-
- :param version: A dictionary containing the information about the version.
- If the `operation` is `create`, `version` should contain all the
- information about this version such as name, and deploymentUrl.
- If the `operation` is `get` or `delete`, the `version` parameter
- should contain the `name` of the version.
- If it is None, the only `operation` possible would be `list`.
- :type version: dict
-
- :param version_name: A name to use for the version being operated upon. If
- not None and the `version` argument is None or does not have a value for
- the `name` key, then this will be populated in the payload for the
- `name` key.
- :type version_name: string
-
- :param gcp_conn_id: The connection ID to use when fetching connection info.
- :type gcp_conn_id: string
-
- :param operation: The operation to perform. Available operations are:
- 'create': Creates a new version in the model specified by `model_name`,
- in which case the `version` parameter should contain all the
- information to create that version
- (e.g. `name`, `deploymentUrl`).
- 'get': Gets full information of a particular version in the model
- specified by `model_name`.
- The name of the version should be specified in the `version`
- parameter.
-
- 'list': Lists all available versions of the model specified
- by `model_name`.
-
- 'delete': Deletes the version specified in `version` parameter from the
- model specified by `model_name`).
- The name of the version should be specified in the `version`
- parameter.
- :type operation: string
-
- :param delegate_to: The account to impersonate, if any.
- For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :type delegate_to: string
- """
-
- template_fields = [
- '_model_name',
- '_version',
- '_version_name',
- ]
-
- @apply_defaults
- def __init__(self,
- model_name,
- project_id,
- version=None,
- version_name=None,
- gcp_conn_id='google_cloud_default',
- operation='create',
- delegate_to=None,
- *args,
- **kwargs):
-
- super(CloudMLVersionOperator, self).__init__(*args, **kwargs)
- self._model_name = model_name
- self._version = version or {}
- self._version_name = version_name
- self._gcp_conn_id = gcp_conn_id
- self._delegate_to = delegate_to
- self._project_id = project_id
- self._operation = operation
-
- def execute(self, context):
- if 'name' not in self._version:
- self._version['name'] = self._version_name
-
- hook = CloudMLHook(
- gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
-
- if self._operation == 'create':
- assert self._version is not None
- return hook.create_version(self._project_id, self._model_name,
- self._version)
- elif self._operation == 'set_default':
- return hook.set_default_version(
- self._project_id, self._model_name,
- self._version['name'])
- elif self._operation == 'list':
- return hook.list_versions(self._project_id, self._model_name)
- elif self._operation == 'delete':
- return hook.delete_version(self._project_id, self._model_name,
- self._version['name'])
- else:
- raise ValueError('Unknown operation: {}'.format(self._operation))
-
-
-class CloudMLTrainingOperator(BaseOperator):
- """
- Operator for launching a CloudML training job.
-
- :param project_id: The Google Cloud project name within which CloudML
- training job should run. This field could be templated.
- :type project_id: string
-
- :param job_id: A unique templated id for the submitted Google CloudML
- training job.
- :type job_id: string
-
- :param package_uris: A list of package locations for CloudML training job,
- which should include the main training program + any additional
- dependencies.
- :type package_uris: string
-
- :param training_python_module: The Python module name to run within CloudML
- training job after installing 'package_uris' packages.
- :type training_python_module: string
-
- :param training_args: A list of templated command line arguments to pass to
- the CloudML training program.
- :type training_args: string
-
- :param region: The Google Compute Engine region to run the CloudML training
- job in. This field could be templated.
- :type region: string
-
- :param scale_tier: Resource tier for CloudML training job.
- :type scale_tier: string
-
- :param gcp_conn_id: The connection ID to use when fetching connection info.
- :type gcp_conn_id: string
-
- :param delegate_to: The account to impersonate, if any.
- For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :type delegate_to: string
-
- :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
- training job will be launched, but the CloudML training job request
- will be printed out. In 'CLOUD' mode, a real CloudML training job
- creation request will be issued.
- :type mode: string
- """
-
- template_fields = [
- '_project_id',
- '_job_id',
- '_package_uris',
- '_training_python_module',
- '_training_args',
- '_region',
- '_scale_tier',
- ]
-
- @apply_defaults
- def __init__(self,
- project_id,
- job_id,
- package_uris,
- training_python_module,
- training_args,
- region,
- scale_tier=None,
- gcp_conn_id='google_cloud_default',
- delegate_to=None,
- mode='PRODUCTION',
- *args,
- **kwargs):
- super(CloudMLTrainingOperator, self).__init__(*args, **kwargs)
- self._project_id = project_id
- self._job_id = job_id
- self._package_uris = package_uris
- self._training_python_module = training_python_module
- self._training_args = training_args
- self._region = region
- self._scale_tier = scale_tier
- self._gcp_conn_id = gcp_conn_id
- self._delegate_to = delegate_to
- self._mode = mode
-
- if not self._project_id:
- raise AirflowException('Google Cloud project id is required.')
- if not self._job_id:
- raise AirflowException(
- 'An unique job id is required for Google CloudML training '
- 'job.')
- if not package_uris:
- raise AirflowException(
- 'At least one python package is required for CloudML '
- 'Training job.')
- if not training_python_module:
- raise AirflowException(
- 'Python module name to run after installing required '
- 'packages is required.')
- if not self._region:
- raise AirflowException('Google Compute Engine region is required.')
-
- def execute(self, context):
- job_id = _normalize_cloudml_job_id(self._job_id)
- training_request = {
- 'jobId': job_id,
- 'trainingInput': {
- 'scaleTier': self._scale_tier,
- 'packageUris': self._package_uris,
- 'pythonModule': self._training_python_module,
- 'region': self._region,
- 'args': self._training_args,
- }
- }
-
- if self._mode == 'DRY_RUN':
- logging.info('In dry_run mode.')
- logging.info(
- 'CloudML Training job request is: {}'.format(training_request))
- return
-
- hook = CloudMLHook(
- gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
-
- # Helper method to check if the existing job's training input is the
- # same as the request we get here.
- def check_existing_job(existing_job):
- return existing_job.get('trainingInput', None) == \
- training_request['trainingInput']
- try:
- finished_training_job = hook.create_job(
- self._project_id, training_request, check_existing_job)
- except errors.HttpError:
- raise
-
- if finished_training_job['state'] != 'SUCCEEDED':
- logging.error('CloudML training job failed: {}'.format(
- str(finished_training_job)))
- raise RuntimeError(finished_training_job['errorMessage'])
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py
deleted file mode 100644
index 81cd54f..0000000
--- a/airflow/contrib/operators/cloudml_operator_utils.py
+++ /dev/null
@@ -1,245 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the 'License'); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import json
-import os
-import re
-
-import dill
-
-from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
-from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
-from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
-from airflow.exceptions import AirflowException
-from airflow.operators.python_operator import PythonOperator
-from six.moves.urllib.parse import urlsplit
-
-def create_evaluate_ops(task_prefix,
- data_format,
- input_paths,
- prediction_path,
- metric_fn_and_keys,
- validate_fn,
- batch_prediction_job_id=None,
- project_id=None,
- region=None,
- dataflow_options=None,
- model_uri=None,
- model_name=None,
- version_name=None,
- dag=None):
- """
- Creates Operators needed for model evaluation and returns.
-
- It gets prediction over inputs via Cloud ML Engine BatchPrediction API by
- calling CloudMLBatchPredictionOperator, then summarize and validate
- the result via Cloud Dataflow using DataFlowPythonOperator.
-
- For details and pricing about Batch prediction, please refer to the website
- https://cloud.google.com/ml-engine/docs/how-tos/batch-predict
- and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/
-
- It returns three chained operators for prediction, summary, and validation,
- named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation,
- respectively.
- (<prefix> should contain only alphanumeric characters or hyphen.)
-
- The upstream and downstream can be set accordingly like:
- pred, _, val = create_evaluate_ops(...)
- pred.set_upstream(upstream_op)
- ...
- downstream_op.set_upstream(val)
-
- Callers will provide two python callables, metric_fn and validate_fn, in
- order to customize the evaluation behavior as they wish.
- - metric_fn receives a dictionary per instance derived from json in the
- batch prediction result. The keys might vary depending on the model.
- It should return a tuple of metrics.
- - validation_fn receives a dictionary of the averaged metrics that metric_fn
- generated over all instances.
- The key/value of the dictionary matches to what's given by
- metric_fn_and_keys arg.
- The dictionary contains an additional metric, 'count' to represent the
- total number of instances received for evaluation.
- The function would raise an exception to mark the task as failed, in a
- case the validation result is not okay to proceed (i.e. to set the trained
- version as default).
-
- Typical examples are like this:
-
- def get_metric_fn_and_keys():
- import math # imports should be outside of the metric_fn below.
- def error_and_squared_error(inst):
- label = float(inst['input_label'])
- classes = float(inst['classes']) # 0 or 1
- err = abs(classes-label)
- squared_err = math.pow(classes-label, 2)
- return (err, squared_err) # returns a tuple.
- return error_and_squared_error, ['err', 'mse'] # key order must match.
-
- def validate_err_and_count(summary):
- if summary['err'] > 0.2:
- raise ValueError('Too high err>0.2; summary=%s' % summary)
- if summary['mse'] > 0.05:
- raise ValueError('Too high mse>0.05; summary=%s' % summary)
- if summary['count'] < 1000:
- raise ValueError('Too few instances<1000; summary=%s' % summary)
- return summary
-
- For the details on the other BatchPrediction-related arguments (project_id,
- job_id, region, data_format, input_paths, prediction_path, model_uri),
- please refer to CloudMLBatchPredictionOperator too.
-
- :param task_prefix: a prefix for the tasks. Only alphanumeric characters and
- hyphen are allowed (no underscores), since this will be used as dataflow
- job name, which doesn't allow other characters.
- :type task_prefix: string
-
- :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
- :type data_format: string
-
- :param input_paths: a list of input paths to be sent to BatchPrediction.
- :type input_paths: list of strings
-
- :param prediction_path: GCS path to put the prediction results in.
- :type prediction_path: string
-
- :param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
- - metric_fn is a function that accepts a dictionary (for an instance),
- and returns a tuple of metric(s) that it calculates.
- - metric_keys is a list of strings to denote the key of each metric.
- :type metric_fn_and_keys: tuple of a function and a list of strings
-
- :param validate_fn: a function to validate whether the averaged metric(s) is
- good enough to push the model.
- :type validate_fn: function
-
- :param batch_prediction_job_id: the id to use for the Cloud ML Batch
- prediction job. Passed directly to the CloudMLBatchPredictionOperator as
- the job_id argument.
- :type batch_prediction_job_id: string
-
- :param project_id: the Google Cloud Platform project id in which to execute
- Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
- `default_args['project_id']` will be used.
- :type project_id: string
-
- :param region: the Google Cloud Platform region in which to execute Cloud ML
- Batch Prediction and Dataflow jobs. If None, then the `dag`'s
- `default_args['region']` will be used.
- :type region: string
-
- :param dataflow_options: options to run Dataflow jobs. If None, then the
- `dag`'s `default_args['dataflow_default_options']` will be used.
- :type dataflow_options: dictionary
-
- :param model_uri: GCS path of the model exported by Tensorflow using
- tensorflow.estimator.export_savedmodel(). It cannot be used with
- model_name or version_name below. See CloudMLBatchPredictionOperator for
- more detail.
- :type model_uri: string
-
- :param model_name: Used to indicate a model to use for prediction. Can be
- used in combination with version_name, but cannot be used together with
- model_uri. See CloudMLBatchPredictionOperator for more detail. If None,
- then the `dag`'s `default_args['model_name']` will be used.
- :type model_name: string
-
- :param version_name: Used to indicate a model version to use for prediciton,
- in combination with model_name. Cannot be used together with model_uri.
- See CloudMLBatchPredictionOperator for more detail. If None, then the
- `dag`'s `default_args['version_name']` will be used.
- :type version_name: string
-
- :param dag: The `DAG` to use for all Operators.
- :type dag: airflow.DAG
-
- :returns: a tuple of three operators, (prediction, summary, validation)
- :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
- PythonOperator)
- """
-
- # Verify that task_prefix doesn't have any special characters except hyphen
- # '-', which is the only allowed non-alphanumeric character by Dataflow.
- if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
- raise AirflowException(
- "Malformed task_id for DataFlowPythonOperator (only alphanumeric "
- "and hyphens are allowed but got: " + task_prefix)
-
- metric_fn, metric_keys = metric_fn_and_keys
- if not callable(metric_fn):
- raise AirflowException("`metric_fn` param must be callable.")
- if not callable(validate_fn):
- raise AirflowException("`validate_fn` param must be callable.")
-
- if dag is not None and dag.default_args is not None:
- default_args = dag.default_args
- project_id = project_id or default_args.get('project_id')
- region = region or default_args.get('region')
- model_name = model_name or default_args.get('model_name')
- version_name = version_name or default_args.get('version_name')
- dataflow_options = dataflow_options or \
- default_args.get('dataflow_default_options')
-
- evaluate_prediction = CloudMLBatchPredictionOperator(
- task_id=(task_prefix + "-prediction"),
- project_id=project_id,
- job_id=batch_prediction_job_id,
- region=region,
- data_format=data_format,
- input_paths=input_paths,
- output_path=prediction_path,
- uri=model_uri,
- model_name=model_name,
- version_name=version_name,
- dag=dag)
-
- metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True))
- evaluate_summary = DataFlowPythonOperator(
- task_id=(task_prefix + "-summary"),
- py_options=["-m"],
- py_file="airflow.contrib.operators.cloudml_prediction_summary",
- dataflow_default_options=dataflow_options,
- options={
- "prediction_path": prediction_path,
- "metric_fn_encoded": metric_fn_encoded,
- "metric_keys": ','.join(metric_keys)
- },
- dag=dag)
- evaluate_summary.set_upstream(evaluate_prediction)
-
- def apply_validate_fn(*args, **kwargs):
- prediction_path = kwargs["templates_dict"]["prediction_path"]
- scheme, bucket, obj, _, _ = urlsplit(prediction_path)
- if scheme != "gs" or not bucket or not obj:
- raise ValueError("Wrong format prediction_path: %s",
- prediction_path)
- summary = os.path.join(obj.strip("/"),
- "prediction.summary.json")
- gcs_hook = GoogleCloudStorageHook()
- summary = json.loads(gcs_hook.download(bucket, summary))
- return validate_fn(summary)
-
- evaluate_validation = PythonOperator(
- task_id=(task_prefix + "-validation"),
- python_callable=apply_validate_fn,
- provide_context=True,
- templates_dict={"prediction_path": prediction_path},
- dag=dag)
- evaluate_validation.set_upstream(evaluate_summary)
-
- return evaluate_prediction, evaluate_summary, evaluate_validation
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/cloudml_prediction_summary.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_prediction_summary.py b/airflow/contrib/operators/cloudml_prediction_summary.py
deleted file mode 100644
index 3128dc3..0000000
--- a/airflow/contrib/operators/cloudml_prediction_summary.py
+++ /dev/null
@@ -1,177 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the 'License'); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
-
-It accepts a user function to calculate the metric(s) per instance in
-the prediction results, then aggregates to output as a summary.
-
-Args:
- --prediction_path:
- The GCS folder that contains BatchPrediction results, containing
- prediction.results-NNNNN-of-NNNNN files in the json format.
- Output will be also stored in this folder, as 'prediction.summary.json'.
-
- --metric_fn_encoded:
- An encoded function that calculates and returns a tuple of metric(s)
- for a given instance (as a dictionary). It should be encoded
- via base64.b64encode(dill.dumps(fn, recurse=True)).
-
- --metric_keys:
- A comma-separated key(s) of the aggregated metric(s) in the summary
- output. The order and the size of the keys must match to the output
- of metric_fn.
- The summary will have an additional key, 'count', to represent the
- total number of instances, so the keys shouldn't include 'count'.
-
-# Usage example:
-def get_metric_fn():
- import math # all imports must be outside of the function to be passed.
- def metric_fn(inst):
- label = float(inst["input_label"])
- classes = float(inst["classes"])
- prediction = float(inst["scores"][1])
- log_loss = math.log(1 + math.exp(
- -(label * 2 - 1) * math.log(prediction / (1 - prediction))))
- squared_err = (classes-label)**2
- return (log_loss, squared_err)
- return metric_fn
-metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
-
-airflow.contrib.operators.DataFlowPythonOperator(
- task_id="summary-prediction",
- py_options=["-m"],
- py_file="airflow.contrib.operators.cloudml_prediction_summary",
- options={
- "prediction_path": prediction_path,
- "metric_fn_encoded": metric_fn_encoded,
- "metric_keys": "log_loss,mse"
- },
- dataflow_default_options={
- "project": "xxx", "region": "us-east1",
- "staging_location": "gs://yy", "temp_location": "gs://zz",
- })
- >> dag
-
-# When the input file is like the following:
-{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
-{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
-{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
-{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
-
-# The output file will be:
-{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
-
-# To test outside of the dag:
-subprocess.check_call(["python",
- "-m",
- "airflow.contrib.operators.cloudml_prediction_summary",
- "--prediction_path=gs://...",
- "--metric_fn_encoded=" + metric_fn_encoded,
- "--metric_keys=log_loss,mse",
- "--runner=DataflowRunner",
- "--staging_location=gs://...",
- "--temp_location=gs://...",
- ])
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import base64
-import json
-import logging
-import os
-
-import apache_beam as beam
-import dill
-
-
-class JsonCoder(object):
- def encode(self, x):
- return json.dumps(x)
-
- def decode(self, x):
- return json.loads(x)
-
-
-@beam.ptransform_fn
-def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name
- return (
- pcoll
- | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
- | "PairWith1" >> beam.Map(lambda tup: tup + (1,))
- | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(
- *([sum] * (len(metric_keys) + 1))))
- | "AverageAndMakeDict" >> beam.Map(
- lambda tup: dict(
- [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] +
- [("count", tup[-1])])))
-
-
-def run(argv=None):
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--prediction_path", required=True,
- help=(
- "The GCS folder that contains BatchPrediction results, containing "
- "prediction.results-NNNNN-of-NNNNN files in the json format. "
- "Output will be also stored in this folder, as a file"
- "'prediction.summary.json'."))
- parser.add_argument(
- "--metric_fn_encoded", required=True,
- help=(
- "An encoded function that calculates and returns a tuple of "
- "metric(s) for a given instance (as a dictionary). It should be "
- "encoded via base64.b64encode(dill.dumps(fn, recurse=True))."))
- parser.add_argument(
- "--metric_keys", required=True,
- help=(
- "A comma-separated keys of the aggregated metric(s) in the summary "
- "output. The order and the size of the keys must match to the "
- "output of metric_fn. The summary will have an additional key, "
- "'count', to represent the total number of instances, so this flag "
- "shouldn't include 'count'."))
- known_args, pipeline_args = parser.parse_known_args(argv)
-
- metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded))
- if not callable(metric_fn):
- raise ValueError("--metric_fn_encoded must be an encoded callable.")
- metric_keys = known_args.metric_keys.split(",")
-
- with beam.Pipeline(
- options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
- # This is apache-beam ptransform's convention
- # pylint: disable=no-value-for-parameter
- _ = (p
- | "ReadPredictionResult" >> beam.io.ReadFromText(
- os.path.join(known_args.prediction_path,
- "prediction.results-*-of-*"),
- coder=JsonCoder())
- | "Summary" >> MakeSummary(metric_fn, metric_keys)
- | "Write" >> beam.io.WriteToText(
- os.path.join(known_args.prediction_path,
- "prediction.summary.json"),
- shard_name_template='', # without trailing -NNNNN-of-NNNNN.
- coder=JsonCoder()))
- # pylint: enable=no-value-for-parameter
-
-
-if __name__ == "__main__":
- logging.getLogger().setLevel(logging.INFO)
- run()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py
new file mode 100644
index 0000000..7476825
--- /dev/null
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -0,0 +1,564 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from airflow import settings
+from airflow.contrib.hooks.gcp_mlengine_hook import MLEngineHook
+from airflow.exceptions import AirflowException
+from airflow.operators import BaseOperator
+from airflow.utils.decorators import apply_defaults
+from apiclient import errors
+
+
+logging.getLogger('GoogleCloudMLEngine').setLevel(settings.LOGGING_LEVEL)
+
+
+def _create_prediction_input(project_id,
+ region,
+ data_format,
+ input_paths,
+ output_path,
+ model_name=None,
+ version_name=None,
+ uri=None,
+ max_worker_count=None,
+ runtime_version=None):
+ """
+ Create the batch prediction input from the given parameters.
+
+ Args:
+ A subset of arguments documented in __init__ method of class
+ MLEngineBatchPredictionOperator
+
+ Returns:
+ A dictionary representing the predictionInput object as documented
+ in https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs.
+
+ Raises:
+ ValueError: if a unique model/version origin cannot be determined.
+ """
+
+ prediction_input = {
+ 'dataFormat': data_format,
+ 'inputPaths': input_paths,
+ 'outputPath': output_path,
+ 'region': region
+ }
+
+ if uri:
+ if model_name or version_name:
+ logging.error(
+ 'Ambiguous model origin: Both uri and model/version name are '
+ 'provided.')
+ raise ValueError('Ambiguous model origin.')
+ prediction_input['uri'] = uri
+ elif model_name:
+ origin_name = 'projects/{}/models/{}'.format(project_id, model_name)
+ if not version_name:
+ prediction_input['modelName'] = origin_name
+ else:
+ prediction_input['versionName'] = \
+ origin_name + '/versions/{}'.format(version_name)
+ else:
+ logging.error(
+ 'Missing model origin: Batch prediction expects a model, '
+ 'a model & version combination, or a URI to savedModel.')
+ raise ValueError('Missing model origin.')
+
+ if max_worker_count:
+ prediction_input['maxWorkerCount'] = max_worker_count
+ if runtime_version:
+ prediction_input['runtimeVersion'] = runtime_version
+
+ return prediction_input
+
+
+def _normalize_mlengine_job_id(job_id):
+ """
+ Replaces invalid MLEngine job_id characters with '_'.
+
+ This also adds a leading 'z' in case job_id starts with an invalid
+ character.
+
+ Args:
+ job_id: A job_id str that may have invalid characters.
+
+ Returns:
+ A valid job_id representation.
+ """
+ match = re.search(r'\d', job_id)
+ if match and match.start() is 0:
+ job_id = 'z_{}'.format(job_id)
+ return re.sub('[^0-9a-zA-Z]+', '_', job_id)
+
+
+class MLEngineBatchPredictionOperator(BaseOperator):
+ """
+ Start a Google Cloud ML Engine prediction job.
+
+ NOTE: For model origin, users should consider exactly one from the
+ three options below:
+ 1. Populate 'uri' field only, which should be a GCS location that
+ points to a tensorflow savedModel directory.
+ 2. Populate 'model_name' field only, which refers to an existing
+ model, and the default version of the model will be used.
+ 3. Populate both 'model_name' and 'version_name' fields, which
+ refers to a specific version of a specific model.
+
+ In options 2 and 3, both model and version name should contain the
+ minimal identifier. For instance, call
+ MLEngineBatchPredictionOperator(
+ ...,
+ model_name='my_model',
+ version_name='my_version',
+ ...)
+ if the desired model version is
+ "projects/my_project/models/my_model/versions/my_version".
+
+
+ :param project_id: The Google Cloud project name where the
+ prediction job is submitted.
+ :type project_id: string
+
+ :param job_id: A unique id for the prediction job on Google Cloud
+ ML Engine.
+ :type job_id: string
+
+ :param data_format: The format of the input data.
+ It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided
+ or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"].
+ :type data_format: string
+
+ :param input_paths: A list of GCS paths of input data for batch
+ prediction. Accepting wildcard operator *, but only at the end.
+ :type input_paths: list of string
+
+ :param output_path: The GCS path where the prediction results are
+ written to.
+ :type output_path: string
+
+ :param region: The Google Compute Engine region to run the
+ prediction job in.:
+ :type region: string
+
+ :param model_name: The Google Cloud ML Engine model to use for prediction.
+ If version_name is not provided, the default version of this
+ model will be used.
+ Should not be None if version_name is provided.
+ Should be None if uri is provided.
+ :type model_name: string
+
+ :param version_name: The Google Cloud ML Engine model version to use for
+ prediction.
+ Should be None if uri is provided.
+ :type version_name: string
+
+ :param uri: The GCS path of the saved model to use for prediction.
+ Should be None if model_name is provided.
+ It should be a GCS path pointing to a tensorflow SavedModel.
+ :type uri: string
+
+ :param max_worker_count: The maximum number of workers to be used
+ for parallel processing. Defaults to 10 if not specified.
+ :type max_worker_count: int
+
+ :param runtime_version: The Google Cloud ML Engine runtime version to use
+ for batch prediction.
+ :type runtime_version: string
+
+ :param gcp_conn_id: The connection ID used for connection to Google
+ Cloud Platform.
+ :type gcp_conn_id: string
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must
+ have doamin-wide delegation enabled.
+ :type delegate_to: string
+
+ Raises:
+ ValueError: if a unique model/version origin cannot be determined.
+ """
+
+ template_fields = [
+ "prediction_job_request",
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ job_id,
+ region,
+ data_format,
+ input_paths,
+ output_path,
+ model_name=None,
+ version_name=None,
+ uri=None,
+ max_worker_count=None,
+ runtime_version=None,
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ *args,
+ **kwargs):
+ super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)
+
+ self.project_id = project_id
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+
+ try:
+ prediction_input = _create_prediction_input(
+ project_id, region, data_format, input_paths, output_path,
+ model_name, version_name, uri, max_worker_count,
+ runtime_version)
+ except ValueError as e:
+ logging.error(
+ 'Cannot create batch prediction job request due to: {}'
+ .format(str(e)))
+ raise
+
+ self.prediction_job_request = {
+ 'jobId': _normalize_mlengine_job_id(job_id),
+ 'predictionInput': prediction_input
+ }
+
+ def execute(self, context):
+ hook = MLEngineHook(self.gcp_conn_id, self.delegate_to)
+
+ def check_existing_job(existing_job):
+ return existing_job.get('predictionInput', None) == \
+ self.prediction_job_request['predictionInput']
+ try:
+ finished_prediction_job = hook.create_job(
+ self.project_id,
+ self.prediction_job_request,
+ check_existing_job)
+ except errors.HttpError:
+ raise
+
+ if finished_prediction_job['state'] != 'SUCCEEDED':
+ logging.error(
+ 'Batch prediction job failed: %s',
+ str(finished_prediction_job))
+ raise RuntimeError(finished_prediction_job['errorMessage'])
+
+ return finished_prediction_job['predictionOutput']
+
+
+class MLEngineModelOperator(BaseOperator):
+ """
+ Operator for managing a Google Cloud ML Engine model.
+
+ :param project_id: The Google Cloud project name to which MLEngine
+ model belongs.
+ :type project_id: string
+
+ :param model: A dictionary containing the information about the model.
+ If the `operation` is `create`, then the `model` parameter should
+ contain all the information about this model such as `name`.
+
+ If the `operation` is `get`, the `model` parameter
+ should contain the `name` of the model.
+ :type model: dict
+
+ :param operation: The operation to perform. Available operations are:
+ 'create': Creates a new model as provided by the `model` parameter.
+ 'get': Gets a particular model where the name is specified in `model`.
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :type gcp_conn_id: string
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: string
+ """
+
+ template_fields = [
+ '_model',
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ model,
+ operation='create',
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ *args,
+ **kwargs):
+ super(MLEngineModelOperator, self).__init__(*args, **kwargs)
+ self._project_id = project_id
+ self._model = model
+ self._operation = operation
+ self._gcp_conn_id = gcp_conn_id
+ self._delegate_to = delegate_to
+
+ def execute(self, context):
+ hook = MLEngineHook(
+ gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+ if self._operation == 'create':
+ return hook.create_model(self._project_id, self._model)
+ elif self._operation == 'get':
+ return hook.get_model(self._project_id, self._model['name'])
+ else:
+ raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class MLEngineVersionOperator(BaseOperator):
+ """
+ Operator for managing a Google Cloud ML Engine version.
+
+ :param project_id: The Google Cloud project name to which MLEngine
+ model belongs.
+ :type project_id: string
+
+ :param model_name: The name of the Google Cloud ML Engine model that the version
+ belongs to.
+ :type model_name: string
+
+ :param version_name: A name to use for the version being operated upon. If
+ not None and the `version` argument is None or does not have a value for
+ the `name` key, then this will be populated in the payload for the
+ `name` key.
+ :type version_name: string
+
+ :param version: A dictionary containing the information about the version.
+ If the `operation` is `create`, `version` should contain all the
+ information about this version such as name, and deploymentUrl.
+ If the `operation` is `get` or `delete`, the `version` parameter
+ should contain the `name` of the version.
+ If it is None, the only `operation` possible would be `list`.
+ :type version: dict
+
+ :param operation: The operation to perform. Available operations are:
+ 'create': Creates a new version in the model specified by `model_name`,
+ in which case the `version` parameter should contain all the
+ information to create that version
+ (e.g. `name`, `deploymentUrl`).
+ 'get': Gets full information of a particular version in the model
+ specified by `model_name`.
+ The name of the version should be specified in the `version`
+ parameter.
+
+ 'list': Lists all available versions of the model specified
+ by `model_name`.
+
+ 'delete': Deletes the version specified in `version` parameter from the
+ model specified by `model_name`).
+ The name of the version should be specified in the `version`
+ parameter.
+ :type operation: string
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :type gcp_conn_id: string
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: string
+ """
+
+ template_fields = [
+ '_model_name',
+ '_version_name',
+ '_version',
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ model_name,
+ version_name=None,
+ version=None,
+ operation='create',
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ *args,
+ **kwargs):
+
+ super(MLEngineVersionOperator, self).__init__(*args, **kwargs)
+ self._project_id = project_id
+ self._model_name = model_name
+ self._version_name = version_name
+ self._version = version or {}
+ self._operation = operation
+ self._gcp_conn_id = gcp_conn_id
+ self._delegate_to = delegate_to
+
+ def execute(self, context):
+ if 'name' not in self._version:
+ self._version['name'] = self._version_name
+
+ hook = MLEngineHook(
+ gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+ if self._operation == 'create':
+ assert self._version is not None
+ return hook.create_version(self._project_id, self._model_name,
+ self._version)
+ elif self._operation == 'set_default':
+ return hook.set_default_version(
+ self._project_id, self._model_name,
+ self._version['name'])
+ elif self._operation == 'list':
+ return hook.list_versions(self._project_id, self._model_name)
+ elif self._operation == 'delete':
+ return hook.delete_version(self._project_id, self._model_name,
+ self._version['name'])
+ else:
+ raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class MLEngineTrainingOperator(BaseOperator):
+ """
+ Operator for launching a MLEngine training job.
+
+ :param project_id: The Google Cloud project name within which MLEngine
+ training job should run. This field could be templated.
+ :type project_id: string
+
+ :param job_id: A unique templated id for the submitted Google MLEngine
+ training job.
+ :type job_id: string
+
+ :param package_uris: A list of package locations for MLEngine training job,
+ which should include the main training program + any additional
+ dependencies.
+ :type package_uris: string
+
+ :param training_python_module: The Python module name to run within MLEngine
+ training job after installing 'package_uris' packages.
+ :type training_python_module: string
+
+ :param training_args: A list of templated command line arguments to pass to
+ the MLEngine training program.
+ :type training_args: string
+
+ :param region: The Google Compute Engine region to run the MLEngine training
+ job in. This field could be templated.
+ :type region: string
+
+ :param scale_tier: Resource tier for MLEngine training job.
+ :type scale_tier: string
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :type gcp_conn_id: string
+
+ :param delegate_to: The account to impersonate, if any.
+ For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: string
+
+ :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
+ training job will be launched, but the MLEngine training job request
+ will be printed out. In 'CLOUD' mode, a real MLEngine training job
+ creation request will be issued.
+ :type mode: string
+ """
+
+ template_fields = [
+ '_project_id',
+ '_job_id',
+ '_package_uris',
+ '_training_python_module',
+ '_training_args',
+ '_region',
+ '_scale_tier',
+ ]
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ job_id,
+ package_uris,
+ training_python_module,
+ training_args,
+ region,
+ scale_tier=None,
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ mode='PRODUCTION',
+ *args,
+ **kwargs):
+ super(MLEngineTrainingOperator, self).__init__(*args, **kwargs)
+ self._project_id = project_id
+ self._job_id = job_id
+ self._package_uris = package_uris
+ self._training_python_module = training_python_module
+ self._training_args = training_args
+ self._region = region
+ self._scale_tier = scale_tier
+ self._gcp_conn_id = gcp_conn_id
+ self._delegate_to = delegate_to
+ self._mode = mode
+
+ if not self._project_id:
+ raise AirflowException('Google Cloud project id is required.')
+ if not self._job_id:
+ raise AirflowException(
+ 'An unique job id is required for Google MLEngine training '
+ 'job.')
+ if not package_uris:
+ raise AirflowException(
+ 'At least one python package is required for MLEngine '
+ 'Training job.')
+ if not training_python_module:
+ raise AirflowException(
+ 'Python module name to run after installing required '
+ 'packages is required.')
+ if not self._region:
+ raise AirflowException('Google Compute Engine region is required.')
+
+ def execute(self, context):
+ job_id = _normalize_mlengine_job_id(self._job_id)
+ training_request = {
+ 'jobId': job_id,
+ 'trainingInput': {
+ 'scaleTier': self._scale_tier,
+ 'packageUris': self._package_uris,
+ 'pythonModule': self._training_python_module,
+ 'region': self._region,
+ 'args': self._training_args,
+ }
+ }
+
+ if self._mode == 'DRY_RUN':
+ logging.info('In dry_run mode.')
+ logging.info(
+ 'MLEngine Training job request is: {}'.format(training_request))
+ return
+
+ hook = MLEngineHook(
+ gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+ # Helper method to check if the existing job's training input is the
+ # same as the request we get here.
+ def check_existing_job(existing_job):
+ return existing_job.get('trainingInput', None) == \
+ training_request['trainingInput']
+ try:
+ finished_training_job = hook.create_job(
+ self._project_id, training_request, check_existing_job)
+ except errors.HttpError:
+ raise
+
+ if finished_training_job['state'] != 'SUCCEEDED':
+ logging.error('MLEngine training job failed: {}'.format(
+ str(finished_training_job)))
+ raise RuntimeError(finished_training_job['errorMessage'])