You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/09/22 11:24:18 UTC

[airflow] branch main updated: Fix SageMakerEndpointConfigOperator's return value (#26541)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c7b4cbf62 Fix SageMakerEndpointConfigOperator's return value (#26541)
0c7b4cbf62 is described below

commit 0c7b4cbf62925cf359648eff146f9f4b0c6e7775
Author: D. Ferruzzi <fe...@amazon.com>
AuthorDate: Thu Sep 22 04:24:07 2022 -0700

    Fix SageMakerEndpointConfigOperator's return value (#26541)
    
    * Fix SageMakerEndpointConfigOperator's JSON return value
    
    * Fix the unit tests for the changed operator
---
 airflow/providers/amazon/aws/operators/sagemaker.py                | 6 +++++-
 .../amazon/aws/operators/test_sagemaker_endpoint_config.py         | 7 +++++--
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index a85c0ca5ba..541b6deffa 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -237,7 +237,11 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
         if response['ResponseMetadata']['HTTPStatusCode'] != 200:
             raise AirflowException(f'Sagemaker endpoint config creation failed: {response}')
         else:
-            return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
+            return {
+                'EndpointConfig': serialize(
+                    self.hook.describe_endpoint_config(self.config['EndpointConfigName'])
+                )
+            }
 
 
 class SageMakerEndpointOperator(SageMakerBaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index c9c3bb5c74..c5f8085964 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -24,6 +24,7 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator
 
 CREATE_ENDPOINT_CONFIG_PARAMS: dict = {
@@ -50,7 +51,8 @@ class TestSageMakerEndpointConfigOperator(unittest.TestCase):
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_endpoint_config')
-    def test_integer_fields(self, mock_model, mock_client):
+    @mock.patch.object(sagemaker, 'serialize', return_value="")
+    def test_integer_fields(self, serialize, mock_model, mock_client):
         mock_model.return_value = {
             'EndpointConfigArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},
@@ -62,7 +64,8 @@ class TestSageMakerEndpointConfigOperator(unittest.TestCase):
 
     @mock.patch.object(SageMakerHook, 'get_conn')
     @mock.patch.object(SageMakerHook, 'create_endpoint_config')
-    def test_execute(self, mock_model, mock_client):
+    @mock.patch.object(sagemaker, 'serialize', return_value="")
+    def test_execute(self, serialize, mock_model, mock_client):
         mock_model.return_value = {
             'EndpointConfigArn': 'test_arn',
             'ResponseMetadata': {'HTTPStatusCode': 200},