You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/15 06:07:09 UTC

[airflow] branch main updated: SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 543161a1af SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)
543161a1af is described below

commit 543161a1afe84400dbc3c0409bbf4ff8110919f8
Author: D. Ferruzzi <fe...@amazon.com>
AuthorDate: Fri Jul 15 06:07:00 2022 +0000

    SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)
    
    * Sagemaker Operator - Improve type hints and docstrings
    
    * SageMaker Operator Unit Test Improvements
    
    - Add explicit unit testing for integer_fields
    - Improve type hinting
    - Standardize some variable naming
    
    * Sagemaker Operators - Configure integer fields at runtime
---
 .../providers/amazon/aws/operators/sagemaker.py    | 132 +++++++------
 .../amazon/aws/operators/test_sagemaker_base.py    |  16 +-
 .../aws/operators/test_sagemaker_endpoint.py       |  71 +++----
 .../operators/test_sagemaker_endpoint_config.py    |  32 ++--
 .../amazon/aws/operators/test_sagemaker_model.py   |  45 +++--
 .../aws/operators/test_sagemaker_processing.py     | 211 ++++++++++++---------
 .../aws/operators/test_sagemaker_training.py       |  85 ++++-----
 .../aws/operators/test_sagemaker_transform.py      |  79 ++++----
 .../amazon/aws/operators/test_sagemaker_tuning.py  |  79 +++-----
 9 files changed, 391 insertions(+), 359 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 6a6ef69df4..8da36c58dc 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -16,7 +16,7 @@
 # under the License.
 
 import json
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
 
 from botocore.exceptions import ClientError
 
@@ -29,8 +29,8 @@ from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
-DEFAULT_CONN_ID = 'aws_default'
-CHECK_INTERVAL_SECOND = 30
+DEFAULT_CONN_ID: str = 'aws_default'
+CHECK_INTERVAL_SECOND: int = 30
 
 
 class SageMakerBaseOperator(BaseOperator):
@@ -41,15 +41,15 @@ class SageMakerBaseOperator(BaseOperator):
 
     template_fields: Sequence[str] = ('config',)
     template_ext: Sequence[str] = ()
-    template_fields_renderers = {'config': 'json'}
-    ui_color = '#ededed'
+    template_fields_renderers: Dict = {'config': 'json'}
+    ui_color: str = '#ededed'
     integer_fields: List[List[Any]] = []
 
-    def __init__(self, *, config: dict, **kwargs):
+    def __init__(self, *, config: Dict, **kwargs):
         super().__init__(**kwargs)
         self.config = config
 
-    def parse_integer(self, config, field):
+    def parse_integer(self, config: Dict, field: Union[List[str], str]) -> None:
         """Recursive method for parsing string fields holding integer values to integers."""
         if len(field) == 1:
             if isinstance(config, list):
@@ -69,19 +69,17 @@ class SageMakerBaseOperator(BaseOperator):
             self.parse_integer(config[head], tail)
         return
 
-    def parse_config_integers(self):
-        """
-        Parse the integer fields of training config to integers in case the config is rendered by Jinja and
-        all fields are str
-        """
+    def parse_config_integers(self) -> None:
+        """Parse the integer fields to ints in case the config is rendered by Jinja and all fields are str."""
         for field in self.integer_fields:
             self.parse_integer(self.config, field)
 
-    def expand_role(self):
+    def expand_role(self) -> None:
         """Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
 
-    def preprocess_config(self):
+    def preprocess_config(self) -> None:
         """Process the config into a usable form."""
+        self._create_integer_fields()
         self.log.info('Preprocessing the config and doing required s3_operations')
         self.hook.configure_s3_resources(self.config)
         self.parse_config_integers()
@@ -91,12 +89,19 @@ class SageMakerBaseOperator(BaseOperator):
             json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')),
         )
 
-    def execute(self, context: 'Context'):
+    def _create_integer_fields(self) -> None:
+        """
+        Set fields which should be cast to integers.
+        Child classes should override this method if they need integer fields parsed.
+        """
+        self.integer_fields = []
+
+    def execute(self, context: 'Context') -> Union[None, Dict]:
         raise NotImplementedError('Please implement execute() in sub class!')
 
     @cached_property
     def hook(self):
-        """Return SageMakerHook"""
+        """Return SageMakerHook."""
         return SageMakerHook(aws_conn_id=self.aws_conn_id)
 
 
@@ -130,7 +135,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
     def __init__(
         self,
         *,
-        config: dict,
+        config: Dict,
         aws_conn_id: str = DEFAULT_CONN_ID,
         wait_for_completion: bool = True,
         print_log: bool = True,
@@ -151,23 +156,23 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         self.print_log = print_log
         self.check_interval = check_interval
         self.max_ingestion_time = max_ingestion_time
-        self._create_integer_fields()
 
     def _create_integer_fields(self) -> None:
-        """Set fields which should be casted to integers."""
-        self.integer_fields = [
+        """Set fields which should be cast to integers."""
+        self.integer_fields: List[Union[List[str], List[List[str]]]] = [
             ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
             ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
         ]
         if 'StoppingCondition' in self.config:
-            self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']]
+            self.integer_fields.append(['StoppingCondition', 'MaxRuntimeInSeconds'])
 
     def expand_role(self) -> None:
+        """Expands an IAM role name into an ARN."""
         if 'RoleArn' in self.config:
             hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
             self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
 
-    def execute(self, context: 'Context') -> dict:
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         processing_job_name = self.config['ProcessingJobName']
         if self.hook.find_processing_job_by_name(processing_job_name):
@@ -204,12 +209,10 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
     :return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker.
     """
 
-    integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
-
     def __init__(
         self,
         *,
-        config: dict,
+        config: Dict,
         aws_conn_id: str = DEFAULT_CONN_ID,
         **kwargs,
     ):
@@ -217,7 +220,11 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
         self.config = config
         self.aws_conn_id = aws_conn_id
 
-    def execute(self, context: 'Context') -> dict:
+    def _create_integer_fields(self) -> None:
+        """Set fields which should be cast to integers."""
+        self.integer_fields: List[List[str]] = [['ProductionVariants', 'InitialInstanceCount']]
+
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
         response = self.hook.create_endpoint_config(self.config)
