You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/02 18:21:01 UTC

[airflow] branch main updated: Enable Auto-incrementing Transform job name in SageMakerTransformOperator (#25263)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 007b1920dd Enable Auto-incrementing Transform job name in SageMakerTransformOperator (#25263)
007b1920dd is described below

commit 007b1920ddcee1d78f871d039a6ed8f4d0d4089d
Author: celeriev <62...@users.noreply.github.com>
AuthorDate: Tue Aug 2 20:20:53 2022 +0200

    Enable Auto-incrementing Transform job name in SageMakerTransformOperator (#25263)
---
 airflow/providers/amazon/aws/hooks/sagemaker.py    | 75 +++++++++++++++++-----
 .../providers/amazon/aws/operators/sagemaker.py    | 34 +++++++++-
 .../aws/operators/test_sagemaker_transform.py      | 57 ++++++++++++++++
 3 files changed, 150 insertions(+), 16 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 2c8c28a738..d4b07ef23c 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -23,7 +23,7 @@ import time
 import warnings
 from datetime import datetime
 from functools import partial
-from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast
+from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, cast
 
 from botocore.exceptions import ClientError
 
@@ -844,24 +844,38 @@ class SageMakerHook(AwsBaseHook):
         :param kwargs: (optional) kwargs to boto3's list_training_jobs method
         :return: results of the list_training_jobs request
         """
-        config = {}
+        config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs)
+        list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
+        results = self._list_request(
+            list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
+        )
+        return results
 
-        if name_contains:
-            if "NameContains" in kwargs:
-                raise AirflowException("Either name_contains or NameContains can be provided, not both.")
-            config["NameContains"] = name_contains
+    def list_transform_jobs(
+        self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
+    ) -> List[Dict]:
+        """
+        This method wraps boto3's `list_transform_jobs`.
+        The transform job name and max results are configurable via arguments.
+        Other arguments are not, and should be provided via kwargs. Note boto3 expects these in
+        CamelCase format, for example:
 
-        if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
-            if max_results:
-                raise AirflowException("Either max_results or MaxResults can be provided, not both.")
-            # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
-            max_results = kwargs["MaxResults"]
-            del kwargs["MaxResults"]
+        .. code-block:: python
 
-        config.update(kwargs)
-        list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
+            list_transform_jobs(name_contains="myjob", StatusEquals="Failed")
+
+        .. seealso::
+            https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_transform_jobs
+
+        :param name_contains: (optional) partial name to match
+        :param max_results: (optional) maximum number of results to return. None returns infinite results
+        :param kwargs: (optional) kwargs to boto3's list_transform_jobs method
+        :return: results of the list_transform_jobs request
+        """
+        config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs)
+        list_transform_jobs_request = partial(self.get_conn().list_transform_jobs, **config)
         results = self._list_request(
-            list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
+            list_transform_jobs_request, "TransformJobSummaries", max_results=max_results
         )
         return results
 
@@ -886,6 +900,37 @@ class SageMakerHook(AwsBaseHook):
         )
         return results
 
+    def _preprocess_list_request_args(
+        self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
+    ) -> Tuple[Dict[str, Any], Optional[int]]:
+        """
+        This method preprocesses the arguments to the boto3's list_* methods.
+        It will turn arguments name_contains and max_results as boto3 compliant CamelCase format.
+        This method also makes sure that these two arguments are only set once.
+
+        :param name_contains: boto3 function with arguments
+        :param max_results: the result key to iterate over
+        :param kwargs: (optional) kwargs to boto3's list_* method
+        :return: Tuple with config dict to be passed to boto3's list_* method and max_results parameter
+        """
+        config = {}
+
+        if name_contains:
+            if "NameContains" in kwargs:
+                raise AirflowException("Either name_contains or NameContains can be provided, not both.")
+            config["NameContains"] = name_contains
+
+        if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
+            if max_results:
+                raise AirflowException("Either max_results or MaxResults can be provided, not both.")
+            # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
+            max_results = kwargs["MaxResults"]
+            del kwargs["MaxResults"]
+
+        config.update(kwargs)
+
+        return config, max_results
+
     def _list_request(
         self, partial_func: Callable, result_key: str, max_results: Optional[int] = None
     ) -> List[Dict]:
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 8da36c58dc..791000ed78 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -400,6 +400,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
     :param max_ingestion_time: If wait is set to True, the operation fails
         if the transform job doesn't finish within max_ingestion_time seconds. If you
         set this parameter to None, the operation does not timeout.
+    :param check_if_job_exists: If set to true, then the operator will check whether a transform job
+        already exists for the name in the config.
+    :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
+        (default) and "fail".
+        This is only relevant if check_if_job_exists is True.
     :return Dict: Returns The ARN of the model created in Amazon SageMaker.
     """
 
@@ -411,6 +416,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         wait_for_completion: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
         max_ingestion_time: Optional[int] = None,
+        check_if_job_exists: bool = True,
+        action_if_job_exists: str = 'increment',
         **kwargs,
     ):
         super().__init__(config=config, **kwargs)
@@ -419,6 +426,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
         self.max_ingestion_time = max_ingestion_time
+        self.check_if_job_exists = check_if_job_exists
+        if action_if_job_exists in ('increment', 'fail'):
+            self.action_if_job_exists = action_if_job_exists
+        else:
+            raise AirflowException(
+                f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
+                Provided value: '{action_if_job_exists}'."
+            )
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -444,6 +459,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         self.preprocess_config()
         model_config = self.config.get('Model')
         transform_config = self.config.get('Transform', self.config)
+        if self.check_if_job_exists:
+            self._check_if_transform_job_exists()
         if model_config:
             self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
             self.hook.create_model(model_config)
@@ -462,6 +479,21 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
             }
 
