You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2021/12/21 12:32:08 UTC
[airflow] branch main updated: Organize Sagemaker classes in Amazon provider (#20370)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d557965 Organize Sagemaker classes in Amazon provider (#20370)
d557965 is described below
commit d557965bab8e03d48922ed233ed3b9551cf65cec
Author: Bhavani Ravi <bh...@users.noreply.github.com>
AuthorDate: Tue Dec 21 18:01:42 2021 +0530
Organize Sagemaker classes in Amazon provider (#20370)
Organize Sagemaker classes in Amazon provider (#20370)
---
.../providers/amazon/aws/operators/sagemaker.py | 643 +++++++++++++++++++++
.../amazon/aws/operators/sagemaker_base.py | 95 +--
.../amazon/aws/operators/sagemaker_endpoint.py | 143 +----
.../aws/operators/sagemaker_endpoint_config.py | 39 +-
.../amazon/aws/operators/sagemaker_model.py | 43 +-
.../amazon/aws/operators/sagemaker_processing.py | 100 +---
.../amazon/aws/operators/sagemaker_training.py | 115 +---
.../amazon/aws/operators/sagemaker_transform.py | 109 +---
.../amazon/aws/operators/sagemaker_tuning.py | 82 +--
airflow/providers/amazon/aws/sensors/sagemaker.py | 268 +++++++++
.../providers/amazon/aws/sensors/sagemaker_base.py | 72 +--
.../amazon/aws/sensors/sagemaker_endpoint.py | 40 +-
.../amazon/aws/sensors/sagemaker_training.py | 89 +--
.../amazon/aws/sensors/sagemaker_transform.py | 41 +-
.../amazon/aws/sensors/sagemaker_tuning.py | 41 +-
airflow/providers/amazon/provider.yaml | 2 +
dev/provider_packages/prepare_provider_packages.py | 2 +
tests/deprecated_classes.py | 62 +-
.../amazon/aws/operators/test_sagemaker_base.py | 2 +-
.../aws/operators/test_sagemaker_endpoint.py | 2 +-
.../operators/test_sagemaker_endpoint_config.py | 2 +-
.../amazon/aws/operators/test_sagemaker_model.py | 2 +-
.../aws/operators/test_sagemaker_processing.py | 2 +-
.../aws/operators/test_sagemaker_training.py | 2 +-
.../aws/operators/test_sagemaker_transform.py | 2 +-
.../amazon/aws/operators/test_sagemaker_tuning.py | 2 +-
.../amazon/aws/sensors/test_sagemaker_base.py | 2 +-
.../amazon/aws/sensors/test_sagemaker_endpoint.py | 2 +-
.../amazon/aws/sensors/test_sagemaker_training.py | 2 +-
.../amazon/aws/sensors/test_sagemaker_transform.py | 2 +-
.../amazon/aws/sensors/test_sagemaker_tuning.py | 2 +-
31 files changed, 1085 insertions(+), 927 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
new file mode 100644
index 0000000..9eceec6
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -0,0 +1,643 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import sys
+from typing import List, Optional
+
+from botocore.exceptions import ClientError
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else:
+ from cached_property import cached_property
+
+
+class SageMakerBaseOperator(BaseOperator):
+ """This is the base operator for all SageMaker operators.
+
+ :param config: The configuration necessary to start a training job (templated)
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ """
+
+ template_fields = ['config']
+ template_ext = ()
+ template_fields_renderers = {'config': 'json'}
+ ui_color = '#ededed'
+ integer_fields = []
+
+ def __init__(self, *, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
+ super().__init__(**kwargs)
+ self.aws_conn_id = aws_conn_id
+ self.config = config
+
+ def parse_integer(self, config, field):
+ """Recursive method for parsing string fields holding integer values to integers."""
+ if len(field) == 1:
+ if isinstance(config, list):
+ for sub_config in config:
+ self.parse_integer(sub_config, field)
+ return
+ head = field[0]
+ if head in config:
+ config[head] = int(config[head])
+ return
+ if isinstance(config, list):
+ for sub_config in config:
+ self.parse_integer(sub_config, field)
+ return
+ (head, tail) = (field[0], field[1:])
+ if head in config:
+ self.parse_integer(config[head], tail)
+ return
+
+ def parse_config_integers(self):
+ """
+ Parse the integer fields of training config to integers in case the config is rendered by Jinja and
+ all fields are str
+ """
+ for field in self.integer_fields:
+ self.parse_integer(self.config, field)
+
+ def expand_role(self):
+ """Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
+
+ def preprocess_config(self):
+ """Process the config into a usable form."""
+ self.log.info('Preprocessing the config and doing required s3_operations')
+ self.hook.configure_s3_resources(self.config)
+ self.parse_config_integers()
+ self.expand_role()
+ self.log.info(
+ 'After preprocessing the config is:\n %s',
+ json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')),
+ )
+
+ def execute(self, context):
+ raise NotImplementedError('Please implement execute() in sub class!')
+
+ @cached_property
+ def hook(self):
+ """Return SageMakerHook"""
+ return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+
+class SageMakerProcessingOperator(SageMakerBaseOperator):
+ """Initiate a SageMaker processing job.
+
+ This operator returns The ARN of the processing job created in Amazon SageMaker.
+
+ :param config: The configuration necessary to start a processing job (templated).
+
+ For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job`
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ :param wait_for_completion: If wait is set to True, the time interval, in seconds,
+ that the operation waits to check the status of the processing job.
+ :type wait_for_completion: bool
+ :param print_log: if the operator should print the cloudwatch log during processing
+ :type print_log: bool
+ :param check_interval: if wait is set to be true, this is the time interval
+ in seconds which the operator will check the status of the processing job
+ :type check_interval: int
+ :param max_ingestion_time: If wait is set to True, the operation fails if the processing job
+ doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
+ the operation does not timeout.
+ :type max_ingestion_time: int
+ :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
+ (default) and "fail".
+ :type action_if_job_exists: str
+ """
+
+ def __init__(
+ self,
+ *,
+ config: dict,
+ aws_conn_id: str,
+ wait_for_completion: bool = True,
+ print_log: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ action_if_job_exists: str = 'increment',
+ **kwargs,
+ ):
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
+ if action_if_job_exists not in ('increment', 'fail'):
+ raise AirflowException(
+ f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
+ Provided value: '{action_if_job_exists}'."
+ )
+ self.action_if_job_exists = action_if_job_exists
+ self.wait_for_completion = wait_for_completion
+ self.print_log = print_log
+ self.check_interval = check_interval
+ self.max_ingestion_time = max_ingestion_time
+ self._create_integer_fields()
+
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be casted to integers."""
+ self.integer_fields = [
+ ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
+ ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
+ ]
+ if 'StoppingCondition' in self.config:
+ self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']]
+
+ def expand_role(self) -> None:
+ if 'RoleArn' in self.config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
+ self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ processing_job_name = self.config['ProcessingJobName']
+ if self.hook.find_processing_job_by_name(processing_job_name):
+ raise AirflowException(
+ f'A SageMaker processing job with name {processing_job_name} already exists.'
+ )
+ self.log.info('Creating SageMaker processing job %s.', self.config['ProcessingJobName'])
+ response = self.hook.create_processing_job(
+ self.config,
+ wait_for_completion=self.wait_for_completion,
+ check_interval=self.check_interval,
+ max_ingestion_time=self.max_ingestion_time,
+ )
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker Processing Job creation failed: {response}')
+ return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])}
+
+
+class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
+ """
+ Create a SageMaker endpoint config.
+
+ This operator returns The ARN of the endpoint config created in Amazon SageMaker
+
+ :param config: The configuration necessary to create an endpoint config.
+
+ For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ """
+
+ integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
+
+ def __init__(self, *, config: dict, **kwargs):
+ super().__init__(config=config, **kwargs)
+ self.config = config
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
+ response = self.hook.create_endpoint_config(self.config)
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker endpoint config creation failed: {response}')
+ else:
+ return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
+
+
+class SageMakerEndpointOperator(SageMakerBaseOperator):
+ """
+ Create a SageMaker endpoint.
+
+ This operator returns The ARN of the endpoint created in Amazon SageMaker
+
+ :param config:
+ The configuration necessary to create an endpoint.
+
+ If you need to create a SageMaker endpoint based on an existed
+ SageMaker model and an existed SageMaker endpoint config::
+
+ config = endpoint_configuration;
+
+ If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
+
+ config = {
+ 'Model': model_configuration,
+ 'EndpointConfig': endpoint_config_configuration,
+ 'Endpoint': endpoint_configuration
+ }
+
+ For details of the configuration parameter of model_configuration see
+ :py:meth:`SageMaker.Client.create_model`
+
+ For details of the configuration parameter of endpoint_config_configuration see
+ :py:meth:`SageMaker.Client.create_endpoint_config`
+
+ For details of the configuration parameter of endpoint_configuration see
+ :py:meth:`SageMaker.Client.create_endpoint`
+
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
+ :type wait_for_completion: bool
+ :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
+ waits before polling the status of the endpoint creation.
+ :type check_interval: int
+ :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
+ finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
+ :type max_ingestion_time: int
+ :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
+ :type operation: str
+ """
+
+ def __init__(
+ self,
+ *,
+ config: dict,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ operation: str = 'create',
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.config = config
+ self.wait_for_completion = wait_for_completion
+ self.check_interval = check_interval
+ self.max_ingestion_time = max_ingestion_time
+ self.operation = operation.lower()
+ if self.operation not in ['create', 'update']:
+ raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
+ self.create_integer_fields()
+
+ def create_integer_fields(self) -> None:
+ """Set fields which should be casted to integers."""
+ if 'EndpointConfig' in self.config:
+ self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
+
+ def expand_role(self) -> None:
+ if 'Model' not in self.config:
+ return
+ hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
+ config = self.config['Model']
+ if 'ExecutionRoleArn' in config:
+ config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ model_info = self.config.get('Model')
+ endpoint_config_info = self.config.get('EndpointConfig')
+ endpoint_info = self.config.get('Endpoint', self.config)
+ if model_info:
+ self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
+ self.hook.create_model(model_info)
+ if endpoint_config_info:
+ self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
+ self.hook.create_endpoint_config(endpoint_config_info)
+ if self.operation == 'create':
+ sagemaker_operation = self.hook.create_endpoint
+ log_str = 'Creating'
+ elif self.operation == 'update':
+ sagemaker_operation = self.hook.update_endpoint
+ log_str = 'Updating'
+ else:
+ raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
+ self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
+ try:
+ response = sagemaker_operation(
+ endpoint_info,
+ wait_for_completion=self.wait_for_completion,
+ check_interval=self.check_interval,
+ max_ingestion_time=self.max_ingestion_time,
+ )
+ except ClientError:
+ self.operation = 'update'
+ sagemaker_operation = self.hook.update_endpoint
+ log_str = 'Updating'
+ response = sagemaker_operation(
+ endpoint_info,
+ wait_for_completion=self.wait_for_completion,
+ check_interval=self.check_interval,
+ max_ingestion_time=self.max_ingestion_time,
+ )
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker endpoint creation failed: {response}')
+ else:
+ return {
+ 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),
+ 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']),
+ }
+
+
+class SageMakerTransformOperator(SageMakerBaseOperator):
+ """Initiate a SageMaker transform job.
+
+ This operator returns The ARN of the model created in Amazon SageMaker.
+
+ :param config: The configuration necessary to start a transform job (templated).
+
+ If you need to create a SageMaker transform job based on an existed SageMaker model::
+
+ config = transform_config
+
+ If you need to create both SageMaker model and SageMaker Transform job::
+
+ config = {
+ 'Model': model_config,
+ 'Transform': transform_config
+ }
+
+ For details of the configuration parameter of transform_config see
+ :py:meth:`SageMaker.Client.create_transform_job`
+
+ For details of the configuration parameter of model_config, See:
+ :py:meth:`SageMaker.Client.create_model`
+
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ :param wait_for_completion: Set to True to wait until the transform job finishes.
+ :type wait_for_completion: bool
+ :param check_interval: If wait is set to True, the time interval, in seconds,
+ that this operation waits to check the status of the transform job.
+ :type check_interval: int
+ :param max_ingestion_time: If wait is set to True, the operation fails
+ if the transform job doesn't finish within max_ingestion_time seconds. If you
+ set this parameter to None, the operation does not timeout.
+ :type max_ingestion_time: int
+ """
+
+ def __init__(
+ self,
+ *,
+ config: dict,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.config = config
+ self.wait_for_completion = wait_for_completion
+ self.check_interval = check_interval
+ self.max_ingestion_time = max_ingestion_time
+ self.create_integer_fields()
+
+ def create_integer_fields(self) -> None:
+ """Set fields which should be casted to integers."""
+ self.integer_fields: List[List[str]] = [
+ ['Transform', 'TransformResources', 'InstanceCount'],
+ ['Transform', 'MaxConcurrentTransforms'],
+ ['Transform', 'MaxPayloadInMB'],
+ ]
+ if 'Transform' not in self.config:
+ for field in self.integer_fields:
+ field.pop(0)
+
+ def expand_role(self) -> None:
+ if 'Model' not in self.config:
+ return
+ config = self.config['Model']
+ if 'ExecutionRoleArn' in config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
+ config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ model_config = self.config.get('Model')
+ transform_config = self.config.get('Transform', self.config)
+ if model_config:
+ self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
+ self.hook.create_model(model_config)
+ self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
+ response = self.hook.create_transform_job(
+ transform_config,
+ wait_for_completion=self.wait_for_completion,
+ check_interval=self.check_interval,
+ max_ingestion_time=self.max_ingestion_time,
+ )
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker transform Job creation failed: {response}')
+ else:
+ return {
+ 'Model': self.hook.describe_model(transform_config['ModelName']),
+ 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
+ }
+
+
+class SageMakerTuningOperator(SageMakerBaseOperator):
+ """Initiate a SageMaker hyperparameter tuning job.
+
+ This operator returns The ARN of the tuning job created in Amazon SageMaker.
+
+ :param config: The configuration necessary to start a tuning job (templated).
+
+ For details of the configuration parameter see
+ :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ :param wait_for_completion: Set to True to wait until the tuning job finishes.
+ :type wait_for_completion: bool
+ :param check_interval: If wait is set to True, the time interval, in seconds,
+ that this operation waits to check the status of the tuning job.
+ :type check_interval: int
+ :param max_ingestion_time: If wait is set to True, the operation fails
+ if the tuning job doesn't finish within max_ingestion_time seconds. If you
+ set this parameter to None, the operation does not timeout.
+ :type max_ingestion_time: int
+ """
+
+ integer_fields = [
+ ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
+ ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
+ ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
+ ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
+ ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
+ ]
+
+ def __init__(
+ self,
+ *,
+ config: dict,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.config = config
+ self.wait_for_completion = wait_for_completion
+ self.check_interval = check_interval
+ self.max_ingestion_time = max_ingestion_time
+
+ def expand_role(self) -> None:
+ if 'TrainingJobDefinition' in self.config:
+ config = self.config['TrainingJobDefinition']
+ if 'RoleArn' in config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
+ config['RoleArn'] = hook.expand_role(config['RoleArn'])
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ self.log.info(
+ 'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
+ )
+ response = self.hook.create_tuning_job(
+ self.config,
+ wait_for_completion=self.wait_for_completion,
+ check_interval=self.check_interval,
+ max_ingestion_time=self.max_ingestion_time,
+ )
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}')
+ else:
+ return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])}
+
+
+class SageMakerModelOperator(SageMakerBaseOperator):
+ """Create a SageMaker model.
+
+ This operator returns The ARN of the model created in Amazon SageMaker
+
+ :param config: The configuration necessary to create a model.
+
+ For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ """
+
+ def __init__(self, *, config, **kwargs):
+ super().__init__(config=config, **kwargs)
+ self.config = config
+
+ def expand_role(self) -> None:
+ if 'ExecutionRoleArn' in self.config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
+ self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
+ response = self.hook.create_model(self.config)
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker model creation failed: {response}')
+ else:
+ return {'Model': self.hook.describe_model(self.config['ModelName'])}
+
+
+class SageMakerTrainingOperator(SageMakerBaseOperator):
+ """
+ Initiate a SageMaker training job.
+
+ This operator returns The ARN of the training job created in Amazon SageMaker.
+
+ :param config: The configuration necessary to start a training job (templated).
+
+ For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
+ :type config: dict
+ :param aws_conn_id: The AWS connection ID to use.
+ :type aws_conn_id: str
+ :param wait_for_completion: If wait is set to True, the time interval, in seconds,
+ that the operation waits to check the status of the training job.
+ :type wait_for_completion: bool
+ :param print_log: if the operator should print the cloudwatch log during training
+ :type print_log: bool
+ :param check_interval: if wait is set to be true, this is the time interval
+ in seconds which the operator will check the status of the training job
+ :type check_interval: int
+ :param max_ingestion_time: If wait is set to True, the operation fails if the training job
+ doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
+ the operation does not timeout.
+ :type max_ingestion_time: int
+ :param check_if_job_exists: If set to true, then the operator will check whether a training job
+ already exists for the name in the config.
+ :type check_if_job_exists: bool
+ :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
+ (default) and "fail".
+ This is only relevant if check_if
+ """
+
+ integer_fields = [
+ ['ResourceConfig', 'InstanceCount'],
+ ['ResourceConfig', 'VolumeSizeInGB'],
+ ['StoppingCondition', 'MaxRuntimeInSeconds'],
+ ]
+
+ def __init__(
+ self,
+ *,
+ config: dict,
+ wait_for_completion: bool = True,
+ print_log: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ check_if_job_exists: bool = True,
+ action_if_job_exists: str = 'increment',
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
+ self.wait_for_completion = wait_for_completion
+ self.print_log = print_log
+ self.check_interval = check_interval
+ self.max_ingestion_time = max_ingestion_time
+ self.check_if_job_exists = check_if_job_exists
+ if action_if_job_exists in ('increment', 'fail'):
+ self.action_if_job_exists = action_if_job_exists
+ else:
+ raise AirflowException(
+ f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
+ Provided value: '{action_if_job_exists}'."
+ )
+
+ def expand_role(self) -> None:
+ if 'RoleArn' in self.config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
+ self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
+
+ def execute(self, context) -> dict:
+ self.preprocess_config()
+ if self.check_if_job_exists:
+ self._check_if_job_exists()
+ self.log.info('Creating SageMaker training job %s.', self.config['TrainingJobName'])
+ response = self.hook.create_training_job(
+ self.config,
+ wait_for_completion=self.wait_for_completion,
+ print_log=self.print_log,
+ check_interval=self.check_interval,
+ max_ingestion_time=self.max_ingestion_time,
+ )
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Sagemaker Training Job creation failed: {response}')
+ else:
+ return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])}
+
+ def _check_if_job_exists(self) -> None:
+ training_job_name = self.config['TrainingJobName']
+ training_jobs = self.hook.list_training_jobs(name_contains=training_job_name)
+ if training_job_name in [tj['TrainingJobName'] for tj in training_jobs]:
+ if self.action_if_job_exists == 'increment':
+ self.log.info("Found existing training job with name '%s'.", training_job_name)
+ new_training_job_name = f'{training_job_name}-{(len(training_jobs) + 1)}'
+ self.config['TrainingJobName'] = new_training_job_name
+ self.log.info("Incremented training job name to '%s'.", new_training_job_name)
+ elif self.action_if_job_exists == 'fail':
+ raise AirflowException(
+ f'A SageMaker training job with name {training_job_name} already exists.'
+ )
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py b/airflow/providers/amazon/aws/operators/sagemaker_base.py
index b91c7b4..22e44d7 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_base.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_base.py
@@ -16,93 +16,14 @@
# specific language governing permissions and limitations
# under the License.
-import json
-import sys
-from typing import Iterable
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+import warnings
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator # noqa
-from airflow.models import BaseOperator
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-
-
-class SageMakerBaseOperator(BaseOperator):
- """
- This is the base operator for all SageMaker operators.
-
- :param config: The configuration necessary to start a training job (templated)
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- """
-
- template_fields = ['config']
- template_ext = ()
- template_fields_renderers = {"config": "json"}
- ui_color = '#ededed'
-
- integer_fields = [] # type: Iterable[Iterable[str]]
-
- def __init__(self, *, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
- super().__init__(**kwargs)
-
- self.aws_conn_id = aws_conn_id
- self.config = config
-
- def parse_integer(self, config, field):
- """Recursive method for parsing string fields holding integer values to integers."""
- if len(field) == 1:
- if isinstance(config, list):
- for sub_config in config:
- self.parse_integer(sub_config, field)
- return
- head = field[0]
- if head in config:
- config[head] = int(config[head])
- return
-
- if isinstance(config, list):
- for sub_config in config:
- self.parse_integer(sub_config, field)
- return
-
- head, tail = field[0], field[1:]
- if head in config:
- self.parse_integer(config[head], tail)
- return
-
- def parse_config_integers(self):
- """
- Parse the integer fields of training config to integers in case the config is rendered by Jinja and
- all fields are str.
- """
- for field in self.integer_fields:
- self.parse_integer(self.config, field)
-
- def expand_role(self):
- """Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
-
- def preprocess_config(self):
- """Process the config into a usable form."""
- self.log.info('Preprocessing the config and doing required s3_operations')
-
- self.hook.configure_s3_resources(self.config)
- self.parse_config_integers()
- self.expand_role()
-
- self.log.info(
- "After preprocessing the config is:\n %s",
- json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")),
- )
-
- def execute(self, context):
- raise NotImplementedError('Please implement execute() in sub class!')
-
- @cached_property
- def hook(self):
- """Return SageMakerHook"""
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
index 352c88d..5351431 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
@@ -15,142 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
-from botocore.exceptions import ClientError
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+import warnings
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator # noqa
-class SageMakerEndpointOperator(SageMakerBaseOperator):
- """
- Create a SageMaker endpoint.
-
- This operator returns The ARN of the endpoint created in Amazon SageMaker
-
- :param config:
- The configuration necessary to create an endpoint.
-
- If you need to create a SageMaker endpoint based on an existed
- SageMaker model and an existed SageMaker endpoint config::
-
- config = endpoint_configuration;
-
- If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
-
- config = {
- 'Model': model_configuration,
- 'EndpointConfig': endpoint_config_configuration,
- 'Endpoint': endpoint_configuration
- }
-
- For details of the configuration parameter of model_configuration see
- :py:meth:`SageMaker.Client.create_model`
-
- For details of the configuration parameter of endpoint_config_configuration see
- :py:meth:`SageMaker.Client.create_endpoint_config`
-
- For details of the configuration parameter of endpoint_configuration see
- :py:meth:`SageMaker.Client.create_endpoint`
-
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
- :type wait_for_completion: bool
- :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
- waits before polling the status of the endpoint creation.
- :type check_interval: int
- :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
- finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
- :type max_ingestion_time: int
- :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
- :type operation: str
- """
-
- def __init__(
- self,
- *,
- config: dict,
- wait_for_completion: bool = True,
- check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
- operation: str = 'create',
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
-
- self.config = config
- self.wait_for_completion = wait_for_completion
- self.check_interval = check_interval
- self.max_ingestion_time = max_ingestion_time
- self.operation = operation.lower()
- if self.operation not in ['create', 'update']:
- raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
- self.create_integer_fields()
-
- def create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- if 'EndpointConfig' in self.config:
- self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
-
- def expand_role(self) -> None:
- if 'Model' not in self.config:
- return
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- config = self.config['Model']
- if 'ExecutionRoleArn' in config:
- config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
-
- def execute(self, context) -> dict:
- self.preprocess_config()
-
- model_info = self.config.get('Model')
- endpoint_config_info = self.config.get('EndpointConfig')
- endpoint_info = self.config.get('Endpoint', self.config)
-
- if model_info:
- self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
- self.hook.create_model(model_info)
-
- if endpoint_config_info:
- self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
- self.hook.create_endpoint_config(endpoint_config_info)
-
- if self.operation == 'create':
- sagemaker_operation = self.hook.create_endpoint
- log_str = 'Creating'
- elif self.operation == 'update':
- sagemaker_operation = self.hook.update_endpoint
- log_str = 'Updating'
- else:
- raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
-
- self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
- try:
- response = sagemaker_operation(
- endpoint_info,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
- )
- except ClientError: # Botocore throws a ClientError if the endpoint is already created
- self.operation = 'update'
- sagemaker_operation = self.hook.update_endpoint
- log_str = 'Updating'
- response = sagemaker_operation(
- endpoint_info,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
- )
-
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker endpoint creation failed: {response}')
- else:
- return {
- 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),
- 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']),
- }
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
index 448a1a5..737e5b6 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
@@ -16,37 +16,14 @@
# specific language governing permissions and limitations
# under the License.
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
+import warnings
-class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
- """
- Create a SageMaker endpoint config.
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator # noqa
- This operator returns The ARN of the endpoint config created in Amazon SageMaker
-
- :param config: The configuration necessary to create an endpoint config.
-
- For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- """
-
- integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
-
- def __init__(self, *, config: dict, **kwargs):
- super().__init__(config=config, **kwargs)
-
- self.config = config
-
- def execute(self, context) -> dict:
- self.preprocess_config()
-
- self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
- response = self.hook.create_endpoint_config(self.config)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker endpoint config creation failed: {response}')
- else:
- return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py b/airflow/providers/amazon/aws/operators/sagemaker_model.py
index c29f8a9..fffe8d9 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_model.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_model.py
@@ -16,41 +16,14 @@
# specific language governing permissions and limitations
# under the License.
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
+import warnings
-class SageMakerModelOperator(SageMakerBaseOperator):
- """
- Create a SageMaker model.
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerModelOperator # noqa
- This operator returns The ARN of the model created in Amazon SageMaker
-
- :param config: The configuration necessary to create a model.
-
- For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- """
-
- def __init__(self, *, config, **kwargs):
- super().__init__(config=config, **kwargs)
-
- self.config = config
-
- def expand_role(self) -> None:
- if 'ExecutionRoleArn' in self.config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
-
- def execute(self, context) -> dict:
- self.preprocess_config()
-
- self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
- response = self.hook.create_model(self.config)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker model creation failed: {response}')
- else:
- return {'Model': self.hook.describe_model(self.config['ModelName'])}
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
index 529fb5d..b3a4be8 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
@@ -15,99 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
+import warnings
-class SageMakerProcessingOperator(SageMakerBaseOperator):
- """
- Initiate a SageMaker processing job.
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator # noqa
- This operator returns The ARN of the processing job created in Amazon SageMaker.
-
- :param config: The configuration necessary to start a processing job (templated).
-
- For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job`
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- :param wait_for_completion: If wait is set to True, the time interval, in seconds,
- that the operation waits to check the status of the processing job.
- :type wait_for_completion: bool
- :param print_log: if the operator should print the cloudwatch log during processing
- :type print_log: bool
- :param check_interval: if wait is set to be true, this is the time interval
- in seconds which the operator will check the status of the processing job
- :type check_interval: int
- :param max_ingestion_time: If wait is set to True, the operation fails if the processing job
- doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
- the operation does not timeout.
- :type max_ingestion_time: int
- :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
- (default) and "fail".
- :type action_if_job_exists: str
- """
-
- def __init__(
- self,
- *,
- config: dict,
- aws_conn_id: str,
- wait_for_completion: bool = True,
- print_log: bool = True,
- check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
- action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
- **kwargs,
- ):
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
-
- if action_if_job_exists not in ("increment", "fail"):
- raise AirflowException(
- "Argument action_if_job_exists accepts only 'increment' and 'fail'. "
- f"Provided value: '{action_if_job_exists}'."
- )
- self.action_if_job_exists = action_if_job_exists
- self.wait_for_completion = wait_for_completion
- self.print_log = print_log
- self.check_interval = check_interval
- self.max_ingestion_time = max_ingestion_time
- self._create_integer_fields()
-
- def _create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- self.integer_fields = [
- ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
- ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
- ]
- if 'StoppingCondition' in self.config:
- self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']]
-
- def expand_role(self) -> None:
- if 'RoleArn' in self.config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
-
- def execute(self, context) -> dict:
- self.preprocess_config()
-
- processing_job_name = self.config["ProcessingJobName"]
-
- if self.hook.find_processing_job_by_name(processing_job_name):
- raise AirflowException(
- f"A SageMaker processing job with name {processing_job_name} already exists."
- )
-
- self.log.info("Creating SageMaker processing job %s.", self.config["ProcessingJobName"])
- response = self.hook.create_processing_job(
- self.config,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
- )
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker Processing Job creation failed: {response}')
- return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])}
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py
index db60bde..40f13b4 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_training.py
@@ -15,114 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
+import warnings
-class SageMakerTrainingOperator(SageMakerBaseOperator):
- """
- Initiate a SageMaker training job.
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator # noqa
- This operator returns The ARN of the training job created in Amazon SageMaker.
-
- :param config: The configuration necessary to start a training job (templated).
-
- For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- :param wait_for_completion: If wait is set to True, the time interval, in seconds,
- that the operation waits to check the status of the training job.
- :type wait_for_completion: bool
- :param print_log: if the operator should print the cloudwatch log during training
- :type print_log: bool
- :param check_interval: if wait is set to be true, this is the time interval
- in seconds which the operator will check the status of the training job
- :type check_interval: int
- :param max_ingestion_time: If wait is set to True, the operation fails if the training job
- doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
- the operation does not timeout.
- :type max_ingestion_time: int
- :param check_if_job_exists: If set to true, then the operator will check whether a training job
- already exists for the name in the config.
- :type check_if_job_exists: bool
- :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
- (default) and "fail".
- This is only relevant if check_if_job_exists is True.
- :type action_if_job_exists: str
- """
-
- integer_fields = [
- ['ResourceConfig', 'InstanceCount'],
- ['ResourceConfig', 'VolumeSizeInGB'],
- ['StoppingCondition', 'MaxRuntimeInSeconds'],
- ]
-
- def __init__(
- self,
- *,
- config: dict,
- wait_for_completion: bool = True,
- print_log: bool = True,
- check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
- check_if_job_exists: bool = True,
- action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
-
- self.wait_for_completion = wait_for_completion
- self.print_log = print_log
- self.check_interval = check_interval
- self.max_ingestion_time = max_ingestion_time
- self.check_if_job_exists = check_if_job_exists
-
- if action_if_job_exists in ("increment", "fail"):
- self.action_if_job_exists = action_if_job_exists
- else:
- raise AirflowException(
- "Argument action_if_job_exists accepts only 'increment' and 'fail'. "
- f"Provided value: '{action_if_job_exists}'."
- )
-
- def expand_role(self) -> None:
- if 'RoleArn' in self.config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
-
- def execute(self, context) -> dict:
- self.preprocess_config()
- if self.check_if_job_exists:
- self._check_if_job_exists()
- self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"])
- response = self.hook.create_training_job(
- self.config,
- wait_for_completion=self.wait_for_completion,
- print_log=self.print_log,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
- )
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker Training Job creation failed: {response}')
- else:
- return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])}
-
- def _check_if_job_exists(self) -> None:
- training_job_name = self.config["TrainingJobName"]
- training_jobs = self.hook.list_training_jobs(name_contains=training_job_name)
-
- # Check if given TrainingJobName already exists
- if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]:
- if self.action_if_job_exists == "increment":
- self.log.info("Found existing training job with name '%s'.", training_job_name)
- new_training_job_name = f"{training_job_name}-{len(training_jobs) + 1}"
- self.config["TrainingJobName"] = new_training_job_name
- self.log.info("Incremented training job name to '%s'.", new_training_job_name)
- elif self.action_if_job_exists == "fail":
- raise AirflowException(
- f"A SageMaker training job with name {training_job_name} already exists."
- )
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
index 7449e86..1e833ed 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
@@ -15,108 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
+import warnings
-class SageMakerTransformOperator(SageMakerBaseOperator):
- """
- Initiate a SageMaker transform job.
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator # noqa
- This operator returns The ARN of the model created in Amazon SageMaker.
-
- :param config: The configuration necessary to start a transform job (templated).
-
- If you need to create a SageMaker transform job based on an existed SageMaker model::
-
- config = transform_config
-
- If you need to create both SageMaker model and SageMaker Transform job::
-
- config = {
- 'Model': model_config,
- 'Transform': transform_config
- }
-
- For details of the configuration parameter of transform_config see
- :py:meth:`SageMaker.Client.create_transform_job`
-
- For details of the configuration parameter of model_config, See:
- :py:meth:`SageMaker.Client.create_model`
-
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- :param wait_for_completion: Set to True to wait until the transform job finishes.
- :type wait_for_completion: bool
- :param check_interval: If wait is set to True, the time interval, in seconds,
- that this operation waits to check the status of the transform job.
- :type check_interval: int
- :param max_ingestion_time: If wait is set to True, the operation fails
- if the transform job doesn't finish within max_ingestion_time seconds. If you
- set this parameter to None, the operation does not timeout.
- :type max_ingestion_time: int
- """
-
- def __init__(
- self,
- *,
- config: dict,
- wait_for_completion: bool = True,
- check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.wait_for_completion = wait_for_completion
- self.check_interval = check_interval
- self.max_ingestion_time = max_ingestion_time
- self.create_integer_fields()
-
- def create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- self.integer_fields: List[List[str]] = [
- ['Transform', 'TransformResources', 'InstanceCount'],
- ['Transform', 'MaxConcurrentTransforms'],
- ['Transform', 'MaxPayloadInMB'],
- ]
- if 'Transform' not in self.config:
- for field in self.integer_fields:
- field.pop(0)
-
- def expand_role(self) -> None:
- if 'Model' not in self.config:
- return
- config = self.config['Model']
- if 'ExecutionRoleArn' in config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
-
- def execute(self, context) -> dict:
- self.preprocess_config()
-
- model_config = self.config.get('Model')
- transform_config = self.config.get('Transform', self.config)
-
- if model_config:
- self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
- self.hook.create_model(model_config)
-
- self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
- response = self.hook.create_transform_job(
- transform_config,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
- )
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker transform Job creation failed: {response}')
- else:
- return {
- 'Model': self.hook.describe_model(transform_config['ModelName']),
- 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
- }
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
index 2bb3857..18a8263 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
@@ -15,81 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
+import warnings
-class SageMakerTuningOperator(SageMakerBaseOperator):
- """
- Initiate a SageMaker hyperparameter tuning job.
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator # noqa
- This operator returns The ARN of the tuning job created in Amazon SageMaker.
-
- :param config: The configuration necessary to start a tuning job (templated).
-
- For details of the configuration parameter see
- :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
- :type config: dict
- :param aws_conn_id: The AWS connection ID to use.
- :type aws_conn_id: str
- :param wait_for_completion: Set to True to wait until the tuning job finishes.
- :type wait_for_completion: bool
- :param check_interval: If wait is set to True, the time interval, in seconds,
- that this operation waits to check the status of the tuning job.
- :type check_interval: int
- :param max_ingestion_time: If wait is set to True, the operation fails
- if the tuning job doesn't finish within max_ingestion_time seconds. If you
- set this parameter to None, the operation does not timeout.
- :type max_ingestion_time: int
- """
-
- integer_fields = [
- ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
- ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
- ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
- ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
- ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
- ]
-
- def __init__(
- self,
- *,
- config: dict,
- wait_for_completion: bool = True,
- check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.wait_for_completion = wait_for_completion
- self.check_interval = check_interval
- self.max_ingestion_time = max_ingestion_time
-
- def expand_role(self) -> None:
- if 'TrainingJobDefinition' in self.config:
- config = self.config['TrainingJobDefinition']
- if 'RoleArn' in config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- config['RoleArn'] = hook.expand_role(config['RoleArn'])
-
- def execute(self, context) -> dict:
- self.preprocess_config()
-
- self.log.info(
- 'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
- )
-
- response = self.hook.create_tuning_job(
- self.config,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
- )
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}')
- else:
- return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])}
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py
new file mode 100644
index 0000000..c7370fb
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -0,0 +1,268 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import time
+from typing import Optional, Set
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
+from airflow.sensors.base import BaseSensorOperator
+
+
+class SageMakerBaseSensor(BaseSensorOperator):
+ """Contains general sensor behavior for SageMaker.
+
+ Subclasses should implement get_sagemaker_response()
+ and state_from_response() methods.
+ Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
+ """
+
+ ui_color = '#ededed'
+
+ def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs):
+ super().__init__(**kwargs)
+ self.aws_conn_id = aws_conn_id
+ self.hook: Optional[SageMakerHook] = None
+
+ def get_hook(self) -> SageMakerHook:
+ """Get SageMakerHook."""
+ if self.hook:
+ return self.hook
+ self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
+ return self.hook
+
+ def poke(self, context):
+ response = self.get_sagemaker_response()
+ if not (response['ResponseMetadata']['HTTPStatusCode'] == 200):
+ self.log.info('Bad HTTP response: %s', response)
+ return False
+ state = self.state_from_response(response)
+ self.log.info('Job currently %s', state)
+ if state in self.non_terminal_states():
+ return False
+ if state in self.failed_states():
+ failed_reason = self.get_failed_reason_from_response(response)
+ raise AirflowException(f'Sagemaker job failed for the following reason: {failed_reason}')
+ return True
+
+ def non_terminal_states(self) -> Set[str]:
+ """Placeholder for returning states with should not terminate."""
+ raise NotImplementedError('Please implement non_terminal_states() in subclass')
+
+ def failed_states(self) -> Set[str]:
+ """Placeholder for returning states with are considered failed."""
+ raise NotImplementedError('Please implement failed_states() in subclass')
+
+ def get_sagemaker_response(self) -> Optional[dict]:
+ """Placeholder for checking status of a SageMaker task."""
+ raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
+
+ def get_failed_reason_from_response(self, response: dict) -> str:
+ """Placeholder for extracting the reason for failure from an AWS response."""
+ return 'Unknown'
+
+ def state_from_response(self, response: dict) -> str:
+ """Placeholder for extracting the state from an AWS response."""
+ raise NotImplementedError('Please implement state_from_response() in subclass')
+
+
+class SageMakerEndpointSensor(SageMakerBaseSensor):
+ """Asks for the state of the endpoint state until it reaches a
+ terminal state.
+ If it fails the sensor errors, the task fails.
+
+
+ :param job_name: job_name of the endpoint instance to check the state of
+
+ :type job_name: str
+ """
+
+ template_fields = ['endpoint_name']
+ template_ext = ()
+
+ def __init__(self, *, endpoint_name, **kwargs):
+ super().__init__(**kwargs)
+ self.endpoint_name = endpoint_name
+
+ def non_terminal_states(self):
+ return SageMakerHook.endpoint_non_terminal_states
+
+ def failed_states(self):
+ return SageMakerHook.failed_states
+
+ def get_sagemaker_response(self):
+ self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
+ return self.get_hook().describe_endpoint(self.endpoint_name)
+
+ def get_failed_reason_from_response(self, response):
+ return response['FailureReason']
+
+ def state_from_response(self, response):
+ return response['EndpointStatus']
+
+
+class SageMakerTransformSensor(SageMakerBaseSensor):
+ """Asks for the state of the transform state until it reaches a
+ terminal state.
+ The sensor will error if the job errors, throwing a
+ AirflowException
+ containing the failure reason.
+
+ :param
+ job_name: job_name of the transform job instance to check the state of
+
+ :type job_name: str
+ """
+
+ template_fields = ['job_name']
+ template_ext = ()
+
+ def __init__(self, *, job_name: str, **kwargs):
+ super().__init__(**kwargs)
+ self.job_name = job_name
+
+ def non_terminal_states(self):
+ return SageMakerHook.non_terminal_states
+
+ def failed_states(self):
+ return SageMakerHook.failed_states
+
+ def get_sagemaker_response(self):
+ self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
+ return self.get_hook().describe_transform_job(self.job_name)
+
+ def get_failed_reason_from_response(self, response):
+ return response['FailureReason']
+
+ def state_from_response(self, response):
+ return response['TransformJobStatus']
+
+
+class SageMakerTuningSensor(SageMakerBaseSensor):
+ """Asks for the state of the tuning state until it reaches a terminal
+ state.
+ The sensor will error if the job errors, throwing a
+ AirflowException
+ containing the failure reason.
+
+ :param
+ job_name: job_name of the tuning instance to check the state of
+ :type
+ job_name: str
+ """
+
+ template_fields = ['job_name']
+ template_ext = ()
+
+ def __init__(self, *, job_name: str, **kwargs):
+ super().__init__(**kwargs)
+ self.job_name = job_name
+
+ def non_terminal_states(self):
+ return SageMakerHook.non_terminal_states
+
+ def failed_states(self):
+ return SageMakerHook.failed_states
+
+ def get_sagemaker_response(self):
+ self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
+ return self.get_hook().describe_tuning_job(self.job_name)
+
+ def get_failed_reason_from_response(self, response):
+ return response['FailureReason']
+
+ def state_from_response(self, response):
+ return response['HyperParameterTuningJobStatus']
+
+
+class SageMakerTrainingSensor(SageMakerBaseSensor):
+ """Asks for the state of the training state until it reaches a
+ terminal state.
+ If it fails the sensor errors, failing the task.
+
+
+ :param job_name: name of the SageMaker training job to check the state of
+
+ :type job_name: str
+ :param print_log: if the operator should print the cloudwatch log
+ :type print_log: bool
+ """
+
+ template_fields = ['job_name']
+ template_ext = ()
+
+ def __init__(self, *, job_name, print_log=True, **kwargs):
+ super().__init__(**kwargs)
+ self.job_name = job_name
+ self.print_log = print_log
+ self.positions = {}
+ self.stream_names = []
+ self.instance_count: Optional[int] = None
+ self.state: Optional[int] = None
+ self.last_description = None
+ self.last_describe_job_call = None
+ self.log_resource_inited = False
+
+ def init_log_resource(self, hook: SageMakerHook) -> None:
+ """Set tailing LogState for associated training job."""
+ description = hook.describe_training_job(self.job_name)
+ self.instance_count = description['ResourceConfig']['InstanceCount']
+ status = description['TrainingJobStatus']
+ job_already_completed = status not in self.non_terminal_states()
+ self.state = LogState.TAILING if (not job_already_completed) else LogState.COMPLETE
+ self.last_description = description
+ self.last_describe_job_call = time.monotonic()
+ self.log_resource_inited = True
+
+ def non_terminal_states(self):
+ return SageMakerHook.non_terminal_states
+
+ def failed_states(self):
+ return SageMakerHook.failed_states
+
+ def get_sagemaker_response(self):
+ if self.print_log:
+ if not self.log_resource_inited:
+ self.init_log_resource(self.get_hook())
+ (
+ self.state,
+ self.last_description,
+ self.last_describe_job_call,
+ ) = self.get_hook().describe_training_job_with_log(
+ self.job_name,
+ self.positions,
+ self.stream_names,
+ self.instance_count,
+ self.state,
+ self.last_description,
+ self.last_describe_job_call,
+ )
+ else:
+ self.last_description = self.get_hook().describe_training_job(self.job_name)
+ status = self.state_from_response(self.last_description)
+ if (status not in self.non_terminal_states()) and (status not in self.failed_states()):
+ billable_time = (
+ self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']
+ ) * self.last_description['ResourceConfig']['InstanceCount']
+ self.log.info('Billable seconds: %s', (int(billable_time.total_seconds()) + 1))
+ return self.last_description
+
+ def get_failed_reason_from_response(self, response):
+ return response['FailureReason']
+
+ def state_from_response(self, response):
+ return response['TrainingJobStatus']
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
index 8a0956e..102c410 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
@@ -15,71 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Set
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.sensors.base import BaseSensorOperator
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
+import warnings
-class SageMakerBaseSensor(BaseSensorOperator):
- """
- Contains general sensor behavior for SageMaker.
- Subclasses should implement get_sagemaker_response()
- and state_from_response() methods.
- Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
- """
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor # noqa
- ui_color = '#ededed'
-
- def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs):
- super().__init__(**kwargs)
- self.aws_conn_id = aws_conn_id
- self.hook: Optional[SageMakerHook] = None
-
- def get_hook(self) -> SageMakerHook:
- """Get SageMakerHook"""
- if self.hook:
- return self.hook
-
- self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
- return self.hook
-
- def poke(self, context):
- response = self.get_sagemaker_response()
-
- if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
- self.log.info('Bad HTTP response: %s', response)
- return False
-
- state = self.state_from_response(response)
-
- self.log.info('Job currently %s', state)
-
- if state in self.non_terminal_states():
- return False
-
- if state in self.failed_states():
- failed_reason = self.get_failed_reason_from_response(response)
- raise AirflowException(f'Sagemaker job failed for the following reason: {failed_reason}')
- return True
-
- def non_terminal_states(self) -> Set[str]:
- """Placeholder for returning states with should not terminate."""
- raise NotImplementedError('Please implement non_terminal_states() in subclass')
-
- def failed_states(self) -> Set[str]:
- """Placeholder for returning states with are considered failed."""
- raise NotImplementedError('Please implement failed_states() in subclass')
-
- def get_sagemaker_response(self) -> Optional[dict]:
- """Placeholder for checking status of a SageMaker task."""
- raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
-
- def get_failed_reason_from_response(self, response: dict) -> str:
- """Placeholder for extracting the reason for failure from an AWS response."""
- return 'Unknown'
-
- def state_from_response(self, response: dict) -> str:
- """Placeholder for extracting the state from an AWS response."""
- raise NotImplementedError('Please implement state_from_response() in subclass')
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
index bb3885f..00ed844 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
@@ -16,38 +16,14 @@
# specific language governing permissions and limitations
# under the License.
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
+import warnings
-class SageMakerEndpointSensor(SageMakerBaseSensor):
- """
- Asks for the state of the endpoint state until it reaches a terminal state.
- If it fails the sensor errors, the task fails.
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor # noqa
- :param job_name: job_name of the endpoint instance to check the state of
- :type job_name: str
- """
-
- template_fields = ['endpoint_name']
- template_ext = ()
-
- def __init__(self, *, endpoint_name, **kwargs):
- super().__init__(**kwargs)
- self.endpoint_name = endpoint_name
-
- def non_terminal_states(self):
- return SageMakerHook.endpoint_non_terminal_states
-
- def failed_states(self):
- return SageMakerHook.failed_states
-
- def get_sagemaker_response(self):
- self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
- return self.get_hook().describe_endpoint(self.endpoint_name)
-
- def get_failed_reason_from_response(self, response):
- return response['FailureReason']
-
- def state_from_response(self, response):
- return response['EndpointStatus']
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
index 12e1264..d194996 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
@@ -15,88 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import time
-from typing import Optional
-from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
+import warnings
-class SageMakerTrainingSensor(SageMakerBaseSensor):
- """
- Asks for the state of the training state until it reaches a terminal state.
- If it fails the sensor errors, failing the task.
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTrainingSensor # noqa
- :param job_name: name of the SageMaker training job to check the state of
- :type job_name: str
- :param print_log: if the operator should print the cloudwatch log
- :type print_log: bool
- """
-
- template_fields = ['job_name']
- template_ext = ()
-
- def __init__(self, *, job_name, print_log=True, **kwargs):
- super().__init__(**kwargs)
- self.job_name = job_name
- self.print_log = print_log
- self.positions = {}
- self.stream_names = []
- self.instance_count: Optional[int] = None
- self.state: Optional[int] = None
- self.last_description = None
- self.last_describe_job_call = None
- self.log_resource_inited = False
-
- def init_log_resource(self, hook: SageMakerHook) -> None:
- """Set tailing LogState for associated training job."""
- description = hook.describe_training_job(self.job_name)
- self.instance_count = description['ResourceConfig']['InstanceCount']
-
- status = description['TrainingJobStatus']
- job_already_completed = status not in self.non_terminal_states()
- self.state = LogState.TAILING if not job_already_completed else LogState.COMPLETE
- self.last_description = description
- self.last_describe_job_call = time.monotonic()
- self.log_resource_inited = True
-
- def non_terminal_states(self):
- return SageMakerHook.non_terminal_states
-
- def failed_states(self):
- return SageMakerHook.failed_states
-
- def get_sagemaker_response(self):
- if self.print_log:
- if not self.log_resource_inited:
- self.init_log_resource(self.get_hook())
- (
- self.state,
- self.last_description,
- self.last_describe_job_call,
- ) = self.get_hook().describe_training_job_with_log(
- self.job_name,
- self.positions,
- self.stream_names,
- self.instance_count,
- self.state,
- self.last_description,
- self.last_describe_job_call,
- )
- else:
- self.last_description = self.get_hook().describe_training_job(self.job_name)
-
- status = self.state_from_response(self.last_description)
- if status not in self.non_terminal_states() and status not in self.failed_states():
- billable_time = (
- self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']
- ) * self.last_description['ResourceConfig']['InstanceCount']
- self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)
-
- return self.last_description
-
- def get_failed_reason_from_response(self, response):
- return response['FailureReason']
-
- def state_from_response(self, response):
- return response['TrainingJobStatus']
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
index 6e03066..7a48f3e 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
@@ -16,39 +16,14 @@
# specific language governing permissions and limitations
# under the License.
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
+import warnings
-class SageMakerTransformSensor(SageMakerBaseSensor):
- """
- Asks for the state of the transform state until it reaches a terminal state.
- The sensor will error if the job errors, throwing a AirflowException
- containing the failure reason.
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTransformSensor # noqa
- :param job_name: job_name of the transform job instance to check the state of
- :type job_name: str
- """
-
- template_fields = ['job_name']
- template_ext = ()
-
- def __init__(self, *, job_name: str, **kwargs):
- super().__init__(**kwargs)
- self.job_name = job_name
-
- def non_terminal_states(self):
- return SageMakerHook.non_terminal_states
-
- def failed_states(self):
- return SageMakerHook.failed_states
-
- def get_sagemaker_response(self):
- self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
- return self.get_hook().describe_transform_job(self.job_name)
-
- def get_failed_reason_from_response(self, response):
- return response['FailureReason']
-
- def state_from_response(self, response):
- return response['TransformJobStatus']
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
index 9f05fa8..d5f0d90 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
@@ -16,39 +16,14 @@
# specific language governing permissions and limitations
# under the License.
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
+import warnings
-class SageMakerTuningSensor(SageMakerBaseSensor):
- """
- Asks for the state of the tuning state until it reaches a terminal state.
- The sensor will error if the job errors, throwing a AirflowException
- containing the failure reason.
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTuningSensor # noqa
- :param job_name: job_name of the tuning instance to check the state of
- :type job_name: str
- """
-
- template_fields = ['job_name']
- template_ext = ()
-
- def __init__(self, *, job_name: str, **kwargs):
- super().__init__(**kwargs)
- self.job_name = job_name
-
- def non_terminal_states(self):
- return SageMakerHook.non_terminal_states
-
- def failed_states(self):
- return SageMakerHook.failed_states
-
- def get_sagemaker_response(self):
- self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
- return self.get_hook().describe_tuning_job(self.job_name)
-
- def get_failed_reason_from_response(self, response):
- return response['FailureReason']
-
- def state_from_response(self, response):
- return response['HyperParameterTuningJobStatus']
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 5664ba4..7190dac 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -231,6 +231,7 @@ operators:
- airflow.providers.amazon.aws.operators.s3
- integration-name: Amazon SageMaker
python-modules:
+ - airflow.providers.amazon.aws.operators.sagemaker
- airflow.providers.amazon.aws.operators.sagemaker_base
- airflow.providers.amazon.aws.operators.sagemaker_endpoint
- airflow.providers.amazon.aws.operators.sagemaker_endpoint_config
@@ -307,6 +308,7 @@ sensors:
- airflow.providers.amazon.aws.sensors.s3
- integration-name: Amazon SageMaker
python-modules:
+ - airflow.providers.amazon.aws.sensors.sagemaker
- airflow.providers.amazon.aws.sensors.sagemaker_base
- airflow.providers.amazon.aws.sensors.sagemaker_endpoint
- airflow.providers.amazon.aws.sensors.sagemaker_training
diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py
index 8c1204d..42c6d59 100755
--- a/dev/provider_packages/prepare_provider_packages.py
+++ b/dev/provider_packages/prepare_provider_packages.py
@@ -2161,6 +2161,8 @@ KNOWN_DEPRECATED_DIRECT_IMPORTS: Set[str] = {
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.redshift_sql` "
"or `airflow.providers.amazon.aws.operators.redshift_cluster` as appropriate.",
"This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.",
+ 'This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.',
+ 'This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.',
'This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.',
}
diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py
index 11edfdd..310c7d6 100644
--- a/tests/deprecated_classes.py
+++ b/tests/deprecated_classes.py
@@ -1136,31 +1136,59 @@ OPERATORS = [
'airflow.operators.s3_file_transform_operator.S3FileTransformOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator',
'airflow.providers.amazon.aws.operators.sagemaker_base.SageMakerBaseOperator',
- 'airflow.contrib.operators.sagemaker_base_operator.SageMakerBaseOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointConfigOperator',
'airflow.providers.amazon.aws.operators.sagemaker_endpoint_config.SageMakerEndpointConfigOperator',
- 'airflow.contrib.operators.sagemaker_endpoint_config_operator.SageMakerEndpointConfigOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointOperator',
'airflow.providers.amazon.aws.operators.sagemaker_endpoint.SageMakerEndpointOperator',
- 'airflow.contrib.operators.sagemaker_endpoint_operator.SageMakerEndpointOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator',
'airflow.providers.amazon.aws.operators.sagemaker_model.SageMakerModelOperator',
- 'airflow.contrib.operators.sagemaker_model_operator.SageMakerModelOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator',
'airflow.providers.amazon.aws.operators.sagemaker_training.SageMakerTrainingOperator',
- 'airflow.contrib.operators.sagemaker_training_operator.SageMakerTrainingOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator',
'airflow.providers.amazon.aws.operators.sagemaker_transform.SageMakerTransformOperator',
- 'airflow.contrib.operators.sagemaker_transform_operator.SageMakerTransformOperator',
),
(
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator',
'airflow.providers.amazon.aws.operators.sagemaker_tuning.SageMakerTuningOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator',
+ 'airflow.contrib.operators.sagemaker_base_operator.SageMakerBaseOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointConfigOperator',
+ 'airflow.contrib.operators.sagemaker_endpoint_config_operator.SageMakerEndpointConfigOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointOperator',
+ 'airflow.contrib.operators.sagemaker_endpoint_operator.SageMakerEndpointOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator',
+ 'airflow.contrib.operators.sagemaker_model_operator.SageMakerModelOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator',
+ 'airflow.contrib.operators.sagemaker_training_operator.SageMakerTrainingOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator',
+ 'airflow.contrib.operators.sagemaker_transform_operator.SageMakerTransformOperator',
+ ),
+ (
+ 'airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator',
'airflow.contrib.operators.sagemaker_tuning_operator.SageMakerTuningOperator',
),
(
@@ -1615,19 +1643,35 @@ SENSORS = [
'airflow.contrib.sensors.emr_step_sensor.EmrStepSensor',
),
(
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerBaseSensor',
'airflow.providers.amazon.aws.sensors.sagemaker_base.SageMakerBaseSensor',
- 'airflow.contrib.sensors.sagemaker_base_sensor.SageMakerBaseSensor',
),
(
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerEndpointSensor',
'airflow.providers.amazon.aws.sensors.sagemaker_endpoint.SageMakerEndpointSensor',
- 'airflow.contrib.sensors.sagemaker_endpoint_sensor.SageMakerEndpointSensor',
),
(
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTransformSensor',
'airflow.providers.amazon.aws.sensors.sagemaker_transform.SageMakerTransformSensor',
- 'airflow.contrib.sensors.sagemaker_transform_sensor.SageMakerTransformSensor',
),
(
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTuningSensor',
'airflow.providers.amazon.aws.sensors.sagemaker_tuning.SageMakerTuningSensor',
+ ),
+ (
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerBaseSensor',
+ 'airflow.contrib.sensors.sagemaker_base_sensor.SageMakerBaseSensor',
+ ),
+ (
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerEndpointSensor',
+ 'airflow.contrib.sensors.sagemaker_endpoint_sensor.SageMakerEndpointSensor',
+ ),
+ (
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTransformSensor',
+ 'airflow.contrib.sensors.sagemaker_transform_sensor.SageMakerTransformSensor',
+ ),
+ (
+ 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTuningSensor',
'airflow.contrib.sensors.sagemaker_tuning_sensor.SageMakerTuningSensor',
),
(
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
index 6b128da..74b30e2 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
@@ -17,7 +17,7 @@
import unittest
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator
config = {'key1': '1', 'key2': {'key3': '3', 'key4': '4'}, 'key5': [{'key6': '6'}, {'key6': '7'}]}
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 9c68ad4..11e1431 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -24,7 +24,7 @@ from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator
role = 'arn:aws:iam:role/test-role'
bucket = 'test-bucket'
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index b8bf18c..40fe16f 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import SageMakerEndpointConfigOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator
model_name = 'test-model-name'
config_name = 'test-config-name'
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index 17ba3d9..075805e 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerModelOperator
role = 'arn:aws:iam:role/test-role'
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 4aa108d..7855bb6 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -23,7 +23,7 @@ from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_processing import SageMakerProcessingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator
job_name = 'test-job-name'
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index 86c87bf..3c29c96 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -22,7 +22,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
role = 'arn:aws:iam:role/test-role'
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 6ca4dc9..76baa71 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
role = 'arn:aws:iam:role/test-role'
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index 3982bce..cb63357 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator
role = 'arn:aws:iam:role/test-role'
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py
index 2acaef9..b12ebb6 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py
@@ -21,7 +21,7 @@ import unittest
import pytest
from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor
class TestSagemakerBaseSensor(unittest.TestCase):
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
index 410d7ec..0331bca 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_endpoint.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor
DESCRIBE_ENDPOINT_CREATING_RESPONSE = {
'EndpointStatus': 'Creating',
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
index 09d98f7..0811cdd 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_training.py
@@ -25,7 +25,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_training import SageMakerTrainingSensor
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTrainingSensor
DESCRIBE_TRAINING_COMPLETED_RESPONSE = {
'TrainingJobStatus': 'Completed',
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
index a3e23d8..88179fb 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_transform.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTransformSensor
DESCRIBE_TRANSFORM_INPROGRESS_RESPONSE = {
'TransformJobStatus': 'InProgress',
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
index 9b79b60..bbdd4aa 100644
--- a/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_tuning.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTuningSensor
DESCRIBE_TUNING_INPROGRESS_RESPONSE = {
'HyperParameterTuningJobStatus': 'InProgress',