You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/11/04 12:50:22 UTC

[airflow] branch main updated: Code quality improvements on sagemaker operators/hook (#27453)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 531f2d2116 Code quality improvements on sagemaker operators/hook (#27453)
531f2d2116 is described below

commit 531f2d211658e13583189b65470d164af81bc40a
Author: Raphaƫl Vandon <11...@users.noreply.github.com>
AuthorDate: Fri Nov 4 05:50:12 2022 -0700

    Code quality improvements on sagemaker operators/hook (#27453)
---
 .../providers/amazon/aws/operators/sagemaker.py    |  6 ----
 tests/providers/amazon/aws/hooks/test_sagemaker.py | 40 ++++------------------
 .../amazon/aws/operators/test_sagemaker_model.py   | 37 ++++++--------------
 3 files changed, 17 insertions(+), 66 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 5fb9d93372..053ba89e01 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -223,7 +223,6 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
         **kwargs,
     ):
         super().__init__(config=config, **kwargs)
-        self.config = config
         self.aws_conn_id = aws_conn_id
 
     def _create_integer_fields(self) -> None:
@@ -304,7 +303,6 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         **kwargs,
     ):
         super().__init__(config=config, **kwargs)
-        self.config = config
         self.aws_conn_id = aws_conn_id
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
@@ -433,7 +431,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         **kwargs,
     ):
         super().__init__(config=config, **kwargs)
-        self.config = config
         self.aws_conn_id = aws_conn_id
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
@@ -546,7 +543,6 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
         **kwargs,
     ):
         super().__init__(config=config, **kwargs)
-        self.config = config
         self.aws_conn_id = aws_conn_id
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
@@ -609,7 +605,6 @@ class SageMakerModelOperator(SageMakerBaseOperator):
 
     def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
         super().__init__(config=config, **kwargs)
-        self.config = config
         self.aws_conn_id = aws_conn_id
 
     def expand_role(self) -> None:
@@ -746,7 +741,6 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
 
     def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
         super().__init__(config=config, **kwargs)
-        self.config = config
         self.aws_conn_id = aws_conn_id
 
     def execute(self, context: Context) -> Any:
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index 50f3a96c1f..72e58cf697 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -18,7 +18,6 @@
 from __future__ import annotations
 
 import time
-import unittest
 from datetime import datetime
 from unittest import mock
 from unittest.mock import patch
@@ -241,7 +240,7 @@ test_evaluation_config = {
 }
 
 
-class TestSageMakerHook(unittest.TestCase):
+class TestSageMakerHook:
     @mock.patch.object(AwsLogsHook, "get_log_events")
     def test_multi_stream_iter(self, mock_log_stream):
         event = {"timestamp": 1}
@@ -298,8 +297,7 @@ class TestSageMakerHook(unittest.TestCase):
         hook.check_tuning_config(create_tuning_params)
         mock_check_url.assert_called_once_with(data_url)
 
-    @mock.patch.object(SageMakerHook, "get_client_type")
-    def test_conn(self, mock_get_client_type):
+    def test_conn(self):
         hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
         assert hook.aws_conn_id == "sagemaker_test_conn_id"
 
@@ -564,9 +562,10 @@ class TestSageMakerHook(unittest.TestCase):
         )
         assert response == (LogState.JOB_COMPLETE, {}, 50)
 
+    @pytest.mark.parametrize("log_state", [LogState.JOB_COMPLETE, LogState.COMPLETE])
     @mock.patch.object(AwsLogsHook, "get_conn")
     @mock.patch.object(SageMakerHook, "get_conn")
-    def test_describe_training_job_with_logs_job_complete(self, mock_client, mock_log_client):
+    def test_describe_training_job_with_complete_states(self, mock_client, mock_log_client, log_state):
         mock_session = mock.Mock()
         mock_log_session = mock.Mock()
         attrs = {"describe_training_job.return_value": DESCRIBE_TRAINING_COMPLETED_RETURN}
@@ -584,33 +583,7 @@ class TestSageMakerHook(unittest.TestCase):
             positions={},
             stream_names=[],
             instance_count=1,