+    def _check_if_transform_job_exists(self) -> None:
+        transform_config = self.config.get('Transform', self.config)
+        transform_job_name = transform_config['TransformJobName']
+        transform_jobs = self.hook.list_transform_jobs(name_contains=transform_job_name)
+        if transform_job_name in [tj['TransformJobName'] for tj in transform_jobs]:
+            if self.action_if_job_exists == 'increment':
+                self.log.info("Found existing transform job with name '%s'.", transform_job_name)
+                new_transform_job_name = f'{transform_job_name}-{(len(transform_jobs) + 1)}'
+                transform_config['TransformJobName'] = new_transform_job_name
+                self.log.info("Incremented transform job name to '%s'.", new_transform_job_name)
+            elif self.action_if_job_exists == 'fail':
+                raise AirflowException(
+                    f'A SageMaker transform job with name {transform_job_name} already exists.'
+                )
+
 
 class SageMakerTuningOperator(SageMakerBaseOperator):
     """
@@ -605,7 +637,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
         already exists for the name in the config.
     :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
         (default) and "fail".
-        This is only relevant if check_if
+        This is only relevant if check_if_job_exists is True.
     :return Dict: Returns The ARN of the training job created in Amazon SageMaker.
     """
 
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index b622698478..871079b3d4 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -103,3 +103,60 @@ class TestSageMakerTransformOperator(unittest.TestCase):
         }
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_transform_job')
+    def test_execute_with_check_if_job_exists(self, mock_transform, mock_client):
+        mock_transform.return_value = {
+            'TransformJobArn': 'test_arn',
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+        self.sagemaker._check_if_transform_job_exists = mock.MagicMock()
+        self.sagemaker.execute(None)
+        self.sagemaker._check_if_transform_job_exists.assert_called_once()
+        mock_transform.assert_called_once_with(
+            CREATE_TRANSFORM_PARAMS,
+            wait_for_completion=False,
+            check_interval=5,
+            max_ingestion_time=None,
+        )
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_transform_job')
+    def test_execute_without_check_if_job_exists(self, mock_transform, mock_client):
+        mock_transform.return_value = {
+            'TransformJobArn': 'test_arn',
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+        self.sagemaker.check_if_job_exists = False
+        self.sagemaker._check_if_transform_job_exists = mock.MagicMock()
+        self.sagemaker.execute(None)
+        self.sagemaker._check_if_transform_job_exists.assert_not_called()
+        mock_transform.assert_called_once_with(
+            CREATE_TRANSFORM_PARAMS,
+            wait_for_completion=False,
+            check_interval=5,
+            max_ingestion_time=None,
+        )
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'list_transform_jobs')
+    def test_check_if_job_exists_increment(self, mock_list_transform_jobs, mock_client):
+        self.sagemaker.check_if_job_exists = True
+        self.sagemaker.action_if_job_exists = 'increment'
+        mock_list_transform_jobs.return_value = [{'TransformJobName': 'job_name'}]
+        self.sagemaker._check_if_transform_job_exists()
+
+        expected_config = CONFIG.copy()
+        # Expect to see TransformJobName suffixed with "-2" because we return one existing job
+        expected_config["Transform"]['TransformJobName'] = 'job_name-2'
+        assert self.sagemaker.config == expected_config
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'list_transform_jobs')
+    def test_check_if_job_exists_fail(self, mock_list_transform_jobs, mock_client):
+        self.sagemaker.check_if_job_exists = True
+        self.sagemaker.action_if_job_exists = 'fail'
+        mock_list_transform_jobs.return_value = [{'TransformJobName': 'job_name'}]
+        with pytest.raises(AirflowException):
+            self.sagemaker._check_if_transform_job_exists()