You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/11/04 12:50:22 UTC
[airflow] branch main updated: Code quality improvements on sagemaker operators/hook (#27453)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 531f2d2116 Code quality improvements on sagemaker operators/hook (#27453)
531f2d2116 is described below
commit 531f2d211658e13583189b65470d164af81bc40a
Author: Raphaƫl Vandon <11...@users.noreply.github.com>
AuthorDate: Fri Nov 4 05:50:12 2022 -0700
Code quality improvements on sagemaker operators/hook (#27453)
---
.../providers/amazon/aws/operators/sagemaker.py | 6 ----
tests/providers/amazon/aws/hooks/test_sagemaker.py | 40 ++++------------------
.../amazon/aws/operators/test_sagemaker_model.py | 37 ++++++--------------
3 files changed, 17 insertions(+), 66 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 5fb9d93372..053ba89e01 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -223,7 +223,6 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.config = config
self.aws_conn_id = aws_conn_id
def _create_integer_fields(self) -> None:
@@ -304,7 +303,6 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.config = config
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
@@ -433,7 +431,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.config = config
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
@@ -546,7 +543,6 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
**kwargs,
):
super().__init__(config=config, **kwargs)
- self.config = config
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
@@ -609,7 +605,6 @@ class SageMakerModelOperator(SageMakerBaseOperator):
def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(config=config, **kwargs)
- self.config = config
self.aws_conn_id = aws_conn_id
def expand_role(self) -> None:
@@ -746,7 +741,6 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(config=config, **kwargs)
- self.config = config
self.aws_conn_id = aws_conn_id
def execute(self, context: Context) -> Any:
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index 50f3a96c1f..72e58cf697 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import time
-import unittest
from datetime import datetime
from unittest import mock
from unittest.mock import patch
@@ -241,7 +240,7 @@ test_evaluation_config = {
}
-class TestSageMakerHook(unittest.TestCase):
+class TestSageMakerHook:
@mock.patch.object(AwsLogsHook, "get_log_events")
def test_multi_stream_iter(self, mock_log_stream):
event = {"timestamp": 1}
@@ -298,8 +297,7 @@ class TestSageMakerHook(unittest.TestCase):
hook.check_tuning_config(create_tuning_params)
mock_check_url.assert_called_once_with(data_url)
- @mock.patch.object(SageMakerHook, "get_client_type")
- def test_conn(self, mock_get_client_type):
+ def test_conn(self):
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
assert hook.aws_conn_id == "sagemaker_test_conn_id"
@@ -564,9 +562,10 @@ class TestSageMakerHook(unittest.TestCase):
)
assert response == (LogState.JOB_COMPLETE, {}, 50)
+ @pytest.mark.parametrize("log_state", [LogState.JOB_COMPLETE, LogState.COMPLETE])
@mock.patch.object(AwsLogsHook, "get_conn")
@mock.patch.object(SageMakerHook, "get_conn")
- def test_describe_training_job_with_logs_job_complete(self, mock_client, mock_log_client):
+ def test_describe_training_job_with_complete_states(self, mock_client, mock_log_client, log_state):
mock_session = mock.Mock()
mock_log_session = mock.Mock()
attrs = {"describe_training_job.return_value": DESCRIBE_TRAINING_COMPLETED_RETURN}
@@ -584,33 +583,7 @@ class TestSageMakerHook(unittest.TestCase):
positions={},
stream_names=[],
instance_count=1,
- state=LogState.JOB_COMPLETE,
- last_description={},
- last_describe_job_call=0,
- )
- assert response == (LogState.COMPLETE, {}, 0)
-
- @mock.patch.object(AwsLogsHook, "get_conn")
- @mock.patch.object(SageMakerHook, "get_conn")
- def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_client):
- mock_session = mock.Mock()
- mock_log_session = mock.Mock()
- attrs = {"describe_training_job.return_value": DESCRIBE_TRAINING_COMPLETED_RETURN}
- log_attrs = {
- "describe_log_streams.side_effect": LIFECYCLE_LOG_STREAMS,
- "get_log_events.side_effect": STREAM_LOG_EVENTS,
- }
- mock_session.configure_mock(**attrs)
- mock_client.return_value = mock_session
- mock_log_session.configure_mock(**log_attrs)
- mock_log_client.return_value = mock_log_session
- hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
- response = hook.describe_training_job_with_log(
- job_name=job_name,
- positions={},
- stream_names=[],
- instance_count=1,
- state=LogState.COMPLETE,
+ state=log_state,
last_description={},
last_describe_job_call=0,
)
@@ -649,9 +622,8 @@ class TestSageMakerHook(unittest.TestCase):
assert mock_session.describe_training_job.call_count == 1
@mock.patch.object(SageMakerHook, "get_conn")
- def test_find_processing_job_by_name(self, mock_conn):
+ def test_find_processing_job_by_name(self, _):
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
- mock_conn.describe_processing_job.return_value = {}
ret = hook.find_processing_job_by_name("existing_job")
assert ret
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index f134b4ab4e..8716221a1f 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -24,7 +24,6 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerDeleteModelOperator,
SageMakerModelOperator,
@@ -46,40 +45,26 @@ class TestSageMakerModelOperator(unittest.TestCase):
def setUp(self):
self.sagemaker = SageMakerModelOperator(task_id="test_sagemaker_operator", config=CREATE_MODEL_PARAMS)
- @mock.patch.object(SageMakerHook, "get_conn")
+ @mock.patch.object(SageMakerHook, "describe_model", return_value="")
@mock.patch.object(SageMakerHook, "create_model")
- @mock.patch.object(sagemaker, "serialize", return_value="")
- def test_integer_fields(self, serialize, mock_model, mock_client):
- mock_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}
+ def test_execute(self, mock_create_model, _):
+ mock_create_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}
self.sagemaker.execute(None)
+ mock_create_model.assert_called_once_with(CREATE_MODEL_PARAMS)
assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
- @mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "create_model")
- @mock.patch.object(sagemaker, "serialize", return_value="")
- def test_execute(self, serialize, mock_model, mock_client):
- mock_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}
- self.sagemaker.execute(None)
- mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
-
- @mock.patch.object(SageMakerHook, "get_conn")
- @mock.patch.object(SageMakerHook, "create_model")
- def test_execute_with_failure(self, mock_model, mock_client):
- mock_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}
+ def test_execute_with_failure(self, mock_create_model):
+ mock_create_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
class TestSageMakerDeleteModelOperator(unittest.TestCase):
- def setUp(self):
- delete_model_params = {"ModelName": "model_name"}
- self.sagemaker = SageMakerDeleteModelOperator(
- task_id="test_sagemaker_operator", config=delete_model_params
- )
-
- @mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "delete_model")
- def test_execute(self, delete_model, mock_client):
- delete_model.return_value = None
- self.sagemaker.execute(None)
+ def test_execute(self, delete_model):
+ op = SageMakerDeleteModelOperator(
+ task_id="test_sagemaker_operator", config={"ModelName": "model_name"}
+ )
+ op.execute(None)
delete_model.assert_called_once_with(model_name="model_name")