-            state=LogState.JOB_COMPLETE,
-            last_description={},
-            last_describe_job_call=0,
-        )
-        assert response == (LogState.COMPLETE, {}, 0)
-
-    @mock.patch.object(AwsLogsHook, "get_conn")
-    @mock.patch.object(SageMakerHook, "get_conn")
-    def test_describe_training_job_with_logs_complete(self, mock_client, mock_log_client):
-        mock_session = mock.Mock()
-        mock_log_session = mock.Mock()
-        attrs = {"describe_training_job.return_value": DESCRIBE_TRAINING_COMPLETED_RETURN}
-        log_attrs = {
-            "describe_log_streams.side_effect": LIFECYCLE_LOG_STREAMS,
-            "get_log_events.side_effect": STREAM_LOG_EVENTS,
-        }
-        mock_session.configure_mock(**attrs)
-        mock_client.return_value = mock_session
-        mock_log_session.configure_mock(**log_attrs)
-        mock_log_client.return_value = mock_log_session
-        hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
-        response = hook.describe_training_job_with_log(
-            job_name=job_name,
-            positions={},
-            stream_names=[],
-            instance_count=1,
-            state=LogState.COMPLETE,
+            state=log_state,
             last_description={},
             last_describe_job_call=0,
         )
@@ -649,9 +622,8 @@ class TestSageMakerHook(unittest.TestCase):
         assert mock_session.describe_training_job.call_count == 1
 
     @mock.patch.object(SageMakerHook, "get_conn")
-    def test_find_processing_job_by_name(self, mock_conn):
+    def test_find_processing_job_by_name(self, _):
         hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
-        mock_conn.describe_processing_job.return_value = {}
         ret = hook.find_processing_job_by_name("existing_job")
         assert ret
 
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_model.py b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
index f134b4ab4e..8716221a1f 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_model.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_model.py
@@ -24,7 +24,6 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import (
     SageMakerDeleteModelOperator,
     SageMakerModelOperator,
@@ -46,40 +45,26 @@ class TestSageMakerModelOperator(unittest.TestCase):
     def setUp(self):
         self.sagemaker = SageMakerModelOperator(task_id="test_sagemaker_operator", config=CREATE_MODEL_PARAMS)
 
-    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "describe_model", return_value="")
     @mock.patch.object(SageMakerHook, "create_model")
-    @mock.patch.object(sagemaker, "serialize", return_value="")
-    def test_integer_fields(self, serialize, mock_model, mock_client):
-        mock_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}
+    def test_execute(self, mock_create_model, _):
+        mock_create_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}
         self.sagemaker.execute(None)
+        mock_create_model.assert_called_once_with(CREATE_MODEL_PARAMS)
         assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
 
-    @mock.patch.object(SageMakerHook, "get_conn")
     @mock.patch.object(SageMakerHook, "create_model")
-    @mock.patch.object(sagemaker, "serialize", return_value="")
-    def test_execute(self, serialize, mock_model, mock_client):
-        mock_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}
-        self.sagemaker.execute(None)
-        mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
-
-    @mock.patch.object(SageMakerHook, "get_conn")
-    @mock.patch.object(SageMakerHook, "create_model")
-    def test_execute_with_failure(self, mock_model, mock_client):
-        mock_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}
+    def test_execute_with_failure(self, mock_create_model):
+        mock_create_model.return_value = {"ModelArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
 
 
 class TestSageMakerDeleteModelOperator(unittest.TestCase):
-    def setUp(self):
-        delete_model_params = {"ModelName": "model_name"}
-        self.sagemaker = SageMakerDeleteModelOperator(
-            task_id="test_sagemaker_operator", config=delete_model_params
-        )
-
-    @mock.patch.object(SageMakerHook, "get_conn")
     @mock.patch.object(SageMakerHook, "delete_model")
-    def test_execute(self, delete_model, mock_client):
-        delete_model.return_value = None
-        self.sagemaker.execute(None)
+    def test_execute(self, delete_model):
+        op = SageMakerDeleteModelOperator(
+            task_id="test_sagemaker_operator", config={"ModelName": "model_name"}
+        )
+        op.execute(None)
         delete_model.assert_called_once_with(model_name="model_name")