You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/08/03 21:41:15 UTC

[airflow] branch main updated: Extract sagemaker pipeline to their own system test (#33086)

This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e9a2bb3d4a Extract sagemaker pipeline to their own system test (#33086)
e9a2bb3d4a is described below

commit e9a2bb3d4a4231e203c4c4a5cd417bdfcbf2cf6b
Author: Raphaƫl Vandon <va...@amazon.com>
AuthorDate: Thu Aug 3 14:41:05 2023 -0700

    Extract sagemaker pipeline to their own system test (#33086)
---
 .../operators/sagemaker.rst                        |   6 +-
 .../providers/amazon/aws/example_sagemaker.py      |  53 ---------
 .../amazon/aws/example_sagemaker_pipeline.py       | 126 +++++++++++++++++++++
 3 files changed, 129 insertions(+), 56 deletions(-)

diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
index 933d162df8..ca4ab34d69 100644
--- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
@@ -154,7 +154,7 @@ Start an Amazon SageMaker pipeline execution
 To trigger an execution run for an already-defined Amazon Sagemaker pipeline, you can use
 :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStartPipelineOperator`.
 
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_sagemaker_start_pipeline]
@@ -168,7 +168,7 @@ Stop an Amazon SageMaker pipeline execution
 To stop an Amazon Sagemaker pipeline execution that is currently running, you can use
 :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStopPipelineOperator`.
 
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_sagemaker_stop_pipeline]
@@ -289,7 +289,7 @@ Wait on an Amazon SageMaker pipeline execution state
 To check the state of an Amazon Sagemaker pipeline execution until it reaches a terminal state
 you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerPipelineSensor`.
 
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
     :language: python
     :dedent: 4
     :start-after: [START howto_sensor_sagemaker_pipeline]
diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py b/tests/system/providers/amazon/aws/example_sagemaker.py
index c1f5745caf..e471d69b38 100644
--- a/tests/system/providers/amazon/aws/example_sagemaker.py
+++ b/tests/system/providers/amazon/aws/example_sagemaker.py
@@ -40,15 +40,12 @@ from airflow.providers.amazon.aws.operators.sagemaker import (
     SageMakerModelOperator,
     SageMakerProcessingOperator,
     SageMakerRegisterModelVersionOperator,
-    SageMakerStartPipelineOperator,
-    SageMakerStopPipelineOperator,
     SageMakerTrainingOperator,
     SageMakerTransformOperator,
     SageMakerTuningOperator,
 )
 from airflow.providers.amazon.aws.sensors.sagemaker import (
     SageMakerAutoMLSensor,
-    SageMakerPipelineSensor,
     SageMakerTrainingSensor,
     SageMakerTransformSensor,
     SageMakerTuningSensor,
@@ -199,7 +196,6 @@ def set_up(env_id, role_arn):
     transform_job_name = f"{env_id}-transform"
     tuning_job_name = f"{env_id}-tune"
     model_package_group_name = f"{env_id}-group"
-    pipeline_name = f"{env_id}-pipe"
     auto_ml_job_name = f"{env_id}-automl"
     experiment_name = f"{env_id}-experiment"
 
@@ -221,16 +217,6 @@ def set_up(env_id, role_arn):
             f"the directions at the top of the system testfile "
         )
 
-    # Json definition for a dummy pipeline of 30 chained "conditional step" checking that 3 < 6
-    # Each step takes roughly 1 second to execute, so the pipeline runtimes is ~30 seconds, which should be
-    # enough to test stopping and awaiting without race conditions.
-    # Built using sagemaker sdk, and using json.loads(pipeline.definition())
-    pipeline_json_definition = """{"Version": "2020-12-01", "Metadata": {}, "Parameters": [], "PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, "TrialName": {"Get": "Execution.PipelineExecutionId"}}, "Steps": [{"Name": "DummyCond29", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond28", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrE [...]
-    sgmk_client = boto3.client("sagemaker")
-    sgmk_client.create_pipeline(
-        PipelineName=pipeline_name, PipelineDefinition=pipeline_json_definition, RoleArn=role_arn
-    )
-
     resource_config = {
         "InstanceCount": 1,
         "InstanceType": "ml.m5.large",
@@ -410,7 +396,6 @@ def set_up(env_id, role_arn):
     ti.xcom_push(key="training_config", value=training_config)
     ti.xcom_push(key="training_job_name", value=training_job_name)
     ti.xcom_push(key="model_package_group_name", value=model_package_group_name)
-    ti.xcom_push(key="pipeline_name", value=pipeline_name)
     ti.xcom_push(key="auto_ml_job_name", value=auto_ml_job_name)
     ti.xcom_push(key="experiment_name", value=experiment_name)
     ti.xcom_push(key="model_config", value=model_config)
@@ -444,12 +429,6 @@ def delete_model_group(group_name, model_version_arn):
     sgmk_client.delete_model_package_group(ModelPackageGroupName=group_name)
 
 
-@task(trigger_rule=TriggerRule.ALL_DONE)
-def delete_pipeline(pipeline_name):
-    sgmk_client = boto3.client("sagemaker")
-    sgmk_client.delete_pipeline(PipelineName=pipeline_name)
-
-
 @task(trigger_rule=TriggerRule.ALL_DONE)
 def delete_experiment(name):
     sgmk_client = boto3.client("sagemaker")
@@ -528,33 +507,6 @@ with DAG(
     # [END howto_sensor_sagemaker_auto_ml]
     await_automl.poke_interval = 10
 
-    # [START howto_operator_sagemaker_start_pipeline]
-    start_pipeline1 = SageMakerStartPipelineOperator(
-        task_id="start_pipeline1",
-        pipeline_name=test_setup["pipeline_name"],
-    )
-    # [END howto_operator_sagemaker_start_pipeline]
-
-    # [START howto_operator_sagemaker_stop_pipeline]
-    stop_pipeline1 = SageMakerStopPipelineOperator(
-        task_id="stop_pipeline1",
-        pipeline_exec_arn=start_pipeline1.output,
-    )
-    # [END howto_operator_sagemaker_stop_pipeline]
-
-    start_pipeline2 = SageMakerStartPipelineOperator(
-        task_id="start_pipeline2",
-        pipeline_name=test_setup["pipeline_name"],
-    )
-
-    # [START howto_sensor_sagemaker_pipeline]
-    await_pipeline2 = SageMakerPipelineSensor(
-        task_id="await_pipeline2",
-        pipeline_exec_arn=start_pipeline2.output,
-    )
-    # [END howto_sensor_sagemaker_pipeline]
-    await_pipeline2.poke_interval = 10
-
     # [START howto_operator_sagemaker_experiment]
     create_experiment = SageMakerCreateExperimentOperator(
         task_id="create_experiment", name=test_setup["experiment_name"]
@@ -668,10 +620,6 @@ with DAG(
         # TEST BODY
         automl,
         await_automl,
-        start_pipeline1,
-        start_pipeline2,
-        stop_pipeline1,
-        await_pipeline2,
         create_experiment,
         preprocess_raw_data,
         train_model,
@@ -688,7 +636,6 @@ with DAG(
         delete_model,
         delete_bucket,
         delete_experiment(test_setup["experiment_name"]),
-        delete_pipeline(test_setup["pipeline_name"]),
         delete_docker_image(test_setup["docker_image"]),
         log_cleanup,
     )
diff --git a/tests/system/providers/amazon/aws/example_sagemaker_pipeline.py b/tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
new file mode 100644
index 0000000000..204df76287
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+import boto3
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerStartPipelineOperator,
+    SageMakerStopPipelineOperator,
+)
+from airflow.providers.amazon.aws.sensors.sagemaker import (
+    SageMakerPipelineSensor,
+)
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
+
+DAG_ID = "example_sagemaker_pipeline"
+
+# Externally fetched variables:
+ROLE_ARN_KEY = "ROLE_ARN"
+
+sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+
+@task
+def create_pipeline(name: str, role_arn: str):
+    # Json definition for a dummy pipeline of 30 chained "conditional step" checking that 3 < 6
+    # Each step takes roughly 1 second to execute, so the pipeline runtimes is ~30 seconds, which should be
+    # enough to test stopping and awaiting without race conditions.
+    # Built using sagemaker sdk, and using json.loads(pipeline.definition())
+    pipeline_json_definition = """{"Version": "2020-12-01", "Metadata": {}, "Parameters": [], "PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, "TrialName": {"Get": "Execution.PipelineExecutionId"}}, "Steps": [{"Name": "DummyCond29", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond28", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrE [...]
+    sgmk_client = boto3.client("sagemaker")
+    sgmk_client.create_pipeline(
+        PipelineName=name, PipelineDefinition=pipeline_json_definition, RoleArn=role_arn
+    )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_pipeline(name: str):
+    sgmk_client = boto3.client("sagemaker")
+    sgmk_client.delete_pipeline(PipelineName=name)
+
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    tags=["example"],
+    catchup=False,
+) as dag:
+    test_context = sys_test_context_task()
+    env_id = test_context[ENV_ID_KEY]
+
+    pipeline_name = f"{env_id}-pipeline"
+
+    create_pipeline = create_pipeline(pipeline_name, test_context[ROLE_ARN_KEY])
+
+    # [START howto_operator_sagemaker_start_pipeline]
+    start_pipeline1 = SageMakerStartPipelineOperator(
+        task_id="start_pipeline1",
+        pipeline_name=pipeline_name,
+    )
+    # [END howto_operator_sagemaker_start_pipeline]
+
+    # [START howto_operator_sagemaker_stop_pipeline]
+    stop_pipeline1 = SageMakerStopPipelineOperator(
+        task_id="stop_pipeline1",
+        pipeline_exec_arn=start_pipeline1.output,
+    )
+    # [END howto_operator_sagemaker_stop_pipeline]
+
+    start_pipeline2 = SageMakerStartPipelineOperator(
+        task_id="start_pipeline2",
+        pipeline_name=pipeline_name,
+    )
+
+    # [START howto_sensor_sagemaker_pipeline]
+    await_pipeline2 = SageMakerPipelineSensor(
+        task_id="await_pipeline2",
+        pipeline_exec_arn=start_pipeline2.output,
+    )
+    # [END howto_sensor_sagemaker_pipeline]
+    await_pipeline2.poke_interval = 10
+
+    chain(
+        # TEST SETUP
+        test_context,
+        create_pipeline,
+        # TEST BODY
+        start_pipeline1,
+        start_pipeline2,
+        stop_pipeline1,
+        await_pipeline2,
+        # TEST TEARDOWN
+        delete_pipeline(pipeline_name),
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)