You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/16 21:20:09 UTC

[airflow] branch main updated: Migrate datasync sample dag to system tests (AIP-47) (#24354)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 77f51dcf72 Migrate datasync sample dag to system tests (AIP-47) (#24354)
77f51dcf72 is described below

commit 77f51dcf72eca01721379c3fe59d20ba701d7db7
Author: Vincent <97...@users.noreply.github.com>
AuthorDate: Sat Jul 16 17:20:00 2022 -0400

    Migrate datasync sample dag to system tests (AIP-47) (#24354)
    
    * Migrate datasync sample dag to system tests (AIP-47)
    
    * Use SystemTestContextBuilder to pass along data
---
 .../amazon/aws/example_dags/example_datasync.py    |  81 -------
 airflow/providers/amazon/aws/operators/datasync.py |   6 +
 .../operators/datasync.rst                         |   6 +-
 .../amazon/aws/operators/test_datasync.py          |  17 +-
 .../providers/amazon/aws/example_datasync.py       | 239 +++++++++++++++++++++
 5 files changed, 264 insertions(+), 85 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync.py b/airflow/providers/amazon/aws/example_dags/example_datasync.py
deleted file mode 100644
index 09c474079e..0000000000
--- a/airflow/providers/amazon/aws/example_dags/example_datasync.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-import re
-from datetime import datetime
-from os import getenv
-
-from airflow import models
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
-
-TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn")
-SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
-DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
-CREATE_TASK_KWARGS = json.loads(getenv("CREATE_TASK_KWARGS", '{"Name": "Created by Airflow"}'))
-CREATE_SOURCE_LOCATION_KWARGS = json.loads(getenv("CREATE_SOURCE_LOCATION_KWARGS", '{}'))
-default_destination_location_kwargs = """\
-{"S3BucketArn": "arn:aws:s3:::mybucket",
-    "S3Config": {"BucketAccessRoleArn":
-    "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"}
-}"""
-CREATE_DESTINATION_LOCATION_KWARGS = json.loads(
-    getenv("CREATE_DESTINATION_LOCATION_KWARGS", re.sub(r"[\s+]", '', default_destination_location_kwargs))
-)
-UPDATE_TASK_KWARGS = json.loads(getenv("UPDATE_TASK_KWARGS", '{"Name": "Updated by Airflow"}'))
-
-with models.DAG(
-    "example_datasync",
-    schedule_interval=None,  # Override to match your needs
-    start_date=datetime(2021, 1, 1),
-    catchup=False,
-    tags=['example'],
-) as dag:
-    # [START howto_operator_datasync_specific_task]
-    # Execute a specific task
-    datasync_specific_task = DataSyncOperator(task_id="datasync_specific_task", task_arn=TASK_ARN)
-    # [END howto_operator_datasync_specific_task]
-
-    # [START howto_operator_datasync_search_task]
-    # Search and execute a task
-    datasync_search_task = DataSyncOperator(
-        task_id="datasync_search_task",
-        source_location_uri=SOURCE_LOCATION_URI,
-        destination_location_uri=DESTINATION_LOCATION_URI,
-    )
-    # [END howto_operator_datasync_search_task]
-
-    # [START howto_operator_datasync_create_task]
-    # Create a task (the task does not exist)
-    datasync_create_task = DataSyncOperator(
-        task_id="datasync_create_task",
-        source_location_uri=SOURCE_LOCATION_URI,
-        destination_location_uri=DESTINATION_LOCATION_URI,
-        create_task_kwargs=CREATE_TASK_KWARGS,
-        create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS,
-        create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS,
-        update_task_kwargs=UPDATE_TASK_KWARGS,
-        delete_task_after_execution=True,
-    )
-    # [END howto_operator_datasync_create_task]
-
-    chain(
-        datasync_specific_task,
-        datasync_search_task,
-        datasync_create_task,
-    )
diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py
index 5a0c860711..ffe864d5ce 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -49,6 +49,7 @@ class DataSyncOperator(BaseOperator):
         consecutive calls to check TaskExecution status.
     :param max_iterations: Maximum number of
         consecutive calls to check TaskExecution status.
+    :param wait_for_completion: If True, wait for the task execution to reach a final state
     :param task_arn: AWS DataSync TaskArn to use. If None, then this operator will
         attempt to either search for an existing Task or attempt to create a new Task.
     :param source_location_uri: Source location URI to search for. All DataSync
@@ -122,6 +123,7 @@ class DataSyncOperator(BaseOperator):
         aws_conn_id: str = "aws_default",
         wait_interval_seconds: int = 30,
         max_iterations: int = 60,
+        wait_for_completion: bool = True,
         task_arn: Optional[str] = None,
         source_location_uri: Optional[str] = None,
         destination_location_uri: Optional[str] = None,
@@ -141,6 +143,7 @@ class DataSyncOperator(BaseOperator):
         self.aws_conn_id = aws_conn_id
         self.wait_interval_seconds = wait_interval_seconds
         self.max_iterations = max_iterations
+        self.wait_for_completion = wait_for_completion
 
         self.task_arn = task_arn
 
@@ -346,6 +349,9 @@ class DataSyncOperator(BaseOperator):
         self.task_execution_arn = hook.start_task_execution(self.task_arn, **self.task_execution_kwargs)
         self.log.info("Started TaskExecutionArn %s", self.task_execution_arn)
 
+        if not self.wait_for_completion:
+            return
+
         # Wait for task execution to complete
         self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn)
         try:
diff --git a/docs/apache-airflow-providers-amazon/operators/datasync.rst b/docs/apache-airflow-providers-amazon/operators/datasync.rst
index 5380aac5fa..b0b6100f44 100644
--- a/docs/apache-airflow-providers-amazon/operators/datasync.rst
+++ b/docs/apache-airflow-providers-amazon/operators/datasync.rst
@@ -59,7 +59,7 @@ Execute a task
 
 To execute a specific task, you can pass the ``task_arn`` to the operator.
 
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_datasync.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_datasync.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_datasync_specific_task]
@@ -73,7 +73,7 @@ If one task is found, this one will be executed.
 If more than one task is found, the operator will raise an Exception. To avoid this, you can set
 ``allow_random_task_choice`` to ``True`` to randomly choose from candidate tasks.
 
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_datasync.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_datasync.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_datasync_search_task]
@@ -92,7 +92,7 @@ existing Task was found. If these are left to their default value (None) then no
 Also, because ``delete_task_after_execution`` is set to ``True``, the task will be deleted
 from AWS DataSync after it completes successfully.
 
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_datasync.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_datasync.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_datasync_create_task]
diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py
index 30caf91b18..30a8666cfe 100644
--- a/tests/providers/amazon/aws/operators/test_datasync.py
+++ b/tests/providers/amazon/aws/operators/test_datasync.py
@@ -619,7 +619,9 @@ class TestDataSyncOperatorUpdate(DataSyncTestCaseBase):
 @mock_datasync
 @mock.patch.object(DataSyncHook, "get_conn")
 class TestDataSyncOperator(DataSyncTestCaseBase):
