You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/15 06:07:09 UTC
[airflow] branch main updated: SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 543161a1af SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)
543161a1af is described below
commit 543161a1afe84400dbc3c0409bbf4ff8110919f8
Author: D. Ferruzzi <fe...@amazon.com>
AuthorDate: Fri Jul 15 06:07:00 2022 +0000
SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)
* Sagemaker Operator - Improve type hints and docstrings
* SageMaker Operator Unit Test Improvements
- Add explicit unit testing for integer_fields
- Improve type hinting
- Standardize some variable naming
* Sagemaker Operators - Configure integer fields at runtime
---
.../providers/amazon/aws/operators/sagemaker.py | 132 +++++++------
.../amazon/aws/operators/test_sagemaker_base.py | 16 +-
.../aws/operators/test_sagemaker_endpoint.py | 71 +++----
.../operators/test_sagemaker_endpoint_config.py | 32 ++--
.../amazon/aws/operators/test_sagemaker_model.py | 45 +++--
.../aws/operators/test_sagemaker_processing.py | 211 ++++++++++++---------
.../aws/operators/test_sagemaker_training.py | 85 ++++-----
.../aws/operators/test_sagemaker_transform.py | 79 ++++----
.../amazon/aws/operators/test_sagemaker_tuning.py | 79 +++-----
9 files changed, 391 insertions(+), 359 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 6a6ef69df4..8da36c58dc 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from botocore.exceptions import ClientError
@@ -29,8 +29,8 @@ from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
if TYPE_CHECKING:
from airflow.utils.context import Context
-DEFAULT_CONN_ID = 'aws_default'
-CHECK_INTERVAL_SECOND = 30
+DEFAULT_CONN_ID: str = 'aws_default'
+CHECK_INTERVAL_SECOND: int = 30
class SageMakerBaseOperator(BaseOperator):
@@ -41,15 +41,15 @@ class SageMakerBaseOperator(BaseOperator):
template_fields: Sequence[str] = ('config',)
template_ext: Sequence[str] = ()
- template_fields_renderers = {'config': 'json'}
- ui_color = '#ededed'
+ template_fields_renderers: Dict = {'config': 'json'}
+ ui_color: str = '#ededed'
integer_fields: List[List[Any]] = []
- def __init__(self, *, config: dict, **kwargs):
+ def __init__(self, *, config: Dict, **kwargs):
super().__init__(**kwargs)
self.config = config
- def parse_integer(self, config, field):
+ def parse_integer(self, config: Dict, field: Union[List[str], str]) -> None:
"""Recursive method for parsing string fields holding integer values to integers."""
if len(field) == 1:
if isinstance(config, list):
@@ -69,19 +69,17 @@ class SageMakerBaseOperator(BaseOperator):
self.parse_integer(config[head], tail)
return
- def parse_config_integers(self):
- """
- Parse the integer fields of training config to integers in case the config is rendered by Jinja and
- all fields are str
- """
+ def parse_config_integers(self) -> None:
+ """Parse the integer fields to ints in case the config is rendered by Jinja and all fields are str."""
for field in self.integer_fields:
self.parse_integer(self.config, field)
- def expand_role(self):
+ def expand_role(self) -> None:
"""Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
- def preprocess_config(self):
+ def preprocess_config(self) -> None:
"""Process the config into a usable form."""
+ self._create_integer_fields()
self.log.info('Preprocessing the config and doing required s3_operations')
self.hook.configure_s3_resources(self.config)
self.parse_config_integers()
@@ -91,12 +89,19 @@ class SageMakerBaseOperator(BaseOperator):
json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')),
)
- def execute(self, context: 'Context'):
+ def _create_integer_fields(self) -> None:
+ """
+ Set fields which should be cast to integers.
+ Child classes should override this method if they need integer fields parsed.
+ """
+ self.integer_fields = []
+
+ def execute(self, context: 'Context') -> Union[None, Dict]:
raise NotImplementedError('Please implement execute() in sub class!')
@cached_property
def hook(self):
- """Return SageMakerHook"""
+ """Return SageMakerHook."""
return SageMakerHook(aws_conn_id=self.aws_conn_id)
@@ -130,7 +135,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
def __init__(
self,
*,
- config: dict,
+ config: Dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
print_log: bool = True,
@@ -151,23 +156,23 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
self.print_log = print_log
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
- self._create_integer_fields()
def _create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- self.integer_fields = [
+ """Set fields which should be cast to integers."""
+ self.integer_fields: List[Union[List[str], List[List[str]]]] = [
['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
]
if 'StoppingCondition' in self.config:
- self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']]
+ self.integer_fields.append(['StoppingCondition', 'MaxRuntimeInSeconds'])
def expand_role(self) -> None:
+ """Expands an IAM role name into an ARN."""
if 'RoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
processing_job_name = self.config['ProcessingJobName']
if self.hook.find_processing_job_by_name(processing_job_name):
@@ -204,12 +209,10 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker.
"""
- integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
-
def __init__(
self,
*,
- config: dict,
+ config: Dict,
aws_conn_id: str = DEFAULT_CONN_ID,
**kwargs,
):
@@ -217,7 +220,11 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
self.config = config
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context') -> dict:
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: List[List[str]] = [['ProductionVariants', 'InitialInstanceCount']]
+
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
response = self.hook.create_endpoint_config(self.config)
@@ -278,7 +285,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
def __init__(
self,
*,
- config: dict,
+ config: Dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
@@ -295,14 +302,16 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
self.operation = operation.lower()
if self.operation not in ['create', 'update']:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
- self.create_integer_fields()
- def create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
if 'EndpointConfig' in self.config:
- self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
+ self.integer_fields: List[List[str]] = [
+ ['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
+ ]
def expand_role(self) -> None:
+ """Expands an IAM role name into an ARN."""
if 'Model' not in self.config:
return
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
@@ -310,7 +319,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
if 'ExecutionRoleArn' in config:
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
model_info = self.config.get('Model')
endpoint_config_info = self.config.get('EndpointConfig')
@@ -397,7 +406,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
def __init__(
self,
*,
- config: dict,
+ config: Dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
@@ -410,10 +419,9 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
- self.create_integer_fields()
- def create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
self.integer_fields: List[List[str]] = [
['Transform', 'TransformResources', 'InstanceCount'],
['Transform', 'MaxConcurrentTransforms'],
@@ -424,6 +432,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
field.pop(0)
def expand_role(self) -> None:
+ """Expands an IAM role name into an ARN."""
if 'Model' not in self.config:
return
config = self.config['Model']
@@ -431,7 +440,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
model_config = self.config.get('Model')
transform_config = self.config.get('Transform', self.config)
@@ -480,18 +489,10 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the tuning job created in Amazon SageMaker.
"""
- integer_fields = [
- ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
- ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
- ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
- ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
- ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
- ]
-
def __init__(
self,
*,
- config: dict,
+ config: Dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
@@ -506,13 +507,24 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
self.max_ingestion_time = max_ingestion_time
def expand_role(self) -> None:
+ """Expands an IAM role name into an ARN."""
if 'TrainingJobDefinition' in self.config:
config = self.config['TrainingJobDefinition']
if 'RoleArn' in config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
config['RoleArn'] = hook.expand_role(config['RoleArn'])
- def execute(self, context: 'Context') -> dict:
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: List[List[str]] = [
+ ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
+ ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
+ ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
+ ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
+ ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
+ ]
+
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
self.log.info(
'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
@@ -547,17 +559,18 @@ class SageMakerModelOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
- def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
+ def __init__(self, *, config: Dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id
def expand_role(self) -> None:
+ """Expands an IAM role name into an ARN."""
if 'ExecutionRoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
response = self.hook.create_model(self.config)
@@ -596,16 +609,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the training job created in Amazon SageMaker.
"""
- integer_fields = [
- ['ResourceConfig', 'InstanceCount'],
- ['ResourceConfig', 'VolumeSizeInGB'],
- ['StoppingCondition', 'MaxRuntimeInSeconds'],
- ]
-
def __init__(
self,
*,
- config: dict,
+ config: Dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
print_log: bool = True,
@@ -631,11 +638,20 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
)
def expand_role(self) -> None:
+ """Expands an IAM role name into an ARN."""
if 'RoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
- def execute(self, context: 'Context') -> dict:
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: List[List[str]] = [
+ ['ResourceConfig', 'InstanceCount'],
+ ['ResourceConfig', 'VolumeSizeInGB'],
+ ['StoppingCondition', 'MaxRuntimeInSeconds'],
+ ]
+
+ def execute(self, context: 'Context') -> Dict:
self.preprocess_config()
if self.check_if_job_exists:
self._check_if_job_exists()
@@ -680,7 +696,7 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
:param aws_conn_id: The AWS connection ID to use.
"""
- def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
+ def __init__(self, *, config: Dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
index 5fecaf9c94..9c96177eae 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
@@ -16,19 +16,27 @@
# under the License.
import unittest
+from typing import Any, Dict, List
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator
-config = {'key1': '1', 'key2': {'key3': '3', 'key4': '4'}, 'key5': [{'key6': '6'}, {'key6': '7'}]}
+CONFIG: Dict = {'key1': '1', 'key2': {'key3': '3', 'key4': '4'}, 'key5': [{'key6': '6'}, {'key6': '7'}]}
+PARSED_CONFIG: Dict = {'key1': 1, 'key2': {'key3': 3, 'key4': 4}, 'key5': [{'key6': 6}, {'key6': 7}]}
-parsed_config = {'key1': 1, 'key2': {'key3': 3, 'key4': 4}, 'key5': [{'key6': 6}, {'key6': 7}]}
+EXPECTED_INTEGER_FIELDS: List[List[Any]] = []
class TestSageMakerBaseOperator(unittest.TestCase):
def setUp(self):
- self.sagemaker = SageMakerBaseOperator(task_id='test_sagemaker_operator', config=config)
+ self.sagemaker = SageMakerBaseOperator(task_id='test_sagemaker_operator', config=CONFIG)
+ self.sagemaker.aws_conn_id = 'aws_default'
def test_parse_integer(self):
self.sagemaker.integer_fields = [['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6']]
self.sagemaker.parse_config_integers()
- assert self.sagemaker.config == parsed_config
+ assert self.sagemaker.config == PARSED_CONFIG
+
+ def test_default_integer_fields(self):
+ self.sagemaker.preprocess_config()
+
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 11e1431c69..3f252ab4d8 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -17,6 +17,7 @@
# under the License.
import unittest
+from typing import Dict, List
from unittest import mock
import pytest
@@ -26,57 +27,54 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator
-role = 'arn:aws:iam:role/test-role'
-bucket = 'test-bucket'
-image = 'test-image'
-output_url = f's3://{bucket}/test/output'
-model_name = 'test-model-name'
-config_name = 'test-endpoint-config-name'
-endpoint_name = 'test-endpoint-name'
-
-create_model_params = {
- 'ModelName': model_name,
+CREATE_MODEL_PARAMS: Dict = {
+ 'ModelName': 'model_name',
'PrimaryContainer': {
- 'Image': image,
- 'ModelDataUrl': output_url,
+ 'Image': 'image_name',
+ 'ModelDataUrl': 'output_path',
},
- 'ExecutionRoleArn': role,
+ 'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
}
-
-create_endpoint_config_params = {
- 'EndpointConfigName': config_name,
+CREATE_ENDPOINT_CONFIG_PARAMS: Dict = {
+ 'EndpointConfigName': 'config_name',
'ProductionVariants': [
{
'VariantName': 'AllTraffic',
- 'ModelName': model_name,
+ 'ModelName': 'model_name',
'InitialInstanceCount': '1',
'InstanceType': 'ml.c4.xlarge',
}
],
}
+CREATE_ENDPOINT_PARAMS: Dict = {'EndpointName': 'endpoint_name', 'EndpointConfigName': 'config_name'}
-create_endpoint_params = {'EndpointName': endpoint_name, 'EndpointConfigName': config_name}
-
-config = {
- 'Model': create_model_params,
- 'EndpointConfig': create_endpoint_config_params,
- 'Endpoint': create_endpoint_params,
+CONFIG: Dict = {
+ 'Model': CREATE_MODEL_PARAMS,
+ 'EndpointConfig': CREATE_ENDPOINT_CONFIG_PARAMS,
+ 'Endpoint': CREATE_ENDPOINT_PARAMS,
}
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
+
class TestSageMakerEndpointOperator(unittest.TestCase):
def setUp(self):
self.sagemaker = SageMakerEndpointOperator(
task_id='test_sagemaker_operator',
- aws_conn_id='sagemaker_test_id',
- config=config,
+ config=CONFIG,
wait_for_completion=False,
check_interval=5,
operation='create',
)
- def test_parse_config_integers(self):
- self.sagemaker.parse_config_integers()
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'create_model')
+ @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+ @mock.patch.object(SageMakerHook, 'create_endpoint')
+ def test_integer_fields(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
+ mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+ self.sagemaker.execute(None)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']:
assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
@@ -85,20 +83,23 @@ class TestSageMakerEndpointOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
@mock.patch.object(SageMakerHook, 'create_endpoint')
def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
- mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+ mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
- mock_model.assert_called_once_with(create_model_params)
- mock_endpoint_config.assert_called_once_with(create_endpoint_config_params)
+ mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
+ mock_endpoint_config.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
mock_endpoint.assert_called_once_with(
- create_endpoint_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+ CREATE_ENDPOINT_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+ for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']:
+ assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
@mock.patch.object(SageMakerHook, 'create_endpoint')
def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
- mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
+ mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
@@ -111,11 +112,11 @@ class TestSageMakerEndpointOperator(unittest.TestCase):
self, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client
):
response = {
- "Error": {"Code": "ValidationException", "Message": "Cannot create already existing endpoint."}
+ 'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing endpoint.'}
}
- mock_endpoint.side_effect = ClientError(error_response=response, operation_name="CreateEndpoint")
+ mock_endpoint.side_effect = ClientError(error_response=response, operation_name='CreateEndpoint')
mock_endpoint_update.return_value = {
- 'EndpointArn': 'testarn',
+ 'EndpointArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index 40fe16fd82..b43b30d10e 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -17,6 +17,7 @@
# under the License.
import unittest
+from typing import Dict, List
from unittest import mock
import pytest
@@ -25,32 +26,37 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator
-model_name = 'test-model-name'
-config_name = 'test-config-name'
-
-create_endpoint_config_params = {
- 'EndpointConfigName': config_name,
+CREATE_ENDPOINT_CONFIG_PARAMS: Dict = {
+ 'EndpointConfigName': 'config_name',
'ProductionVariants': [
{
'VariantName': 'AllTraffic',
- 'ModelName': model_name,
+ 'ModelName': 'model_name',
'InitialInstanceCount': '1',
'InstanceType': 'ml.c4.xlarge',
}
],
}
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [['ProductionVariants', 'InitialInstanceCount']]
+
class TestSageMakerEndpointConfigOperator(unittest.TestCase):
def setUp(self):
self.sagemaker = SageMakerEndpointConfigOperator(
task_id='test_sagemaker_operator',
- aws_conn_id='sagemaker_test_id',
- config=create_endpoint_config_params,
+ config=CREATE_ENDPOINT_CONFIG_PARAMS,
)
- def test_parse_config_integers(self):
- self.sagemaker.parse_config_integers()
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+ def test_integer_fields(self, mock_model, mock_client):
+ mock_model.return_value = {
+ 'EndpointConfigArn': 'test_arn',
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+ self.sagemaker.execute(None)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
for variant in self.sagemaker.config['ProductionVariants']:
assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
@@ -58,17 +64,17 @@ class TestSageMakerEndpointConfigOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
def test_execute(self, mock_model, mock_client):
mock_model.return_value = {
- 'EndpointConfigArn': 'testarn',
+ 'EndpointConfigArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker.execute(None)
- mock_model.assert_called_once_with(create_endpoint_config_params)
+ mock_model.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
def test_execute_with_failure(self, mock_model, mock_client):
mock_model.return_value = {
- 'EndpointConfigArn': 'testarn',
+ 'EndpointConfigArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
with pytest.raises(AirflowException):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index c2689d8fd5..98990f7c74 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -17,6 +17,7 @@
# under the License.
import unittest
+from typing import Dict, List
from unittest import mock
import pytest
@@ -28,51 +29,49 @@ from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerModelOperator,
)
-role = 'arn:aws:iam:role/test-role'
-
-bucket = 'test-bucket'
-
-model_name = 'test-model-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-create_model_params = {
- 'ModelName': model_name,
+CREATE_MODEL_PARAMS: Dict = {
+ 'ModelName': 'model_name',
'PrimaryContainer': {
- 'Image': image,
- 'ModelDataUrl': output_url,
+ 'Image': 'image_name',
+ 'ModelDataUrl': 'output_path',
},
- 'ExecutionRoleArn': role,
+ 'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
}
+EXPECTED_INTEGER_FIELDS: List[List[str]] = []
+
class TestSageMakerModelOperator(unittest.TestCase):
def setUp(self):
- self.sagemaker = SageMakerModelOperator(
- task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=create_model_params
- )
+ self.sagemaker = SageMakerModelOperator(task_id='test_sagemaker_operator', config=CREATE_MODEL_PARAMS)
+
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'create_model')
+ def test_integer_fields(self, mock_model, mock_client):
+ mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+ self.sagemaker.execute(None)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
def test_execute(self, mock_model, mock_client):
- mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+ mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
- mock_model.assert_called_once_with(create_model_params)
+ mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
def test_execute_with_failure(self, mock_model, mock_client):
- mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
+ mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
class TestSageMakerDeleteModelOperator(unittest.TestCase):
def setUp(self):
- delete_model_params = {'ModelName': 'test'}
+ delete_model_params = {'ModelName': 'model_name'}
self.sagemaker = SageMakerDeleteModelOperator(
- task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=delete_model_params
+ task_id='test_sagemaker_operator', config=delete_model_params
)
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -80,4 +79,4 @@ class TestSageMakerDeleteModelOperator(unittest.TestCase):
def test_execute(self, delete_model, mock_client):
delete_model.return_value = None
self.sagemaker.execute(None)
- delete_model.assert_called_once_with(model_name='test')
+ delete_model.assert_called_once_with(model_name='model_name')
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 7855bb6a44..10e4f3feda 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -16,151 +16,186 @@
# under the License.
import unittest
+from typing import Dict, List
from unittest import mock
import pytest
-from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator
-job_name = 'test-job-name'
-
-create_processing_params = {
- "AppSpecification": {
- "ContainerArguments": ["container_arg"],
- "ContainerEntrypoint": ["container_entrypoint"],
- "ImageUri": "{{ image_uri }}",
+CREATE_PROCESSING_PARAMS: Dict = {
+ 'AppSpecification': {
+ 'ContainerArguments': ['container_arg'],
+ 'ContainerEntrypoint': ['container_entrypoint'],
+ 'ImageUri': 'image_uri',
},
- "Environment": {"{{ key }}": "{{ value }}"},
- "ExperimentConfig": {
- "ExperimentName": "ExperimentName",
- "TrialComponentDisplayName": "TrialComponentDisplayName",
- "TrialName": "TrialName",
+ 'Environment': {'key': 'value'},
+ 'ExperimentConfig': {
+ 'ExperimentName': 'experiment_name',
+ 'TrialComponentDisplayName': 'trial_component_display_name',
+ 'TrialName': 'trial_name',
},
- "ProcessingInputs": [
+ 'ProcessingInputs': [
{
- "InputName": "AnalyticsInputName",
- "S3Input": {
- "LocalPath": "{{ Local Path }}",
- "S3CompressionType": "None",
- "S3DataDistributionType": "FullyReplicated",
- "S3DataType": "S3Prefix",
- "S3InputMode": "File",
- "S3Uri": "{{ S3Uri }}",
+ 'InputName': 'analytics_input_name',
+ 'S3Input': {
+ 'LocalPath': 'local_path',
+ 'S3CompressionType': 'None',
+ 'S3DataDistributionType': 'FullyReplicated',
+ 'S3DataType': 'S3Prefix',
+ 'S3InputMode': 'File',
+ 'S3Uri': 's3_uri',
},
}
],
- "ProcessingJobName": job_name,
- "ProcessingOutputConfig": {
- "KmsKeyId": "KmsKeyID",
- "Outputs": [
+ 'ProcessingJobName': 'job_name',
+ 'ProcessingOutputConfig': {
+ 'KmsKeyId': 'kms_key_ID',
+ 'Outputs': [
{
- "OutputName": "AnalyticsOutputName",
- "S3Output": {
- "LocalPath": "{{ Local Path }}",
- "S3UploadMode": "EndOfJob",
- "S3Uri": "{{ S3Uri }}",
+ 'OutputName': 'analytics_output_name',
+ 'S3Output': {
+ 'LocalPath': 'local_path',
+ 'S3UploadMode': 'EndOfJob',
+ 'S3Uri': 's3_uri',
},
}
],
},
- "ProcessingResources": {
- "ClusterConfig": {
- "InstanceCount": 2,
- "InstanceType": "ml.p2.xlarge",
- "VolumeSizeInGB": 30,
- "VolumeKmsKeyId": "{{ kms_key }}",
+ 'ProcessingResources': {
+ 'ClusterConfig': {
+ 'InstanceCount': '2',
+ 'InstanceType': 'ml.p2.xlarge',
+ 'VolumeSizeInGB': '30',
+ 'VolumeKmsKeyId': 'kms_key',
}
},
- "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
- "Tags": [{"{{ key }}": "{{ value }}"}],
+ 'RoleArn': 'arn:aws:iam::0122345678910:role/SageMakerPowerUser',
+ 'Tags': [{'key': 'value'}],
}
-create_processing_params_with_stopping_condition = create_processing_params.copy()
-create_processing_params_with_stopping_condition.update(StoppingCondition={"MaxRuntimeInSeconds": 3600})
+CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION: Dict = CREATE_PROCESSING_PARAMS.copy()
+CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION.update(StoppingCondition={'MaxRuntimeInSeconds': '3600'})
+
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+ ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
+ ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
+]
+EXPECTED_STOPPING_CONDITION_INTEGER_FIELDS: List[List[str]] = [['StoppingCondition', 'MaxRuntimeInSeconds']]
class TestSageMakerProcessingOperator(unittest.TestCase):
def setUp(self):
self.processing_config_kwargs = dict(
- task_id='test_sagemaker_operator',
- aws_conn_id='sagemaker_test_id',
- wait_for_completion=False,
- check_interval=5,
+ task_id='test_sagemaker_operator', wait_for_completion=False, check_interval=5
)
- @parameterized.expand(
- [
- (
- create_processing_params,
- [
- ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
- ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
- ],
- ),
- (
- create_processing_params_with_stopping_condition,
- [
- ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
- ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
- ['StoppingCondition', 'MaxRuntimeInSeconds'],
- ],
- ),
- ]
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
+ @mock.patch.object(
+ SageMakerHook,
+ 'create_processing_job',
+ return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
)
- def test_integer_fields_are_set(self, config, expected_fields):
- sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, config=config)
- assert sagemaker.integer_fields == expected_fields
+ def test_integer_fields_without_stopping_condition(self, mock_processing, mock_hook, mock_client):
+ sagemaker = SageMakerProcessingOperator(
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
+ )
+ sagemaker.execute(None)
+ assert sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+ for (key1, key2, key3) in EXPECTED_INTEGER_FIELDS:
+ assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3])
+
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
+ @mock.patch.object(
+ SageMakerHook,
+ 'create_processing_job',
+ return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
+ )
+ def test_integer_fields_with_stopping_condition(self, mock_processing, mock_hook, mock_client):
+ sagemaker = SageMakerProcessingOperator(
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION
+ )
+ sagemaker.execute(None)
+ assert (
+ sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS + EXPECTED_STOPPING_CONDITION_INTEGER_FIELDS
+ )
+ for (key1, key2, *key3) in EXPECTED_INTEGER_FIELDS:
+ if key3:
+ (key3,) = key3
+ assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3])
+ else:
+ sagemaker.config[key1][key2] == int(sagemaker.config[key1][key2])
@mock.patch.object(SageMakerHook, 'get_conn')
- @mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
+ @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
@mock.patch.object(
SageMakerHook,
'create_processing_job',
- return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
+ return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
)
def test_execute(self, mock_processing, mock_hook, mock_client):
sagemaker = SageMakerProcessingOperator(
- **self.processing_config_kwargs, config=create_processing_params
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
)
sagemaker.execute(None)
mock_processing.assert_called_once_with(
- create_processing_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+ CREATE_PROCESSING_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
)
@mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
@mock.patch.object(
SageMakerHook,
'create_processing_job',
- return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}},
+ return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
+ )
+ def test_execute_with_stopping_condition(self, mock_processing, mock_hook, mock_client):
+ sagemaker = SageMakerProcessingOperator(
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION
+ )
+ sagemaker.execute(None)
+ mock_processing.assert_called_once_with(
+ CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION,
+ wait_for_completion=False,
+ check_interval=5,
+ max_ingestion_time=None,
+ )
+
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(
+ SageMakerHook,
+ 'create_processing_job',
+ return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}},
)
def test_execute_with_failure(self, mock_processing, mock_client):
sagemaker = SageMakerProcessingOperator(
- **self.processing_config_kwargs, config=create_processing_params
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
)
with pytest.raises(AirflowException):
sagemaker.execute(None)
- @unittest.skip("Currently, the auto-increment jobname functionality is not missing.")
- @mock.patch.object(SageMakerHook, "get_conn")
- @mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
+ @unittest.skip('Currently, the auto-increment jobname functionality is not missing.')
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=True)
@mock.patch.object(
- SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
+ SageMakerHook, 'create_processing_job', return_value={'ResponseMetadata': {'HTTPStatusCode': 200}}
)
def test_execute_with_existing_job_increment(
self, mock_create_processing_job, find_processing_job_by_name, mock_client
):
sagemaker = SageMakerProcessingOperator(
- **self.processing_config_kwargs, config=create_processing_params
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
)
- sagemaker.action_if_job_exists = "increment"
+ sagemaker.action_if_job_exists = 'increment'
sagemaker.execute(None)
- expected_config = create_processing_params.copy()
+ expected_config = CREATE_PROCESSING_PARAMS.copy()
# Expect to see ProcessingJobName suffixed with "-2" because we return one existing job
- expected_config["ProcessingJobName"] = f"{job_name}-2"
+ expected_config['ProcessingJobName'] = 'job_name-2'
mock_create_processing_job.assert_called_once_with(
expected_config,
wait_for_completion=False,
@@ -168,26 +203,26 @@ class TestSageMakerProcessingOperator(unittest.TestCase):
max_ingestion_time=None,
)
- @mock.patch.object(SageMakerHook, "get_conn")
- @mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=True)
@mock.patch.object(
- SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
+ SageMakerHook, 'create_processing_job', return_value={'ResponseMetadata': {'HTTPStatusCode': 200}}
)
def test_execute_with_existing_job_fail(
self, mock_create_processing_job, mock_list_processing_jobs, mock_client
):
sagemaker = SageMakerProcessingOperator(
- **self.processing_config_kwargs, config=create_processing_params
+ **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
)
- sagemaker.action_if_job_exists = "fail"
+ sagemaker.action_if_job_exists = 'fail'
with pytest.raises(AirflowException):
sagemaker.execute(None)
- @mock.patch.object(SageMakerHook, "get_conn")
+ @mock.patch.object(SageMakerHook, 'get_conn')
def test_action_if_job_exists_validation(self, mock_client):
with pytest.raises(AirflowException):
SageMakerProcessingOperator(
**self.processing_config_kwargs,
- config=create_processing_params,
- action_if_job_exists="not_fail_or_increment",
+ config=CREATE_PROCESSING_PARAMS,
+ action_if_job_exists='not_fail_or_increment',
)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index 3c29c960b0..f8f16e3cfe 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -16,6 +16,7 @@
# under the License.
import unittest
+from typing import List
from unittest import mock
import pytest
@@ -24,24 +25,18 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
-role = 'arn:aws:iam:role/test-role'
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+ ['ResourceConfig', 'InstanceCount'],
+ ['ResourceConfig', 'VolumeSizeInGB'],
+ ['StoppingCondition', 'MaxRuntimeInSeconds'],
+]
-bucket = 'test-bucket'
-
-key = 'test/data'
-data_url = f's3://{bucket}/{key}'
-
-job_name = 'test-job-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-create_training_params = {
- 'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'},
- 'RoleArn': role,
- 'OutputDataConfig': {'S3OutputPath': output_url},
+CREATE_TRAINING_PARAMS = {
+ 'AlgorithmSpecification': {'TrainingImage': 'image_name', 'TrainingInputMode': 'File'},
+ 'RoleArn': 'arn:aws:iam:role/test-role',
+ 'OutputDataConfig': {'S3OutputPath': 'output_path'},
'ResourceConfig': {'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': '50'},
- 'TrainingJobName': job_name,
+ 'TrainingJobName': 'job_name',
'HyperParameters': {'k': '10', 'feature_dim': '784', 'mini_batch_size': '500', 'force_dense': 'True'},
'StoppingCondition': {'MaxRuntimeInSeconds': '3600'},
'InputDataConfig': [
@@ -50,7 +45,7 @@ create_training_params = {
'DataSource': {
'S3DataSource': {
'S3DataType': 'S3Prefix',
- 'S3Uri': data_url,
+ 'S3Uri': 's3_uri',
'S3DataDistributionType': 'FullyReplicated',
}
},
@@ -65,36 +60,36 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
def setUp(self):
self.sagemaker = SageMakerTrainingOperator(
task_id='test_sagemaker_operator',
- aws_conn_id='sagemaker_test_id',
- config=create_training_params,
+ config=CREATE_TRAINING_PARAMS,
wait_for_completion=False,
check_interval=5,
)
- def test_parse_config_integers(self):
- self.sagemaker.parse_config_integers()
- assert self.sagemaker.config['ResourceConfig']['InstanceCount'] == int(
- self.sagemaker.config['ResourceConfig']['InstanceCount']
- )
- assert self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'] == int(
- self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']
- )
- assert self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'] == int(
- self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']
- )
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'create_training_job')
+ def test_integer_fields(self, mock_training, mock_client):
+ mock_training.return_value = {
+ 'TrainingJobArn': 'test_arn',
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+ self.sagemaker._check_if_job_exists = mock.MagicMock()
+ self.sagemaker.execute(None)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+ for (key1, key2) in EXPECTED_INTEGER_FIELDS:
+ assert self.sagemaker.config[key1][key2] == int(self.sagemaker.config[key1][key2])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_training_job')
def test_execute_with_check_if_job_exists(self, mock_training, mock_client):
mock_training.return_value = {
- 'TrainingJobArn': 'testarn',
+ 'TrainingJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker._check_if_job_exists = mock.MagicMock()
self.sagemaker.execute(None)
self.sagemaker._check_if_job_exists.assert_called_once()
mock_training.assert_called_once_with(
- create_training_params,
+ CREATE_TRAINING_PARAMS,
wait_for_completion=False,
print_log=True,
check_interval=5,
@@ -105,7 +100,7 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'create_training_job')
def test_execute_without_check_if_job_exists(self, mock_training, mock_client):
mock_training.return_value = {
- 'TrainingJobArn': 'testarn',
+ 'TrainingJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker.check_if_job_exists = False
@@ -113,7 +108,7 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
self.sagemaker.execute(None)
self.sagemaker._check_if_job_exists.assert_not_called()
mock_training.assert_called_once_with(
- create_training_params,
+ CREATE_TRAINING_PARAMS,
wait_for_completion=False,
print_log=True,
check_interval=5,
@@ -124,30 +119,30 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'create_training_job')
def test_execute_with_failure(self, mock_training, mock_client):
mock_training.return_value = {
- 'TrainingJobArn': 'testarn',
+ 'TrainingJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 404},
}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
- @mock.patch.object(SageMakerHook, "get_conn")
- @mock.patch.object(SageMakerHook, "list_training_jobs")
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'list_training_jobs')
def test_check_if_job_exists_increment(self, mock_list_training_jobs, mock_client):
self.sagemaker.check_if_job_exists = True
- self.sagemaker.action_if_job_exists = "increment"
- mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
+ self.sagemaker.action_if_job_exists = 'increment'
+ mock_list_training_jobs.return_value = [{'TrainingJobName': 'job_name'}]
self.sagemaker._check_if_job_exists()
- expected_config = create_training_params.copy()
+ expected_config = CREATE_TRAINING_PARAMS.copy()
# Expect to see TrainingJobName suffixed with "-2" because we return one existing job
- expected_config["TrainingJobName"] = f"{job_name}-2"
+ expected_config['TrainingJobName'] = 'job_name-2'
assert self.sagemaker.config == expected_config
- @mock.patch.object(SageMakerHook, "get_conn")
- @mock.patch.object(SageMakerHook, "list_training_jobs")
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'list_training_jobs')
def test_check_if_job_exists_fail(self, mock_list_training_jobs, mock_client):
self.sagemaker.check_if_job_exists = True
- self.sagemaker.action_if_job_exists = "fail"
- mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
+ self.sagemaker.action_if_job_exists = 'fail'
+ mock_list_training_jobs.return_value = [{'TrainingJobName': 'job_name'}]
with pytest.raises(AirflowException):
self.sagemaker._check_if_job_exists()
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 76baa71add..b622698478 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -17,6 +17,7 @@
# under the License.
import unittest
+from typing import Dict, List
from unittest import mock
import pytest
@@ -25,44 +26,30 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
-role = 'arn:aws:iam:role/test-role'
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+ ['Transform', 'TransformResources', 'InstanceCount'],
+ ['Transform', 'MaxConcurrentTransforms'],
+ ['Transform', 'MaxPayloadInMB'],
+]
-bucket = 'test-bucket'
-
-key = 'test/data'
-data_url = f's3://{bucket}/{key}'
-
-job_name = 'test-job-name'
-
-model_name = 'test-model-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-
-create_transform_params = {
- 'TransformJobName': job_name,
- 'ModelName': model_name,
+CREATE_TRANSFORM_PARAMS: Dict = {
+ 'TransformJobName': 'job_name',
+ 'ModelName': 'model_name',
'MaxConcurrentTransforms': '12',
'MaxPayloadInMB': '6',
'BatchStrategy': 'MultiRecord',
- 'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': data_url}}},
- 'TransformOutput': {
- 'S3OutputPath': output_url,
- },
+ 'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3_uri'}}},
+ 'TransformOutput': {'S3OutputPath': 'output_path'},
'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': '3'},
}
-create_model_params = {
- 'ModelName': model_name,
- 'PrimaryContainer': {
- 'Image': image,
- 'ModelDataUrl': output_url,
- },
- 'ExecutionRoleArn': role,
+CREATE_MODEL_PARAMS: Dict = {
+ 'ModelName': 'model_name',
+ 'PrimaryContainer': {'Image': 'test_image', 'ModelDataUrl': 'output_path'},
+ 'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
}
-config = {'Model': create_model_params, 'Transform': create_transform_params}
+CONFIG: Dict = {'Model': CREATE_MODEL_PARAMS, 'Transform': CREATE_TRANSFORM_PARAMS}
class TestSageMakerTransformOperator(unittest.TestCase):
@@ -70,32 +57,40 @@ class TestSageMakerTransformOperator(unittest.TestCase):
self.sagemaker = SageMakerTransformOperator(
task_id='test_sagemaker_operator',
aws_conn_id='sagemaker_test_id',
- config=config,
+ config=CONFIG,
wait_for_completion=False,
check_interval=5,
)
- def test_parse_config_integers(self):
- self.sagemaker.parse_config_integers()
- test_config = self.sagemaker.config['Transform']
- assert test_config['TransformResources']['InstanceCount'] == int(
- test_config['TransformResources']['InstanceCount']
- )
- assert test_config['MaxConcurrentTransforms'] == int(test_config['MaxConcurrentTransforms'])
- assert test_config['MaxPayloadInMB'] == int(test_config['MaxPayloadInMB'])
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'create_model')
+ @mock.patch.object(SageMakerHook, 'create_transform_job')
+ def test_integer_fields(self, mock_transform, mock_model, mock_client):
+ mock_transform.return_value = {
+ 'TransformJobArn': 'test_arn',
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+ self.sagemaker.execute(None)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+ for (key1, key2, *key3) in EXPECTED_INTEGER_FIELDS:
+ if key3:
+ (key3,) = key3
+ assert self.sagemaker.config[key1][key2][key3] == int(self.sagemaker.config[key1][key2][key3])
+ else:
+ self.sagemaker.config[key1][key2] == int(self.sagemaker.config[key1][key2])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_model')
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute(self, mock_transform, mock_model, mock_client):
mock_transform.return_value = {
- 'TransformJobArn': 'testarn',
+ 'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
}
self.sagemaker.execute(None)
- mock_model.assert_called_once_with(create_model_params)
+ mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
mock_transform.assert_called_once_with(
- create_transform_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+ CREATE_TRANSFORM_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
)
@mock.patch.object(SageMakerHook, 'get_conn')
@@ -103,7 +98,7 @@ class TestSageMakerTransformOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'create_transform_job')
def test_execute_with_failure(self, mock_transform, mock_model, mock_client):
mock_transform.return_value = {
- 'TransformJobArn': 'testarn',
+ 'TransformJobArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 404},
}
with pytest.raises(AirflowException):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index cb63357c67..9d3efb0c49 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -17,6 +17,7 @@
# under the License.
import unittest
+from typing import Dict, List
from unittest import mock
import pytest
@@ -25,30 +26,21 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator
-role = 'arn:aws:iam:role/test-role'
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+ ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
+ ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
+ ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
+ ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
+ ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
+]
-bucket = 'test-bucket'
-
-key = 'test/data'
-data_url = f's3://{bucket}/{key}'
-
-job_name = 'test-job-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-
-create_tuning_params = {
- 'HyperParameterTuningJobName': job_name,
+CREATE_TUNING_PARAMS: Dict = {
+ 'HyperParameterTuningJobName': 'job_name',
'HyperParameterTuningJobConfig': {
'Strategy': 'Bayesian',
'HyperParameterTuningJobObjective': {'Type': 'Maximize', 'MetricName': 'test_metric'},
'ResourceLimits': {'MaxNumberOfTrainingJobs': '123', 'MaxParallelTrainingJobs': '123'},
- 'ParameterRanges': {
- 'IntegerParameterRanges': [
- {'Name': 'k', 'MinValue': '2', 'MaxValue': '10'},
- ]
- },
+ 'ParameterRanges': {'IntegerParameterRanges': [{'Name': 'k', 'MinValue': '2', 'MaxValue': '10'}]},
},
'TrainingJobDefinition': {
'StaticHyperParameters': {
@@ -57,15 +49,15 @@ create_tuning_params = {
'mini_batch_size': '500',
'force_dense': 'True',
},
- 'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'},
- 'RoleArn': role,
+ 'AlgorithmSpecification': {'TrainingImage': 'image_name', 'TrainingInputMode': 'File'},
+ 'RoleArn': 'arn:aws:iam:role/test-role',
'InputDataConfig': [
{
'ChannelName': 'train',
'DataSource': {
'S3DataSource': {
'S3DataType': 'S3Prefix',
- 'S3Uri': data_url,
+ 'S3Uri': 's3_uri',
'S3DataDistributionType': 'FullyReplicated',
}
},
@@ -73,9 +65,9 @@ create_tuning_params = {
'RecordWrapperType': 'None',
}
],
- 'OutputDataConfig': {'S3OutputPath': output_url},
+ 'OutputDataConfig': {'S3OutputPath': 'output_path'},
'ResourceConfig': {'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': '50'},
- 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60),
+ 'StoppingCondition': {'MaxRuntimeInSeconds': '3600'},
},
}
@@ -84,47 +76,32 @@ class TestSageMakerTuningOperator(unittest.TestCase):
def setUp(self):
self.sagemaker = SageMakerTuningOperator(
task_id='test_sagemaker_operator',
- aws_conn_id='sagemaker_test_conn',
- config=create_tuning_params,
+ config=CREATE_TUNING_PARAMS,
wait_for_completion=False,
check_interval=5,
)
- def test_parse_config_integers(self):
- self.sagemaker.parse_config_integers()
- assert self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount'] == int(
- self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount']
- )
- assert self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'] == int(
- self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB']
- )
- assert self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
- 'MaxNumberOfTrainingJobs'
- ] == int(
- self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
- 'MaxNumberOfTrainingJobs'
- ]
- )
- assert self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
- 'MaxParallelTrainingJobs'
- ] == int(
- self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
- 'MaxParallelTrainingJobs'
- ]
- )
+ @mock.patch.object(SageMakerHook, 'get_conn')
+ @mock.patch.object(SageMakerHook, 'create_tuning_job')
+ def test_integer_fields(self, mock_tuning, mock_client):
+ mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+ self.sagemaker.execute(None)
+ assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+ for (key1, key2, key3) in EXPECTED_INTEGER_FIELDS:
+ assert self.sagemaker.config[key1][key2][key3] == int(self.sagemaker.config[key1][key2][key3])
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_tuning_job')
def test_execute(self, mock_tuning, mock_client):
- mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+ mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
self.sagemaker.execute(None)
mock_tuning.assert_called_once_with(
- create_tuning_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+ CREATE_TUNING_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
)
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_tuning_job')
def test_execute_with_failure(self, mock_tuning, mock_client):
- mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
+ mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)