You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/05 15:32:13 UTC

[airflow] branch master updated: S3DataSource is not required (#14220)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 697abf3  S3DataSource is not required (#14220)
697abf3 is described below

commit 697abf399de107eb4bafb730acf23d868e107a08
Author: Marcus Moyses <ma...@gmail.com>
AuthorDate: Fri Mar 5 07:31:58 2021 -0800

    S3DataSource is not required (#14220)
    
    A channel's datasource can be a file system or S3. The code shouldn't assume all datasources
    are S3 datasources and try to verify the URI.
    
    Documentation showing the filed "S3DataSource" is not required: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataSource.html
---
 airflow/providers/amazon/aws/hooks/sagemaker.py    |  9 +++--
 tests/providers/amazon/aws/hooks/test_sagemaker.py | 46 +++++++++++++++++++++-
 2 files changed, 51 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index d6548ad..756d888 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -229,7 +229,8 @@ class SageMakerHook(AwsBaseHook):  # pylint: disable=too-many-public-methods
         """
         if "InputDataConfig" in training_config:
             for channel in training_config['InputDataConfig']:
-                self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
+                if "S3DataSource" in channel['DataSource']:
+                    self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
 
     def check_tuning_config(self, tuning_config: dict) -> None:
         """
@@ -240,7 +241,8 @@ class SageMakerHook(AwsBaseHook):  # pylint: disable=too-many-public-methods
         :return: None
         """
         for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
-            self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
+            if "S3DataSource" in channel['DataSource']:
+                self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
 
     def get_log_conn(self):
         """
@@ -421,7 +423,8 @@ class SageMakerHook(AwsBaseHook):  # pylint: disable=too-many-public-methods
         :type max_ingestion_time: int
         :return: A response to transform job creation
         """
-        self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
+        if "S3DataSource" in config['TransformInput']['DataSource']:
+            self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
 
         response = self.get_conn().create_transform_job(**config)
         if wait_for_completion:
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index b714ab3..786e6e3 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -71,7 +71,20 @@ create_training_params = {
             },
             'CompressionType': 'None',
             'RecordWrapperType': 'None',
-        }
+        },
+        {
+            'ChannelName': 'train_fs',
+            'DataSource': {
+                'FileSystemDataSource': {
+                    'DirectoryPath': '/tmp',
+                    'FileSystemAccessMode': 'ro',
+                    'FileSystemId': 'fs-abc',
+                    'FileSystemType': 'FSxLustre',
+                }
+            },
+            'CompressionType': 'None',
+            'RecordWrapperType': 'None',
+        },
     ],
 }
 
@@ -109,6 +122,26 @@ create_transform_params = {
     'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 123},
 }
 
+create_transform_params_fs = {
+    'TransformJobName': job_name,
+    'ModelName': model_name,
+    'BatchStrategy': 'MultiRecord',
+    'TransformInput': {
+        'DataSource': {
+            'FileSystemDataSource': {
+                'DirectoryPath': '/tmp',
+                'FileSystemAccessMode': 'ro',
+                'FileSystemId': 'fs-abc',
+                'FileSystemType': 'FSxLustre',
+            }
+        }
+    },
+    'TransformOutput': {
+        'S3OutputPath': output_url,
+    },
+    'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 123},
+}
+
 create_model_params = {
     'ModelName': model_name,
     'PrimaryContainer': {
@@ -356,6 +389,17 @@ class TestSageMakerHook(unittest.TestCase):
         assert response == test_arn_return
 
     @mock.patch.object(SageMakerHook, 'get_conn')
+    def test_create_transform_job_fs(self, mock_client):
+        mock_session = mock.Mock()
+        attrs = {'create_transform_job.return_value': test_arn_return}
+        mock_session.configure_mock(**attrs)
+        mock_client.return_value = mock_session
+        hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
+        response = hook.create_transform_job(create_transform_params_fs, wait_for_completion=False)
+        mock_session.create_transform_job.assert_called_once_with(**create_transform_params_fs)
+        assert response == test_arn_return
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
     def test_create_model(self, mock_client):
         mock_session = mock.Mock()
         attrs = {'create_model.return_value': test_arn_return}