You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/05 15:32:13 UTC
[airflow] branch master updated: S3DataSource is not required
(#14220)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 697abf3 S3DataSource is not required (#14220)
697abf3 is described below
commit 697abf399de107eb4bafb730acf23d868e107a08
Author: Marcus Moyses <ma...@gmail.com>
AuthorDate: Fri Mar 5 07:31:58 2021 -0800
S3DataSource is not required (#14220)
A channel's datasource can be a file system or S3. The code shouldn't assume all datasources
are S3 datasources and try to verify the URI.
Documentation showing the filed "S3DataSource" is not required: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataSource.html
---
airflow/providers/amazon/aws/hooks/sagemaker.py | 9 +++--
tests/providers/amazon/aws/hooks/test_sagemaker.py | 46 +++++++++++++++++++++-
2 files changed, 51 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index d6548ad..756d888 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -229,7 +229,8 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
"""
if "InputDataConfig" in training_config:
for channel in training_config['InputDataConfig']:
- self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
+ if "S3DataSource" in channel['DataSource']:
+ self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
def check_tuning_config(self, tuning_config: dict) -> None:
"""
@@ -240,7 +241,8 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
:return: None
"""
for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
- self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
+ if "S3DataSource" in channel['DataSource']:
+ self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
def get_log_conn(self):
"""
@@ -421,7 +423,8 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
:type max_ingestion_time: int
:return: A response to transform job creation
"""
- self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
+ if "S3DataSource" in config['TransformInput']['DataSource']:
+ self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
response = self.get_conn().create_transform_job(**config)
if wait_for_completion:
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index b714ab3..786e6e3 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -71,7 +71,20 @@ create_training_params = {
},
'CompressionType': 'None',
'RecordWrapperType': 'None',
- }
+ },
+ {
+ 'ChannelName': 'train_fs',
+ 'DataSource': {
+ 'FileSystemDataSource': {
+ 'DirectoryPath': '/tmp',
+ 'FileSystemAccessMode': 'ro',
+ 'FileSystemId': 'fs-abc',
+ 'FileSystemType': 'FSxLustre',
+ }
+ },
+ 'CompressionType': 'None',
+ 'RecordWrapperType': 'None',
+ },
],
}
@@ -109,6 +122,26 @@ create_transform_params = {
'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 123},
}
+create_transform_params_fs = {
+ 'TransformJobName': job_name,
+ 'ModelName': model_name,
+ 'BatchStrategy': 'MultiRecord',
+ 'TransformInput': {
+ 'DataSource': {
+ 'FileSystemDataSource': {
+ 'DirectoryPath': '/tmp',
+ 'FileSystemAccessMode': 'ro',
+ 'FileSystemId': 'fs-abc',
+ 'FileSystemType': 'FSxLustre',
+ }
+ }
+ },
+ 'TransformOutput': {
+ 'S3OutputPath': output_url,
+ },
+ 'TransformResources': {'InstanceType': 'ml.m4.xlarge', 'InstanceCount': 123},
+}
+
create_model_params = {
'ModelName': model_name,
'PrimaryContainer': {
@@ -356,6 +389,17 @@ class TestSageMakerHook(unittest.TestCase):
assert response == test_arn_return
@mock.patch.object(SageMakerHook, 'get_conn')
+ def test_create_transform_job_fs(self, mock_client):
+ mock_session = mock.Mock()
+ attrs = {'create_transform_job.return_value': test_arn_return}
+ mock_session.configure_mock(**attrs)
+ mock_client.return_value = mock_session
+ hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
+ response = hook.create_transform_job(create_transform_params_fs, wait_for_completion=False)
+ mock_session.create_transform_job.assert_called_once_with(**create_transform_params_fs)
+ assert response == test_arn_return
+
+ @mock.patch.object(SageMakerHook, 'get_conn')
def test_create_model(self, mock_client):
mock_session = mock.Mock()
attrs = {'create_model.return_value': test_arn_return}