-    def set_up_operator(self, task_id="test_datasync_task_operator", task_arn="self"):
+    def set_up_operator(
+        self, task_id="test_datasync_task_operator", task_arn="self", wait_for_completion=True
+    ):
         if task_arn == "self":
             task_arn = self.task_arn
         # Create operator
@@ -627,6 +629,7 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
             task_id=task_id,
             dag=self.dag,
             wait_interval_seconds=0,
+            wait_for_completion=wait_for_completion,
             task_arn=task_arn,
         )
 
@@ -693,6 +696,18 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
         # ### Check mocks:
         mock_get_conn.assert_called()
 
+    @mock.patch.object(DataSyncHook, "wait_for_task_execution")
+    def test_execute_task_without_wait_for_completion(self, mock_wait, mock_get_conn):
+        self.set_up_operator(wait_for_completion=False)
+
+        # Execute the task
+        result = self.datasync.execute(None)
+        assert result is not None
+        task_execution_arn = result["TaskExecutionArn"]
+        assert task_execution_arn is not None
+
+        mock_wait.assert_not_called()
+
     @mock.patch.object(DataSyncHook, "wait_for_task_execution")
     def test_failed_task(self, mock_wait, mock_get_conn):
         # ### Set up mocks:
diff --git a/tests/system/providers/amazon/aws/example_datasync.py b/tests/system/providers/amazon/aws/example_datasync.py
new file mode 100644
index 0000000000..ac1e791f06
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_datasync.py
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+import boto3
+
+from airflow import models
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
+from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
+
+DAG_ID = 'example_datasync'
+
+# Externally fetched variables:
+ROLE_ARN_KEY = 'ROLE_ARN'
+
+sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+
+def get_s3_bucket_arn(bucket_name):
+    return f'arn:aws:s3:::{bucket_name}'
+
+
+def create_location(bucket_name, role_arn):
+    client = boto3.client('datasync')
+    response = client.create_location_s3(
+        Subdirectory='test',
+        S3BucketArn=get_s3_bucket_arn(bucket_name),
+        S3Config={
+            'BucketAccessRoleArn': role_arn,
+        },
+    )
+    return response['LocationArn']
+
+
+@task
+def create_source_location(bucket_source, role_arn):
+    return create_location(bucket_source, role_arn)
+
+
+@task
+def create_destination_location(bucket_destination, role_arn):
+    return create_location(bucket_destination, role_arn)
+
+
+@task
+def create_task(**kwargs):
+    client = boto3.client('datasync')
+    response = client.create_task(
+        SourceLocationArn=kwargs['ti'].xcom_pull('create_source_location'),
+        DestinationLocationArn=kwargs['ti'].xcom_pull('create_destination_location'),
+    )
+    return response['TaskArn']
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_task(task_arn):
+    client = boto3.client('datasync')
+    client.delete_task(
+        TaskArn=task_arn,
+    )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_task_created_by_operator(**kwargs):
+    client = boto3.client('datasync')
+    client.delete_task(
+        TaskArn=kwargs['ti'].xcom_pull('create_and_execute_task')['TaskArn'],
+    )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def list_locations(bucket_source, bucket_destination):
+    client = boto3.client('datasync')
+    return client.list_locations(
+        Filters=[
+            {
+                'Name': 'LocationUri',
+                'Values': [
+                    f's3://{bucket_source}/test/',
+                    f's3://{bucket_destination}/test/',
+                    f's3://{bucket_source}/test_create/',
+                    f's3://{bucket_destination}/test_create/',
+                ],
+                'Operator': 'In',
+            }
+        ]
+    )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_locations(locations):
+    client = boto3.client('datasync')
+    for location in locations['Locations']:
+        client.delete_location(
+            LocationArn=location['LocationArn'],
+        )
+
+
+with models.DAG(
+    DAG_ID,
+    schedule_interval='@once',
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=['example'],
+) as dag:
+    test_context = sys_test_context_task()
+
+    s3_bucket_source: str = f'{test_context[ENV_ID_KEY]}-datasync-bucket-source'
+    s3_bucket_destination: str = f'{test_context[ENV_ID_KEY]}-datasync-bucket-destination'
+
+    create_s3_bucket_source = S3CreateBucketOperator(
+        task_id='create_s3_bucket_source', bucket_name=s3_bucket_source
+    )
+
+    create_s3_bucket_destination = S3CreateBucketOperator(
+        task_id='create_s3_bucket_destination', bucket_name=s3_bucket_destination
+    )
+
+    source_location = create_source_location(s3_bucket_source, test_context[ROLE_ARN_KEY])
+    destination_location = create_destination_location(s3_bucket_destination, test_context[ROLE_ARN_KEY])
+
+    created_task_arn = create_task()
+
+    # [START howto_operator_datasync_specific_task]
+    # Execute a specific task
+    execute_task_by_arn = DataSyncOperator(
+        task_id='execute_task_by_arn',
+        task_arn=created_task_arn,
+        wait_for_completion=False,
+    )
+    # [END howto_operator_datasync_specific_task]
+
+    # [START howto_operator_datasync_search_task]
+    # Search and execute a task
+    execute_task_by_locations = DataSyncOperator(
+        task_id='execute_task_by_locations',
+        source_location_uri=f's3://{s3_bucket_source}/test',
+        destination_location_uri=f's3://{s3_bucket_destination}/test',
+        # Only transfer files from /test/subdir folder
+        task_execution_kwargs={
+            'Includes': [{'FilterType': 'SIMPLE_PATTERN', 'Value': '/test/subdir'}],
+        },
+        wait_for_completion=False,
+    )
+    # [END howto_operator_datasync_search_task]
+
+    # [START howto_operator_datasync_create_task]
+    # Create a task (the task does not exist)
+    create_and_execute_task = DataSyncOperator(
+        task_id='create_and_execute_task',
+        source_location_uri=f's3://{s3_bucket_source}/test_create',
+        destination_location_uri=f's3://{s3_bucket_destination}/test_create',
+        create_task_kwargs={"Name": "Created by Airflow"},
+        create_source_location_kwargs={
+            'Subdirectory': 'test_create',
+            'S3BucketArn': get_s3_bucket_arn(s3_bucket_source),
+            'S3Config': {
+                'BucketAccessRoleArn': test_context[ROLE_ARN_KEY],
+            },
+        },
+        create_destination_location_kwargs={
+            'Subdirectory': 'test_create',
+            'S3BucketArn': get_s3_bucket_arn(s3_bucket_destination),
+            'S3Config': {
+                'BucketAccessRoleArn': test_context[ROLE_ARN_KEY],
+            },
+        },
+        delete_task_after_execution=False,
+        wait_for_completion=False,
+    )
+    # [END howto_operator_datasync_create_task]
+
+    locations_task = list_locations(s3_bucket_source, s3_bucket_destination)
+    delete_locations_task = delete_locations(locations_task)
+
+    delete_s3_bucket_source = S3DeleteBucketOperator(
+        task_id='delete_s3_bucket_source',
+        bucket_name=s3_bucket_source,
+        force_delete=True,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    delete_s3_bucket_destination = S3DeleteBucketOperator(
+        task_id='delete_s3_bucket_destination',
+        bucket_name=s3_bucket_destination,
+        force_delete=True,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    chain(
+        # TEST SETUP
+        test_context,
+        create_s3_bucket_source,
+        create_s3_bucket_destination,
+        source_location,
+        destination_location,
+        created_task_arn,
+        # TEST BODY
+        execute_task_by_arn,
+        execute_task_by_locations,
+        create_and_execute_task,
+        # TEST TEARDOWN
+        delete_task(created_task_arn),
+        delete_task_created_by_operator(),
+        locations_task,
+        delete_locations_task,
+        delete_s3_bucket_source,
+        delete_s3_bucket_destination,
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)