You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/09/22 11:24:18 UTC
[airflow] branch main updated: Fix SageMakerEndpointConfigOperator's return value (#26541)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0c7b4cbf62 Fix SageMakerEndpointConfigOperator's return value (#26541)
0c7b4cbf62 is described below
commit 0c7b4cbf62925cf359648eff146f9f4b0c6e7775
Author: D. Ferruzzi <fe...@amazon.com>
AuthorDate: Thu Sep 22 04:24:07 2022 -0700
Fix SageMakerEndpointConfigOperator's return value (#26541)
* Fix SageMakerEndpointConfigOperator's JSON return value
* Fix the unit tests for the changed operator
---
airflow/providers/amazon/aws/operators/sagemaker.py | 6 +++++-
.../amazon/aws/operators/test_sagemaker_endpoint_config.py | 7 +++++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index a85c0ca5ba..541b6deffa 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -237,7 +237,11 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker endpoint config creation failed: {response}')
else:
- return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
+ return {
+ 'EndpointConfig': serialize(
+ self.hook.describe_endpoint_config(self.config['EndpointConfigName'])
+ )
+ }
class SageMakerEndpointOperator(SageMakerBaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
index c9c3bb5c74..c5f8085964 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint_config.py
@@ -24,6 +24,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator
CREATE_ENDPOINT_CONFIG_PARAMS: dict = {
@@ -50,7 +51,8 @@ class TestSageMakerEndpointConfigOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
- def test_integer_fields(self, mock_model, mock_client):
+ @mock.patch.object(sagemaker, 'serialize', return_value="")
+ def test_integer_fields(self, serialize, mock_model, mock_client):
mock_model.return_value = {
'EndpointConfigArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},
@@ -62,7 +64,8 @@ class TestSageMakerEndpointConfigOperator(unittest.TestCase):
@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'create_endpoint_config')
- def test_execute(self, mock_model, mock_client):
+ @mock.patch.object(sagemaker, 'serialize', return_value="")
+ def test_execute(self, serialize, mock_model, mock_client):
mock_model.return_value = {
'EndpointConfigArn': 'test_arn',
'ResponseMetadata': {'HTTPStatusCode': 200},