You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/16 21:20:09 UTC
[airflow] branch main updated: Migrate datasync sample dag to system tests (AIP-47) (#24354)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 77f51dcf72 Migrate datasync sample dag to system tests (AIP-47) (#24354)
77f51dcf72 is described below
commit 77f51dcf72eca01721379c3fe59d20ba701d7db7
Author: Vincent <97...@users.noreply.github.com>
AuthorDate: Sat Jul 16 17:20:00 2022 -0400
Migrate datasync sample dag to system tests (AIP-47) (#24354)
* Migrate datasync sample dag to system tests (AIP-47)
* Use SystemTestContextBuilder to pass along data
---
.../amazon/aws/example_dags/example_datasync.py | 81 -------
airflow/providers/amazon/aws/operators/datasync.py | 6 +
.../operators/datasync.rst | 6 +-
.../amazon/aws/operators/test_datasync.py | 17 +-
.../providers/amazon/aws/example_datasync.py | 239 +++++++++++++++++++++
5 files changed, 264 insertions(+), 85 deletions(-)
diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync.py b/airflow/providers/amazon/aws/example_dags/example_datasync.py
deleted file mode 100644
index 09c474079e..0000000000
--- a/airflow/providers/amazon/aws/example_dags/example_datasync.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-import re
-from datetime import datetime
-from os import getenv
-
-from airflow import models
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
-
-TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn")
-SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
-DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
-CREATE_TASK_KWARGS = json.loads(getenv("CREATE_TASK_KWARGS", '{"Name": "Created by Airflow"}'))
-CREATE_SOURCE_LOCATION_KWARGS = json.loads(getenv("CREATE_SOURCE_LOCATION_KWARGS", '{}'))
-default_destination_location_kwargs = """\
-{"S3BucketArn": "arn:aws:s3:::mybucket",
- "S3Config": {"BucketAccessRoleArn":
- "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"}
-}"""
-CREATE_DESTINATION_LOCATION_KWARGS = json.loads(
- getenv("CREATE_DESTINATION_LOCATION_KWARGS", re.sub(r"[\s+]", '', default_destination_location_kwargs))
-)
-UPDATE_TASK_KWARGS = json.loads(getenv("UPDATE_TASK_KWARGS", '{"Name": "Updated by Airflow"}'))
-
-with models.DAG(
- "example_datasync",
- schedule_interval=None, # Override to match your needs
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_operator_datasync_specific_task]
- # Execute a specific task
- datasync_specific_task = DataSyncOperator(task_id="datasync_specific_task", task_arn=TASK_ARN)
- # [END howto_operator_datasync_specific_task]
-
- # [START howto_operator_datasync_search_task]
- # Search and execute a task
- datasync_search_task = DataSyncOperator(
- task_id="datasync_search_task",
- source_location_uri=SOURCE_LOCATION_URI,
- destination_location_uri=DESTINATION_LOCATION_URI,
- )
- # [END howto_operator_datasync_search_task]
-
- # [START howto_operator_datasync_create_task]
- # Create a task (the task does not exist)
- datasync_create_task = DataSyncOperator(
- task_id="datasync_create_task",
- source_location_uri=SOURCE_LOCATION_URI,
- destination_location_uri=DESTINATION_LOCATION_URI,
- create_task_kwargs=CREATE_TASK_KWARGS,
- create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS,
- create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS,
- update_task_kwargs=UPDATE_TASK_KWARGS,
- delete_task_after_execution=True,
- )
- # [END howto_operator_datasync_create_task]
-
- chain(
- datasync_specific_task,
- datasync_search_task,
- datasync_create_task,
- )
diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py
index 5a0c860711..ffe864d5ce 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -49,6 +49,7 @@ class DataSyncOperator(BaseOperator):
consecutive calls to check TaskExecution status.
:param max_iterations: Maximum number of
consecutive calls to check TaskExecution status.
+ :param wait_for_completion: If True, wait for the task execution to reach a final state
:param task_arn: AWS DataSync TaskArn to use. If None, then this operator will
attempt to either search for an existing Task or attempt to create a new Task.
:param source_location_uri: Source location URI to search for. All DataSync
@@ -122,6 +123,7 @@ class DataSyncOperator(BaseOperator):
aws_conn_id: str = "aws_default",
wait_interval_seconds: int = 30,
max_iterations: int = 60,
+ wait_for_completion: bool = True,
task_arn: Optional[str] = None,
source_location_uri: Optional[str] = None,
destination_location_uri: Optional[str] = None,
@@ -141,6 +143,7 @@ class DataSyncOperator(BaseOperator):
self.aws_conn_id = aws_conn_id
self.wait_interval_seconds = wait_interval_seconds
self.max_iterations = max_iterations
+ self.wait_for_completion = wait_for_completion
self.task_arn = task_arn
@@ -346,6 +349,9 @@ class DataSyncOperator(BaseOperator):
self.task_execution_arn = hook.start_task_execution(self.task_arn, **self.task_execution_kwargs)
self.log.info("Started TaskExecutionArn %s", self.task_execution_arn)
+ if not self.wait_for_completion:
+ return
+
# Wait for task execution to complete
self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn)
try:
diff --git a/docs/apache-airflow-providers-amazon/operators/datasync.rst b/docs/apache-airflow-providers-amazon/operators/datasync.rst
index 5380aac5fa..b0b6100f44 100644
--- a/docs/apache-airflow-providers-amazon/operators/datasync.rst
+++ b/docs/apache-airflow-providers-amazon/operators/datasync.rst
@@ -59,7 +59,7 @@ Execute a task
To execute a specific task, you can pass the ``task_arn`` to the operator.
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_datasync.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_datasync.py
:language: python
:dedent: 4
:start-after: [START howto_operator_datasync_specific_task]
@@ -73,7 +73,7 @@ If one task is found, this one will be executed.
If more than one task is found, the operator will raise an Exception. To avoid this, you can set
``allow_random_task_choice`` to ``True`` to randomly choose from candidate tasks.
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_datasync.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_datasync.py
:language: python
:dedent: 4
:start-after: [START howto_operator_datasync_search_task]
@@ -92,7 +92,7 @@ existing Task was found. If these are left to their default value (None) then no
Also, because ``delete_task_after_execution`` is set to ``True``, the task will be deleted
from AWS DataSync after it completes successfully.
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_datasync.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_datasync.py
:language: python
:dedent: 4
:start-after: [START howto_operator_datasync_create_task]
diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py
index 30caf91b18..30a8666cfe 100644
--- a/tests/providers/amazon/aws/operators/test_datasync.py
+++ b/tests/providers/amazon/aws/operators/test_datasync.py
@@ -619,7 +619,9 @@ class TestDataSyncOperatorUpdate(DataSyncTestCaseBase):
@mock_datasync
@mock.patch.object(DataSyncHook, "get_conn")
class TestDataSyncOperator(DataSyncTestCaseBase):
- def set_up_operator(self, task_id="test_datasync_task_operator", task_arn="self"):
+ def set_up_operator(
+ self, task_id="test_datasync_task_operator", task_arn="self", wait_for_completion=True
+ ):
if task_arn == "self":
task_arn = self.task_arn
# Create operator
@@ -627,6 +629,7 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
task_id=task_id,
dag=self.dag,
wait_interval_seconds=0,
+ wait_for_completion=wait_for_completion,
task_arn=task_arn,
)
@@ -693,6 +696,18 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
# ### Check mocks:
mock_get_conn.assert_called()
+ @mock.patch.object(DataSyncHook, "wait_for_task_execution")
+ def test_execute_task_without_wait_for_completion(self, mock_wait, mock_get_conn):
+ self.set_up_operator(wait_for_completion=False)
+
+ # Execute the task
+ result = self.datasync.execute(None)
+ assert result is not None
+ task_execution_arn = result["TaskExecutionArn"]
+ assert task_execution_arn is not None
+
+ mock_wait.assert_not_called()
+
@mock.patch.object(DataSyncHook, "wait_for_task_execution")
def test_failed_task(self, mock_wait, mock_get_conn):
# ### Set up mocks:
diff --git a/tests/system/providers/amazon/aws/example_datasync.py b/tests/system/providers/amazon/aws/example_datasync.py
new file mode 100644
index 0000000000..ac1e791f06
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_datasync.py
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+import boto3
+
+from airflow import models
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
+from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
+
+DAG_ID = 'example_datasync'
+
+# Externally fetched variables:
+ROLE_ARN_KEY = 'ROLE_ARN'
+
+sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+
+def get_s3_bucket_arn(bucket_name):
+ return f'arn:aws:s3:::{bucket_name}'
+
+
+def create_location(bucket_name, role_arn):
+ client = boto3.client('datasync')
+ response = client.create_location_s3(
+ Subdirectory='test',
+ S3BucketArn=get_s3_bucket_arn(bucket_name),
+ S3Config={
+ 'BucketAccessRoleArn': role_arn,
+ },
+ )
+ return response['LocationArn']
+
+
+@task
+def create_source_location(bucket_source, role_arn):
+ return create_location(bucket_source, role_arn)
+
+
+@task
+def create_destination_location(bucket_destination, role_arn):
+ return create_location(bucket_destination, role_arn)
+
+
+@task
+def create_task(**kwargs):
+ client = boto3.client('datasync')
+ response = client.create_task(
+ SourceLocationArn=kwargs['ti'].xcom_pull('create_source_location'),
+ DestinationLocationArn=kwargs['ti'].xcom_pull('create_destination_location'),
+ )
+ return response['TaskArn']
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_task(task_arn):
+ client = boto3.client('datasync')
+ client.delete_task(
+ TaskArn=task_arn,
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_task_created_by_operator(**kwargs):
+ client = boto3.client('datasync')
+ client.delete_task(
+ TaskArn=kwargs['ti'].xcom_pull('create_and_execute_task')['TaskArn'],
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def list_locations(bucket_source, bucket_destination):
+ client = boto3.client('datasync')
+ return client.list_locations(
+ Filters=[
+ {
+ 'Name': 'LocationUri',
+ 'Values': [
+ f's3://{bucket_source}/test/',
+ f's3://{bucket_destination}/test/',
+ f's3://{bucket_source}/test_create/',
+ f's3://{bucket_destination}/test_create/',
+ ],
+ 'Operator': 'In',
+ }
+ ]
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_locations(locations):
+ client = boto3.client('datasync')
+ for location in locations['Locations']:
+ client.delete_location(
+ LocationArn=location['LocationArn'],
+ )
+
+
+with models.DAG(
+ DAG_ID,
+ schedule_interval='@once',
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=['example'],
+) as dag:
+ test_context = sys_test_context_task()
+
+ s3_bucket_source: str = f'{test_context[ENV_ID_KEY]}-datasync-bucket-source'
+ s3_bucket_destination: str = f'{test_context[ENV_ID_KEY]}-datasync-bucket-destination'
+
+ create_s3_bucket_source = S3CreateBucketOperator(
+ task_id='create_s3_bucket_source', bucket_name=s3_bucket_source
+ )
+
+ create_s3_bucket_destination = S3CreateBucketOperator(
+ task_id='create_s3_bucket_destination', bucket_name=s3_bucket_destination
+ )
+
+ source_location = create_source_location(s3_bucket_source, test_context[ROLE_ARN_KEY])
+ destination_location = create_destination_location(s3_bucket_destination, test_context[ROLE_ARN_KEY])
+
+ created_task_arn = create_task()
+
+ # [START howto_operator_datasync_specific_task]
+ # Execute a specific task
+ execute_task_by_arn = DataSyncOperator(
+ task_id='execute_task_by_arn',
+ task_arn=created_task_arn,
+ wait_for_completion=False,
+ )
+ # [END howto_operator_datasync_specific_task]
+
+ # [START howto_operator_datasync_search_task]
+ # Search and execute a task
+ execute_task_by_locations = DataSyncOperator(
+ task_id='execute_task_by_locations',
+ source_location_uri=f's3://{s3_bucket_source}/test',
+ destination_location_uri=f's3://{s3_bucket_destination}/test',
+ # Only transfer files from /test/subdir folder
+ task_execution_kwargs={
+ 'Includes': [{'FilterType': 'SIMPLE_PATTERN', 'Value': '/test/subdir'}],
+ },
+ wait_for_completion=False,
+ )
+ # [END howto_operator_datasync_search_task]
+
+ # [START howto_operator_datasync_create_task]
+ # Create a task (the task does not exist)
+ create_and_execute_task = DataSyncOperator(
+ task_id='create_and_execute_task',
+ source_location_uri=f's3://{s3_bucket_source}/test_create',
+ destination_location_uri=f's3://{s3_bucket_destination}/test_create',
+ create_task_kwargs={"Name": "Created by Airflow"},
+ create_source_location_kwargs={
+ 'Subdirectory': 'test_create',
+ 'S3BucketArn': get_s3_bucket_arn(s3_bucket_source),
+ 'S3Config': {
+ 'BucketAccessRoleArn': test_context[ROLE_ARN_KEY],
+ },
+ },
+ create_destination_location_kwargs={
+ 'Subdirectory': 'test_create',
+ 'S3BucketArn': get_s3_bucket_arn(s3_bucket_destination),
+ 'S3Config': {
+ 'BucketAccessRoleArn': test_context[ROLE_ARN_KEY],
+ },
+ },
+ delete_task_after_execution=False,
+ wait_for_completion=False,
+ )
+ # [END howto_operator_datasync_create_task]
+
+ locations_task = list_locations(s3_bucket_source, s3_bucket_destination)
+ delete_locations_task = delete_locations(locations_task)
+
+ delete_s3_bucket_source = S3DeleteBucketOperator(
+ task_id='delete_s3_bucket_source',
+ bucket_name=s3_bucket_source,
+ force_delete=True,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_s3_bucket_destination = S3DeleteBucketOperator(
+ task_id='delete_s3_bucket_destination',
+ bucket_name=s3_bucket_destination,
+ force_delete=True,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ chain(
+ # TEST SETUP
+ test_context,
+ create_s3_bucket_source,
+ create_s3_bucket_destination,
+ source_location,
+ destination_location,
+ created_task_arn,
+ # TEST BODY
+ execute_task_by_arn,
+ execute_task_by_locations,
+ create_and_execute_task,
+ # TEST TEARDOWN
+ delete_task(created_task_arn),
+ delete_task_created_by_operator(),
+ locations_task,
+ delete_locations_task,
+ delete_s3_bucket_source,
+ delete_s3_bucket_destination,
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)