@@ -278,7 +285,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
     def __init__(
         self,
         *,
-        config: dict,
+        config: Dict,
         aws_conn_id: str = DEFAULT_CONN_ID,
         wait_for_completion: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
@@ -295,14 +302,16 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         self.operation = operation.lower()
         if self.operation not in ['create', 'update']:
             raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
-        self.create_integer_fields()
 
-    def create_integer_fields(self) -> None:
-        """Set fields which should be casted to integers."""
+    def _create_integer_fields(self) -> None:
+        """Set fields which should be cast to integers."""
         if 'EndpointConfig' in self.config:
-            self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
+            self.integer_fields: List[List[str]] = [
+                ['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
+            ]
 
     def expand_role(self) -> None:
+        """Expands an IAM role name into an ARN."""
         if 'Model' not in self.config:
             return
         hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
@@ -310,7 +319,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         if 'ExecutionRoleArn' in config:
             config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
 
-    def execute(self, context: 'Context') -> dict:
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         model_info = self.config.get('Model')
         endpoint_config_info = self.config.get('EndpointConfig')
@@ -397,7 +406,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
     def __init__(
         self,
         *,
-        config: dict,
+        config: Dict,
         aws_conn_id: str = DEFAULT_CONN_ID,
         wait_for_completion: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
@@ -410,10 +419,9 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
         self.max_ingestion_time = max_ingestion_time
-        self.create_integer_fields()
 
-    def create_integer_fields(self) -> None:
-        """Set fields which should be casted to integers."""
+    def _create_integer_fields(self) -> None:
+        """Set fields which should be cast to integers."""
         self.integer_fields: List[List[str]] = [
             ['Transform', 'TransformResources', 'InstanceCount'],
             ['Transform', 'MaxConcurrentTransforms'],
@@ -424,6 +432,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 field.pop(0)
 
     def expand_role(self) -> None:
+        """Expands an IAM role name into an ARN."""
         if 'Model' not in self.config:
             return
         config = self.config['Model']
@@ -431,7 +440,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
             hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
             config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
 
-    def execute(self, context: 'Context') -> dict:
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         model_config = self.config.get('Model')
         transform_config = self.config.get('Transform', self.config)
@@ -480,18 +489,10 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
     :return Dict: Returns The ARN of the tuning job created in Amazon SageMaker.
     """
 
-    integer_fields = [
-        ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
-        ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
-        ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
-        ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
-        ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
-    ]
-
     def __init__(
         self,
         *,
-        config: dict,
+        config: Dict,
         aws_conn_id: str = DEFAULT_CONN_ID,
         wait_for_completion: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
@@ -506,13 +507,24 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
         self.max_ingestion_time = max_ingestion_time
 
     def expand_role(self) -> None:
+        """Expands an IAM role name into an ARN."""
         if 'TrainingJobDefinition' in self.config:
             config = self.config['TrainingJobDefinition']
             if 'RoleArn' in config:
                 hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
                 config['RoleArn'] = hook.expand_role(config['RoleArn'])
 
-    def execute(self, context: 'Context') -> dict:
+    def _create_integer_fields(self) -> None:
+        """Set fields which should be cast to integers."""
+        self.integer_fields: List[List[str]] = [
+            ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
+            ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
+            ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
+            ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
+            ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
+        ]
+
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         self.log.info(
             'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
@@ -547,17 +559,18 @@ class SageMakerModelOperator(SageMakerBaseOperator):
     :return Dict: Returns The ARN of the model created in Amazon SageMaker.
     """
 
-    def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
+    def __init__(self, *, config: Dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
         super().__init__(config=config, **kwargs)
         self.config = config
         self.aws_conn_id = aws_conn_id
 
     def expand_role(self) -> None:
+        """Expands an IAM role name into an ARN."""
         if 'ExecutionRoleArn' in self.config:
             hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
             self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
 
-    def execute(self, context: 'Context') -> dict:
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
         response = self.hook.create_model(self.config)
@@ -596,16 +609,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
     :return Dict: Returns The ARN of the training job created in Amazon SageMaker.
     """
 
-    integer_fields = [
-        ['ResourceConfig', 'InstanceCount'],
-        ['ResourceConfig', 'VolumeSizeInGB'],
-        ['StoppingCondition', 'MaxRuntimeInSeconds'],
-    ]
-
     def __init__(
         self,
         *,
-        config: dict,
+        config: Dict,
         aws_conn_id: str = DEFAULT_CONN_ID,
         wait_for_completion: bool = True,
         print_log: bool = True,
@@ -631,11 +638,20 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
             )
 
     def expand_role(self) -> None:
+        """Expands an IAM role name into an ARN."""
         if 'RoleArn' in self.config:
             hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
             self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
 
-    def execute(self, context: 'Context') -> dict:
+    def _create_integer_fields(self) -> None:
+        """Set fields which should be cast to integers."""
+        self.integer_fields: List[List[str]] = [
+            ['ResourceConfig', 'InstanceCount'],
+            ['ResourceConfig', 'VolumeSizeInGB'],
+            ['StoppingCondition', 'MaxRuntimeInSeconds'],
+        ]
+
+    def execute(self, context: 'Context') -> Dict:
         self.preprocess_config()
         if self.check_if_job_exists:
             self._check_if_job_exists()
@@ -680,7 +696,7 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
     :param aws_conn_id: The AWS connection ID to use.
     """
 
-    def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
+    def __init__(self, *, config: Dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
         super().__init__(config=config, **kwargs)
         self.config = config
         self.aws_conn_id = aws_conn_id
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
index 5fecaf9c94..9c96177eae 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py
@@ -16,19 +16,27 @@
 # under the License.
 
 import unittest
+from typing import Any, Dict, List
 
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator
 
-config = {'key1': '1', 'key2': {'key3': '3', 'key4': '4'}, 'key5': [{'key6': '6'}, {'key6': '7'}]}
+CONFIG: Dict = {'key1': '1', 'key2': {'key3': '3', 'key4': '4'}, 'key5': [{'key6': '6'}, {'key6': '7'}]}
+PARSED_CONFIG: Dict = {'key1': 1, 'key2': {'key3': 3, 'key4': 4}, 'key5': [{'key6': 6}, {'key6': 7}]}
 
-parsed_config = {'key1': 1, 'key2': {'key3': 3, 'key4': 4}, 'key5': [{'key6': 6}, {'key6': 7}]}
+EXPECTED_INTEGER_FIELDS: List[List[Any]] = []
 
 
 class TestSageMakerBaseOperator(unittest.TestCase):
     def setUp(self):
-        self.sagemaker = SageMakerBaseOperator(task_id='test_sagemaker_operator', config=config)
+        self.sagemaker = SageMakerBaseOperator(task_id='test_sagemaker_operator', config=CONFIG)
+        self.sagemaker.aws_conn_id = 'aws_default'
 
     def test_parse_integer(self):
         self.sagemaker.integer_fields = [['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6']]
         self.sagemaker.parse_config_integers()
-        assert self.sagemaker.config == parsed_config
+        assert self.sagemaker.config == PARSED_CONFIG
+
+    def test_default_integer_fields(self):
+        self.sagemaker.preprocess_config()
+
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 11e1431c69..3f252ab4d8 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -17,6 +17,7 @@
 # under the License.
 
 import unittest
+from typing import Dict, List
 from unittest import mock
 
 import pytest
@@ -26,57 +27,54 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator
 
-role = 'arn:aws:iam:role/test-role'
-bucket = 'test-bucket'
-image = 'test-image'
-output_url = f's3://{bucket}/test/output'
-model_name = 'test-model-name'
-config_name = 'test-endpoint-config-name'
-endpoint_name = 'test-endpoint-name'
-
-create_model_params = {
-    'ModelName': model_name,
+CREATE_MODEL_PARAMS: Dict = {
+    'ModelName': 'model_name',
     'PrimaryContainer': {
-        'Image': image,
-        'ModelDataUrl': output_url,
+        'Image': 'image_name',
+        'ModelDataUrl': 'output_path',
     },
-    'ExecutionRoleArn': role,
+    'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
 }
-
-create_endpoint_config_params = {
-    'EndpointConfigName': config_name,
+CREATE_ENDPOINT_CONFIG_PARAMS: Dict = {
+    'EndpointConfigName': 'config_name',
     'ProductionVariants': [
         {
             'VariantName': 'AllTraffic',
-            'ModelName': model_name,
+            'ModelName': 'model_name',
             'InitialInstanceCount': '1',
             'InstanceType': 'ml.c4.xlarge',
         }
     ],
 }
+CREATE_ENDPOINT_PARAMS: Dict = {'EndpointName': 'endpoint_name', 'EndpointConfigName': 'config_name'}
 
-create_endpoint_params = {'EndpointName': endpoint_name, 'EndpointConfigName': config_name}
-
-config = {
-    'Model': create_model_params,
-    'EndpointConfig': create_endpoint_config_params,
-    'Endpoint': create_endpoint_params,
+CONFIG: Dict = {
+    'Model': CREATE_MODEL_PARAMS,
+    'EndpointConfig': CREATE_ENDPOINT_CONFIG_PARAMS,
+    'Endpoint': CREATE_ENDPOINT_PARAMS,
 }
 
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
+
 
 class TestSageMakerEndpointOperator(unittest.TestCase):
     def setUp(self):
         self.sagemaker = SageMakerEndpointOperator(
             task_id='test_sagemaker_operator',
-            aws_conn_id='sagemaker_test_id',
-            config=config,
+            config=CONFIG,
             wait_for_completion=False,
             check_interval=5,
             operation='create',
         )
 
-    def test_parse_config_integers(self):
-        self.sagemaker.parse_config_integers()
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+    @mock.patch.object(SageMakerHook, 'create_endpoint')
+    def test_integer_fields(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
+        mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+        self.sagemaker.execute(None)
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
         for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']:
             assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
 
@@ -85,20 +83,23 @@ class TestSageMakerEndpointOperator(unittest.TestCase):
     @mock.patch.object(SageMakerHook, 'create_endpoint_config')
     @mock.patch.object(SageMakerHook, 'create_endpoint')
     def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
-        mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+        mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
         self.sagemaker.execute(None)
-        mock_model.assert_called_once_with(create_model_params)
-        mock_endpoint_config.assert_called_once_with(create_endpoint_config_params)
+        mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
+        mock_endpoint_config.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
         mock_endpoint.assert_called_once_with(
-            create_endpoint_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+            CREATE_ENDPOINT_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
         )
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+        for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']:
+            assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_model')
     @mock.patch.object(SageMakerHook, 'create_endpoint_config')
     @mock.patch.object(SageMakerHook, 'create_endpoint')
     def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
-        mock_endpoint.return_value = {'EndpointArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
+        mock_endpoint.return_value = {'EndpointArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
 
@@ -111,11 +112,11 @@ class TestSageMakerEndpointOperator(unittest.TestCase):
         self, mock_endpoint_update, mock_endpoint, mock_endpoint_config, mock_model, mock_client
     ):
         response = {
-            "Error": {"Code": "ValidationException", "Message": "Cannot create already existing endpoint."}
+            'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing endpoint.'}
         }
-        mock_endpoint.side_effect = ClientError(error_response=response, operation_name="CreateEndpoint")
+        mock_endpoint.side_effect = ClientError(error_response=response, operation_name='CreateEndpoint')
         mock_endpoint_update.return_value = {
-            'EndpointArn': 'testarn',
+            'EndpointArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
         self.sagemaker.execute(None)
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index 40fe16fd82..b43b30d10e 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -17,6 +17,7 @@
 # under the License.
 
 import unittest
+from typing import Dict, List
 from unittest import mock
 
 import pytest
@@ -25,32 +26,37 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator
 
-model_name = 'test-model-name'
-config_name = 'test-config-name'
-
-create_endpoint_config_params = {
-    'EndpointConfigName': config_name,
+CREATE_ENDPOINT_CONFIG_PARAMS: Dict = {
+    'EndpointConfigName': 'config_name',
     'ProductionVariants': [
         {
             'VariantName': 'AllTraffic',
-            'ModelName': model_name,
+            'ModelName': 'model_name',
             'InitialInstanceCount': '1',
             'InstanceType': 'ml.c4.xlarge',
         }
     ],
 }
 
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [['ProductionVariants', 'InitialInstanceCount']]
+
 
 class TestSageMakerEndpointConfigOperator(unittest.TestCase):
     def setUp(self):
         self.sagemaker = SageMakerEndpointConfigOperator(
             task_id='test_sagemaker_operator',
-            aws_conn_id='sagemaker_test_id',
-            config=create_endpoint_config_params,
+            config=CREATE_ENDPOINT_CONFIG_PARAMS,
         )
 
-    def test_parse_config_integers(self):
-        self.sagemaker.parse_config_integers()
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+    def test_integer_fields(self, mock_model, mock_client):
+        mock_model.return_value = {
+            'EndpointConfigArn': 'test_arn',
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+        self.sagemaker.execute(None)
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
         for variant in self.sagemaker.config['ProductionVariants']:
             assert variant['InitialInstanceCount'] == int(variant['InitialInstanceCount'])
 
@@ -58,17 +64,17 @@ class TestSageMakerEndpointConfigOperator(unittest.TestCase):
     @mock.patch.object(SageMakerHook, 'create_endpoint_config')
     def test_execute(self, mock_model, mock_client):
         mock_model.return_value = {
-            'EndpointConfigArn': 'testarn',
+            'EndpointConfigArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
         self.sagemaker.execute(None)
-        mock_model.assert_called_once_with(create_endpoint_config_params)
+        mock_model.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_model')
     def test_execute_with_failure(self, mock_model, mock_client):
         mock_model.return_value = {
-            'EndpointConfigArn': 'testarn',
+            'EndpointConfigArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
         with pytest.raises(AirflowException):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index c2689d8fd5..98990f7c74 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -17,6 +17,7 @@
 # under the License.
 
 import unittest
+from typing import Dict, List
 from unittest import mock
 
 import pytest
@@ -28,51 +29,49 @@ from airflow.providers.amazon.aws.operators.sagemaker import (
     SageMakerModelOperator,
 )
 
-role = 'arn:aws:iam:role/test-role'
-
-bucket = 'test-bucket'
-
-model_name = 'test-model-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-create_model_params = {
-    'ModelName': model_name,
+CREATE_MODEL_PARAMS: Dict = {
+    'ModelName': 'model_name',
     'PrimaryContainer': {
-        'Image': image,
-        'ModelDataUrl': output_url,
+        'Image': 'image_name',
+        'ModelDataUrl': 'output_path',
     },
-    'ExecutionRoleArn': role,
+    'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
 }
 
+EXPECTED_INTEGER_FIELDS: List[List[str]] = []
+
 
 class TestSageMakerModelOperator(unittest.TestCase):
     def setUp(self):
-        self.sagemaker = SageMakerModelOperator(
-            task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=create_model_params
-        )
+        self.sagemaker = SageMakerModelOperator(task_id='test_sagemaker_operator', config=CREATE_MODEL_PARAMS)
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    def test_integer_fields(self, mock_model, mock_client):
+        mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+        self.sagemaker.execute(None)
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_model')
     def test_execute(self, mock_model, mock_client):
-        mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+        mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
         self.sagemaker.execute(None)
-        mock_model.assert_called_once_with(create_model_params)
+        mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_model')
     def test_execute_with_failure(self, mock_model, mock_client):
-        mock_model.return_value = {'ModelArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
+        mock_model.return_value = {'ModelArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
 
 
 class TestSageMakerDeleteModelOperator(unittest.TestCase):
     def setUp(self):
-        delete_model_params = {'ModelName': 'test'}
+        delete_model_params = {'ModelName': 'model_name'}
         self.sagemaker = SageMakerDeleteModelOperator(
-            task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=delete_model_params
+            task_id='test_sagemaker_operator', config=delete_model_params
         )
 
     @mock.patch.object(SageMakerHook, 'get_conn')
@@ -80,4 +79,4 @@ class TestSageMakerDeleteModelOperator(unittest.TestCase):
     def test_execute(self, delete_model, mock_client):
         delete_model.return_value = None
         self.sagemaker.execute(None)
-        delete_model.assert_called_once_with(model_name='test')
+        delete_model.assert_called_once_with(model_name='model_name')
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 7855bb6a44..10e4f3feda 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -16,151 +16,186 @@
 # under the License.
 
 import unittest
+from typing import Dict, List
 from unittest import mock
 
 import pytest
-from parameterized import parameterized
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator
 
-job_name = 'test-job-name'
-
-create_processing_params = {
-    "AppSpecification": {
-        "ContainerArguments": ["container_arg"],
-        "ContainerEntrypoint": ["container_entrypoint"],
-        "ImageUri": "{{ image_uri }}",
+CREATE_PROCESSING_PARAMS: Dict = {
+    'AppSpecification': {
+        'ContainerArguments': ['container_arg'],
+        'ContainerEntrypoint': ['container_entrypoint'],
+        'ImageUri': 'image_uri',
     },
-    "Environment": {"{{ key }}": "{{ value }}"},
-    "ExperimentConfig": {
-        "ExperimentName": "ExperimentName",
-        "TrialComponentDisplayName": "TrialComponentDisplayName",
-        "TrialName": "TrialName",
+    'Environment': {'key': 'value'},
+    'ExperimentConfig': {
+        'ExperimentName': 'experiment_name',
+        'TrialComponentDisplayName': 'trial_component_display_name',
+        'TrialName': 'trial_name',
     },
-    "ProcessingInputs": [
+    'ProcessingInputs': [
         {
-            "InputName": "AnalyticsInputName",
-            "S3Input": {
-                "LocalPath": "{{ Local Path }}",
-                "S3CompressionType": "None",
-                "S3DataDistributionType": "FullyReplicated",
-                "S3DataType": "S3Prefix",
-                "S3InputMode": "File",
-                "S3Uri": "{{ S3Uri }}",
+            'InputName': 'analytics_input_name',
+            'S3Input': {
+                'LocalPath': 'local_path',
+                'S3CompressionType': 'None',
+                'S3DataDistributionType': 'FullyReplicated',
+                'S3DataType': 'S3Prefix',
+                'S3InputMode': 'File',
+                'S3Uri': 's3_uri',
             },
         }
     ],
-    "ProcessingJobName": job_name,
-    "ProcessingOutputConfig": {
-        "KmsKeyId": "KmsKeyID",
-        "Outputs": [
+    'ProcessingJobName': 'job_name',
+    'ProcessingOutputConfig': {
+        'KmsKeyId': 'kms_key_ID',
+        'Outputs': [
             {
-                "OutputName": "AnalyticsOutputName",
-                "S3Output": {
-                    "LocalPath": "{{ Local Path }}",
-                    "S3UploadMode": "EndOfJob",
-                    "S3Uri": "{{ S3Uri }}",
+                'OutputName': 'analytics_output_name',
+                'S3Output': {
+                    'LocalPath': 'local_path',
+                    'S3UploadMode': 'EndOfJob',
+                    'S3Uri': 's3_uri',
                 },
             }
         ],
     },
-    "ProcessingResources": {
-        "ClusterConfig": {
-            "InstanceCount": 2,
-            "InstanceType": "ml.p2.xlarge",
-            "VolumeSizeInGB": 30,
-            "VolumeKmsKeyId": "{{ kms_key }}",
+    'ProcessingResources': {
+        'ClusterConfig': {
+            'InstanceCount': '2',
+            'InstanceType': 'ml.p2.xlarge',
+            'VolumeSizeInGB': '30',
+            'VolumeKmsKeyId': 'kms_key',
         }
     },
-    "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
-    "Tags": [{"{{ key }}": "{{ value }}"}],
+    'RoleArn': 'arn:aws:iam::0122345678910:role/SageMakerPowerUser',
+    'Tags': [{'key': 'value'}],
 }
 
-create_processing_params_with_stopping_condition = create_processing_params.copy()
-create_processing_params_with_stopping_condition.update(StoppingCondition={"MaxRuntimeInSeconds": 3600})
+CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION: Dict = CREATE_PROCESSING_PARAMS.copy()
+CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION.update(StoppingCondition={'MaxRuntimeInSeconds': '3600'})
+
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+    ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
+    ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
+]
+EXPECTED_STOPPING_CONDITION_INTEGER_FIELDS: List[List[str]] = [['StoppingCondition', 'MaxRuntimeInSeconds']]
 
 
 class TestSageMakerProcessingOperator(unittest.TestCase):
     def setUp(self):
         self.processing_config_kwargs = dict(
-            task_id='test_sagemaker_operator',
-            aws_conn_id='sagemaker_test_id',
-            wait_for_completion=False,
-            check_interval=5,
+            task_id='test_sagemaker_operator', wait_for_completion=False, check_interval=5
         )
 
-    @parameterized.expand(
-        [
-            (
-                create_processing_params,
-                [
-                    ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
-                    ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
-                ],
-            ),
-            (
-                create_processing_params_with_stopping_condition,
-                [
-                    ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
-                    ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
-                    ['StoppingCondition', 'MaxRuntimeInSeconds'],
-                ],
-            ),
-        ]
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
+    @mock.patch.object(
+        SageMakerHook,
+        'create_processing_job',
+        return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
     )
-    def test_integer_fields_are_set(self, config, expected_fields):
-        sagemaker = SageMakerProcessingOperator(**self.processing_config_kwargs, config=config)
-        assert sagemaker.integer_fields == expected_fields
+    def test_integer_fields_without_stopping_condition(self, mock_processing, mock_hook, mock_client):
+        sagemaker = SageMakerProcessingOperator(
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
+        )
+        sagemaker.execute(None)
+        assert sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+        for (key1, key2, key3) in EXPECTED_INTEGER_FIELDS:
+            assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3])
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
+    @mock.patch.object(
+        SageMakerHook,
+        'create_processing_job',
+        return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
+    )
+    def test_integer_fields_with_stopping_condition(self, mock_processing, mock_hook, mock_client):
+        sagemaker = SageMakerProcessingOperator(
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION
+        )
+        sagemaker.execute(None)
+        assert (
+            sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS + EXPECTED_STOPPING_CONDITION_INTEGER_FIELDS
+        )
+        for (key1, key2, *key3) in EXPECTED_INTEGER_FIELDS:
+            if key3:
+                (key3,) = key3
+                assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3])
+            else:
+                sagemaker.config[key1][key2] == int(sagemaker.config[key1][key2])
 
     @mock.patch.object(SageMakerHook, 'get_conn')
-    @mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
+    @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
     @mock.patch.object(
         SageMakerHook,
         'create_processing_job',
-        return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
+        return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
     )
     def test_execute(self, mock_processing, mock_hook, mock_client):
         sagemaker = SageMakerProcessingOperator(
-            **self.processing_config_kwargs, config=create_processing_params
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
         )
         sagemaker.execute(None)
         mock_processing.assert_called_once_with(
-            create_processing_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+            CREATE_PROCESSING_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
         )
 
     @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=False)
     @mock.patch.object(
         SageMakerHook,
         'create_processing_job',
-        return_value={'ProcessingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}},
+        return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}},
+    )
+    def test_execute_with_stopping_condition(self, mock_processing, mock_hook, mock_client):
+        sagemaker = SageMakerProcessingOperator(
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION
+        )
+        sagemaker.execute(None)
+        mock_processing.assert_called_once_with(
+            CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION,
+            wait_for_completion=False,
+            check_interval=5,
+            max_ingestion_time=None,
+        )
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(
+        SageMakerHook,
+        'create_processing_job',
+        return_value={'ProcessingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}},
     )
     def test_execute_with_failure(self, mock_processing, mock_client):
         sagemaker = SageMakerProcessingOperator(
-            **self.processing_config_kwargs, config=create_processing_params
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
         )
         with pytest.raises(AirflowException):
             sagemaker.execute(None)
 
-    @unittest.skip("Currently, the auto-increment jobname functionality is not missing.")
-    @mock.patch.object(SageMakerHook, "get_conn")
-    @mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
+    @unittest.skip('Currently, the auto-increment jobname functionality is not missing.')
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=True)
     @mock.patch.object(
-        SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
+        SageMakerHook, 'create_processing_job', return_value={'ResponseMetadata': {'HTTPStatusCode': 200}}
     )
     def test_execute_with_existing_job_increment(
         self, mock_create_processing_job, find_processing_job_by_name, mock_client
     ):
         sagemaker = SageMakerProcessingOperator(
-            **self.processing_config_kwargs, config=create_processing_params
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
         )
-        sagemaker.action_if_job_exists = "increment"
+        sagemaker.action_if_job_exists = 'increment'
         sagemaker.execute(None)
 
-        expected_config = create_processing_params.copy()
+        expected_config = CREATE_PROCESSING_PARAMS.copy()
         # Expect to see ProcessingJobName suffixed with "-2" because we return one existing job
-        expected_config["ProcessingJobName"] = f"{job_name}-2"
+        expected_config['ProcessingJobName'] = 'job_name-2'
         mock_create_processing_job.assert_called_once_with(
             expected_config,
             wait_for_completion=False,
@@ -168,26 +203,26 @@ class TestSageMakerProcessingOperator(unittest.TestCase):
             max_ingestion_time=None,
         )
 
-    @mock.patch.object(SageMakerHook, "get_conn")
-    @mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'find_processing_job_by_name', return_value=True)
     @mock.patch.object(
-        SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
+        SageMakerHook, 'create_processing_job', return_value={'ResponseMetadata': {'HTTPStatusCode': 200}}
     )
     def test_execute_with_existing_job_fail(
         self, mock_create_processing_job, mock_list_processing_jobs, mock_client
     ):
         sagemaker = SageMakerProcessingOperator(
-            **self.processing_config_kwargs, config=create_processing_params
+            **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
         )
-        sagemaker.action_if_job_exists = "fail"
+        sagemaker.action_if_job_exists = 'fail'
         with pytest.raises(AirflowException):
             sagemaker.execute(None)
 
-    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, 'get_conn')
     def test_action_if_job_exists_validation(self, mock_client):
         with pytest.raises(AirflowException):
             SageMakerProcessingOperator(
                 **self.processing_config_kwargs,
-                config=create_processing_params,
-                action_if_job_exists="not_fail_or_increment",
+                config=CREATE_PROCESSING_PARAMS,
+                action_if_job_exists='not_fail_or_increment',
             )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index 3c29c960b0..f8f16e3cfe 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import unittest
+from typing import List
 from unittest import mock
 
 import pytest
@@ -24,24 +25,18 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
 
-role = 'arn:aws:iam:role/test-role'
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+    ['ResourceConfig', 'InstanceCount'],
+    ['ResourceConfig', 'VolumeSizeInGB'],
+    ['StoppingCondition', 'MaxRuntimeInSeconds'],
+]
 
-bucket = 'test-bucket'
-
-key = 'test/data'
-data_url = f's3://{bucket}/{key}'
-
-job_name = 'test-job-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-create_training_params = {
-    'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'},
-    'RoleArn': role,
-    'OutputDataConfig': {'S3OutputPath': output_url},
+CREATE_TRAINING_PARAMS = {
+    'AlgorithmSpecification': {'TrainingImage': 'image_name', 'TrainingInputMode': 'File'},
+    'RoleArn': 'arn:aws:iam:role/test-role',
+    'OutputDataConfig': {'S3OutputPath': 'output_path'},
     'ResourceConfig': {'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': '50'},
-    'TrainingJobName': job_name,
+    'TrainingJobName': 'job_name',
     'HyperParameters': {'k': '10', 'feature_dim': '784', 'mini_batch_size': '500', 'force_dense': 'True'},
     'StoppingCondition': {'MaxRuntimeInSeconds': '3600'},
     'InputDataConfig': [
@@ -50,7 +45,7 @@ create_training_params = {
             'DataSource': {
                 'S3DataSource': {
                     'S3DataType': 'S3Prefix',
-                    'S3Uri': data_url,
+                    'S3Uri': 's3_uri',
                     'S3DataDistributionType': 'FullyReplicated',
                 }
             },
@@ -65,36 +60,36 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
     def setUp(self):
         self.sagemaker = SageMakerTrainingOperator(
             task_id='test_sagemaker_operator',
-            aws_conn_id='sagemaker_test_id',
-            config=create_training_params,
+            config=CREATE_TRAINING_PARAMS,
             wait_for_completion=False,
             check_interval=5,
         )
 
-    def test_parse_config_integers(self):
-        self.sagemaker.parse_config_integers()
-        assert self.sagemaker.config['ResourceConfig']['InstanceCount'] == int(
-            self.sagemaker.config['ResourceConfig']['InstanceCount']
-        )
-        assert self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'] == int(
-            self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']
-        )
-        assert self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'] == int(
-            self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']
-        )
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_training_job')
+    def test_integer_fields(self, mock_training, mock_client):
+        mock_training.return_value = {
+            'TrainingJobArn': 'test_arn',
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+        self.sagemaker._check_if_job_exists = mock.MagicMock()
+        self.sagemaker.execute(None)
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+        for (key1, key2) in EXPECTED_INTEGER_FIELDS:
+            assert self.sagemaker.config[key1][key2] == int(self.sagemaker.config[key1][key2])
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_training_job')
     def test_execute_with_check_if_job_exists(self, mock_training, mock_client):
         mock_training.return_value = {
-            'TrainingJobArn': 'testarn',
+            'TrainingJobArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
         self.sagemaker._check_if_job_exists = mock.MagicMock()
         self.sagemaker.execute(None)
         self.sagemaker._check_if_job_exists.assert_called_once()
         mock_training.assert_called_once_with(
-            create_training_params,
+            CREATE_TRAINING_PARAMS,
             wait_for_completion=False,
             print_log=True,
             check_interval=5,
@@ -105,7 +100,7 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
     @mock.patch.object(SageMakerHook, 'create_training_job')
     def test_execute_without_check_if_job_exists(self, mock_training, mock_client):
         mock_training.return_value = {
-            'TrainingJobArn': 'testarn',
+            'TrainingJobArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
         self.sagemaker.check_if_job_exists = False
@@ -113,7 +108,7 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
         self.sagemaker.execute(None)
         self.sagemaker._check_if_job_exists.assert_not_called()
         mock_training.assert_called_once_with(
-            create_training_params,
+            CREATE_TRAINING_PARAMS,
             wait_for_completion=False,
             print_log=True,
             check_interval=5,
@@ -124,30 +119,30 @@ class TestSageMakerTrainingOperator(unittest.TestCase):
     @mock.patch.object(SageMakerHook, 'create_training_job')
     def test_execute_with_failure(self, mock_training, mock_client):
         mock_training.return_value = {
-            'TrainingJobArn': 'testarn',
+            'TrainingJobArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 404},
         }
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
 
-    @mock.patch.object(SageMakerHook, "get_conn")
-    @mock.patch.object(SageMakerHook, "list_training_jobs")
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'list_training_jobs')
     def test_check_if_job_exists_increment(self, mock_list_training_jobs, mock_client):
         self.sagemaker.check_if_job_exists = True
-        self.sagemaker.action_if_job_exists = "increment"
-        mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
+        self.sagemaker.action_if_job_exists = 'increment'
+        mock_list_training_jobs.return_value = [{'TrainingJobName': 'job_name'}]
         self.sagemaker._check_if_job_exists()
 
-        expected_config = create_training_params.copy()
+        expected_config = CREATE_TRAINING_PARAMS.copy()
         # Expect to see TrainingJobName suffixed with "-2" because we return one existing job
-        expected_config["TrainingJobName"] = f"{job_name}-2"
+        expected_config['TrainingJobName'] = 'job_name-2'
         assert self.sagemaker.config == expected_config
 
-    @mock.patch.object(SageMakerHook, "get_conn")
-    @mock.patch.object(SageMakerHook, "list_training_jobs")
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'list_training_jobs')
     def test_check_if_job_exists_fail(self, mock_list_training_jobs, mock_client):
         self.sagemaker.check_if_job_exists = True
-        self.sagemaker.action_if_job_exists = "fail"
-        mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
+        self.sagemaker.action_if_job_exists = 'fail'
+        mock_list_training_jobs.return_value = [{'TrainingJobName': 'job_name'}]
         with pytest.raises(AirflowException):
             self.sagemaker._check_if_job_exists()
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 76baa71add..b622698478 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -17,6 +17,7 @@
 # under the License.
 
 import unittest
+from typing import Dict, List
 from unittest import mock
 
 import pytest
@@ -25,44 +26,30 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
 
-role = 'arn:aws:iam:role/test-role'
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+    ['Transform', 'TransformResources', 'InstanceCount'],
+    ['Transform', 'MaxConcurrentTransforms'],
+    ['Transform', 'MaxPayloadInMB'],
+]
 
-bucket = 'test-bucket'
-
-key = 'test/data'
-data_url = f's3://{bucket}/{key}'
-
-job_name = 'test-job-name'
-
-model_name = 'test-model-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-
-create_transform_params = {
-    'TransformJobName': job_name,
-    'ModelName': model_name,
+CREATE_TRANSFORM_PARAMS: Dict = {
+    'TransformJobName': 'job_name',
+    'ModelName': 'model_name',
     'MaxConcurrentTransforms': '12',
     'MaxPayloadInMB': '6',
     'BatchStrategy': 'MultiRecord',
-    'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': data_url}}},
-    'TransformOutput': {
-        'S3OutputPath': output_url,
-    },
+    'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3_uri'}}},
+    'TransformOutput': {'S3OutputPath': 'output_path'},
     'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': '3'},
 }
 
-create_model_params = {
-    'ModelName': model_name,
-    'PrimaryContainer': {
-        'Image': image,
-        'ModelDataUrl': output_url,
-    },
-    'ExecutionRoleArn': role,
+CREATE_MODEL_PARAMS: Dict = {
+    'ModelName': 'model_name',
+    'PrimaryContainer': {'Image': 'test_image', 'ModelDataUrl': 'output_path'},
+    'ExecutionRoleArn': 'arn:aws:iam:role/test-role',
 }
 
-config = {'Model': create_model_params, 'Transform': create_transform_params}
+CONFIG: Dict = {'Model': CREATE_MODEL_PARAMS, 'Transform': CREATE_TRANSFORM_PARAMS}
 
 
 class TestSageMakerTransformOperator(unittest.TestCase):
@@ -70,32 +57,40 @@ class TestSageMakerTransformOperator(unittest.TestCase):
         self.sagemaker = SageMakerTransformOperator(
             task_id='test_sagemaker_operator',
             aws_conn_id='sagemaker_test_id',
-            config=config,
+            config=CONFIG,
             wait_for_completion=False,
             check_interval=5,
         )
 
-    def test_parse_config_integers(self):
-        self.sagemaker.parse_config_integers()
-        test_config = self.sagemaker.config['Transform']
-        assert test_config['TransformResources']['InstanceCount'] == int(
-            test_config['TransformResources']['InstanceCount']
-        )
-        assert test_config['MaxConcurrentTransforms'] == int(test_config['MaxConcurrentTransforms'])
-        assert test_config['MaxPayloadInMB'] == int(test_config['MaxPayloadInMB'])
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    @mock.patch.object(SageMakerHook, 'create_transform_job')
+    def test_integer_fields(self, mock_transform, mock_model, mock_client):
+        mock_transform.return_value = {
+            'TransformJobArn': 'test_arn',
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+        self.sagemaker.execute(None)
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+        for (key1, key2, *key3) in EXPECTED_INTEGER_FIELDS:
+            if key3:
+                (key3,) = key3
+                assert self.sagemaker.config[key1][key2][key3] == int(self.sagemaker.config[key1][key2][key3])
+            else:
+                self.sagemaker.config[key1][key2] == int(self.sagemaker.config[key1][key2])
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_model')
     @mock.patch.object(SageMakerHook, 'create_transform_job')
     def test_execute(self, mock_transform, mock_model, mock_client):
         mock_transform.return_value = {
-            'TransformJobArn': 'testarn',
+            'TransformJobArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
         self.sagemaker.execute(None)
-        mock_model.assert_called_once_with(create_model_params)
+        mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
         mock_transform.assert_called_once_with(
-            create_transform_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+            CREATE_TRANSFORM_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
         )
 
     @mock.patch.object(SageMakerHook, 'get_conn')
@@ -103,7 +98,7 @@ class TestSageMakerTransformOperator(unittest.TestCase):
     @mock.patch.object(SageMakerHook, 'create_transform_job')
     def test_execute_with_failure(self, mock_transform, mock_model, mock_client):
         mock_transform.return_value = {
-            'TransformJobArn': 'testarn',
+            'TransformJobArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 404},
         }
         with pytest.raises(AirflowException):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index cb63357c67..9d3efb0c49 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -17,6 +17,7 @@
 # under the License.
 
 import unittest
+from typing import Dict, List
 from unittest import mock
 
 import pytest
@@ -25,30 +26,21 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator
 
-role = 'arn:aws:iam:role/test-role'
+EXPECTED_INTEGER_FIELDS: List[List[str]] = [
+    ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
+    ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
+    ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
+    ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
+    ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
+]
 
-bucket = 'test-bucket'
-
-key = 'test/data'
-data_url = f's3://{bucket}/{key}'
-
-job_name = 'test-job-name'
-
-image = 'test-image'
-
-output_url = f's3://{bucket}/test/output'
-
-create_tuning_params = {
-    'HyperParameterTuningJobName': job_name,
+CREATE_TUNING_PARAMS: Dict = {
+    'HyperParameterTuningJobName': 'job_name',
     'HyperParameterTuningJobConfig': {
         'Strategy': 'Bayesian',
         'HyperParameterTuningJobObjective': {'Type': 'Maximize', 'MetricName': 'test_metric'},
         'ResourceLimits': {'MaxNumberOfTrainingJobs': '123', 'MaxParallelTrainingJobs': '123'},
-        'ParameterRanges': {
-            'IntegerParameterRanges': [
-                {'Name': 'k', 'MinValue': '2', 'MaxValue': '10'},
-            ]
-        },
+        'ParameterRanges': {'IntegerParameterRanges': [{'Name': 'k', 'MinValue': '2', 'MaxValue': '10'}]},
     },
     'TrainingJobDefinition': {
         'StaticHyperParameters': {
@@ -57,15 +49,15 @@ create_tuning_params = {
             'mini_batch_size': '500',
             'force_dense': 'True',
         },
-        'AlgorithmSpecification': {'TrainingImage': image, 'TrainingInputMode': 'File'},
-        'RoleArn': role,
+        'AlgorithmSpecification': {'TrainingImage': 'image_name', 'TrainingInputMode': 'File'},
+        'RoleArn': 'arn:aws:iam:role/test-role',
         'InputDataConfig': [
             {
                 'ChannelName': 'train',
                 'DataSource': {
                     'S3DataSource': {
                         'S3DataType': 'S3Prefix',
-                        'S3Uri': data_url,
+                        'S3Uri': 's3_uri',
                         'S3DataDistributionType': 'FullyReplicated',
                     }
                 },
@@ -73,9 +65,9 @@ create_tuning_params = {
                 'RecordWrapperType': 'None',
             }
         ],
-        'OutputDataConfig': {'S3OutputPath': output_url},
+        'OutputDataConfig': {'S3OutputPath': 'output_path'},
         'ResourceConfig': {'InstanceCount': '2', 'InstanceType': 'ml.c4.8xlarge', 'VolumeSizeInGB': '50'},
-        'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60),
+        'StoppingCondition': {'MaxRuntimeInSeconds': '3600'},
     },
 }
 
@@ -84,47 +76,32 @@ class TestSageMakerTuningOperator(unittest.TestCase):
     def setUp(self):
         self.sagemaker = SageMakerTuningOperator(
             task_id='test_sagemaker_operator',
-            aws_conn_id='sagemaker_test_conn',
-            config=create_tuning_params,
+            config=CREATE_TUNING_PARAMS,
             wait_for_completion=False,
             check_interval=5,
         )
 
-    def test_parse_config_integers(self):
-        self.sagemaker.parse_config_integers()
-        assert self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount'] == int(
-            self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['InstanceCount']
-        )
-        assert self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB'] == int(
-            self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']['VolumeSizeInGB']
-        )
-        assert self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
-            'MaxNumberOfTrainingJobs'
-        ] == int(
-            self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
-                'MaxNumberOfTrainingJobs'
-            ]
-        )
-        assert self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
-            'MaxParallelTrainingJobs'
-        ] == int(
-            self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits'][
-                'MaxParallelTrainingJobs'
-            ]
-        )
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_tuning_job')
+    def test_integer_fields(self, mock_tuning, mock_client):
+        mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+        self.sagemaker.execute(None)
+        assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
+        for (key1, key2, key3) in EXPECTED_INTEGER_FIELDS:
+            assert self.sagemaker.config[key1][key2][key3] == int(self.sagemaker.config[key1][key2][key3])
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_tuning_job')
     def test_execute(self, mock_tuning, mock_client):
-        mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
+        mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 200}}
         self.sagemaker.execute(None)
         mock_tuning.assert_called_once_with(
-            create_tuning_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None
+            CREATE_TUNING_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
         )
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_tuning_job')
     def test_execute_with_failure(self, mock_tuning, mock_client):
-        mock_tuning.return_value = {'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
+        mock_tuning.return_value = {'TrainingJobArn': 'test_arn', 'ResponseMetadata': {'HTTPStatusCode': 404}}
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)