You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/08/03 21:41:15 UTC
[airflow] branch main updated: Extract sagemaker pipeline to their own system test (#33086)
This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e9a2bb3d4a Extract sagemaker pipeline to their own system test (#33086)
e9a2bb3d4a is described below
commit e9a2bb3d4a4231e203c4c4a5cd417bdfcbf2cf6b
Author: Raphaƫl Vandon <va...@amazon.com>
AuthorDate: Thu Aug 3 14:41:05 2023 -0700
Extract sagemaker pipeline to their own system test (#33086)
---
.../operators/sagemaker.rst | 6 +-
.../providers/amazon/aws/example_sagemaker.py | 53 ---------
.../amazon/aws/example_sagemaker_pipeline.py | 126 +++++++++++++++++++++
3 files changed, 129 insertions(+), 56 deletions(-)
diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
index 933d162df8..ca4ab34d69 100644
--- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
@@ -154,7 +154,7 @@ Start an Amazon SageMaker pipeline execution
To trigger an execution run for an already-defined Amazon Sagemaker pipeline, you can use
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStartPipelineOperator`.
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
:language: python
:dedent: 4
:start-after: [START howto_operator_sagemaker_start_pipeline]
@@ -168,7 +168,7 @@ Stop an Amazon SageMaker pipeline execution
To stop an Amazon Sagemaker pipeline execution that is currently running, you can use
:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStopPipelineOperator`.
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
:language: python
:dedent: 4
:start-after: [START howto_operator_sagemaker_stop_pipeline]
@@ -289,7 +289,7 @@ Wait on an Amazon SageMaker pipeline execution state
To check the state of an Amazon Sagemaker pipeline execution until it reaches a terminal state
you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerPipelineSensor`.
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_sagemaker_pipeline]
diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py b/tests/system/providers/amazon/aws/example_sagemaker.py
index c1f5745caf..e471d69b38 100644
--- a/tests/system/providers/amazon/aws/example_sagemaker.py
+++ b/tests/system/providers/amazon/aws/example_sagemaker.py
@@ -40,15 +40,12 @@ from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerModelOperator,
SageMakerProcessingOperator,
SageMakerRegisterModelVersionOperator,
- SageMakerStartPipelineOperator,
- SageMakerStopPipelineOperator,
SageMakerTrainingOperator,
SageMakerTransformOperator,
SageMakerTuningOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker import (
SageMakerAutoMLSensor,
- SageMakerPipelineSensor,
SageMakerTrainingSensor,
SageMakerTransformSensor,
SageMakerTuningSensor,
@@ -199,7 +196,6 @@ def set_up(env_id, role_arn):
transform_job_name = f"{env_id}-transform"
tuning_job_name = f"{env_id}-tune"
model_package_group_name = f"{env_id}-group"
- pipeline_name = f"{env_id}-pipe"
auto_ml_job_name = f"{env_id}-automl"
experiment_name = f"{env_id}-experiment"
@@ -221,16 +217,6 @@ def set_up(env_id, role_arn):
f"the directions at the top of the system testfile "
)
- # Json definition for a dummy pipeline of 30 chained "conditional step" checking that 3 < 6
- # Each step takes roughly 1 second to execute, so the pipeline runtimes is ~30 seconds, which should be
- # enough to test stopping and awaiting without race conditions.
- # Built using sagemaker sdk, and using json.loads(pipeline.definition())
- pipeline_json_definition = """{"Version": "2020-12-01", "Metadata": {}, "Parameters": [], "PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, "TrialName": {"Get": "Execution.PipelineExecutionId"}}, "Steps": [{"Name": "DummyCond29", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond28", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrE [...]
- sgmk_client = boto3.client("sagemaker")
- sgmk_client.create_pipeline(
- PipelineName=pipeline_name, PipelineDefinition=pipeline_json_definition, RoleArn=role_arn
- )
-
resource_config = {
"InstanceCount": 1,
"InstanceType": "ml.m5.large",
@@ -410,7 +396,6 @@ def set_up(env_id, role_arn):
ti.xcom_push(key="training_config", value=training_config)
ti.xcom_push(key="training_job_name", value=training_job_name)
ti.xcom_push(key="model_package_group_name", value=model_package_group_name)
- ti.xcom_push(key="pipeline_name", value=pipeline_name)
ti.xcom_push(key="auto_ml_job_name", value=auto_ml_job_name)
ti.xcom_push(key="experiment_name", value=experiment_name)
ti.xcom_push(key="model_config", value=model_config)
@@ -444,12 +429,6 @@ def delete_model_group(group_name, model_version_arn):
sgmk_client.delete_model_package_group(ModelPackageGroupName=group_name)
-@task(trigger_rule=TriggerRule.ALL_DONE)
-def delete_pipeline(pipeline_name):
- sgmk_client = boto3.client("sagemaker")
- sgmk_client.delete_pipeline(PipelineName=pipeline_name)
-
-
@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_experiment(name):
sgmk_client = boto3.client("sagemaker")
@@ -528,33 +507,6 @@ with DAG(
# [END howto_sensor_sagemaker_auto_ml]
await_automl.poke_interval = 10
- # [START howto_operator_sagemaker_start_pipeline]
- start_pipeline1 = SageMakerStartPipelineOperator(
- task_id="start_pipeline1",
- pipeline_name=test_setup["pipeline_name"],
- )
- # [END howto_operator_sagemaker_start_pipeline]
-
- # [START howto_operator_sagemaker_stop_pipeline]
- stop_pipeline1 = SageMakerStopPipelineOperator(
- task_id="stop_pipeline1",
- pipeline_exec_arn=start_pipeline1.output,
- )
- # [END howto_operator_sagemaker_stop_pipeline]
-
- start_pipeline2 = SageMakerStartPipelineOperator(
- task_id="start_pipeline2",
- pipeline_name=test_setup["pipeline_name"],
- )
-
- # [START howto_sensor_sagemaker_pipeline]
- await_pipeline2 = SageMakerPipelineSensor(
- task_id="await_pipeline2",
- pipeline_exec_arn=start_pipeline2.output,
- )
- # [END howto_sensor_sagemaker_pipeline]
- await_pipeline2.poke_interval = 10
-
# [START howto_operator_sagemaker_experiment]
create_experiment = SageMakerCreateExperimentOperator(
task_id="create_experiment", name=test_setup["experiment_name"]
@@ -668,10 +620,6 @@ with DAG(
# TEST BODY
automl,
await_automl,
- start_pipeline1,
- start_pipeline2,
- stop_pipeline1,
- await_pipeline2,
create_experiment,
preprocess_raw_data,
train_model,
@@ -688,7 +636,6 @@ with DAG(
delete_model,
delete_bucket,
delete_experiment(test_setup["experiment_name"]),
- delete_pipeline(test_setup["pipeline_name"]),
delete_docker_image(test_setup["docker_image"]),
log_cleanup,
)
diff --git a/tests/system/providers/amazon/aws/example_sagemaker_pipeline.py b/tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
new file mode 100644
index 0000000000..204df76287
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_sagemaker_pipeline.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+import boto3
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.sagemaker import (
+ SageMakerStartPipelineOperator,
+ SageMakerStopPipelineOperator,
+)
+from airflow.providers.amazon.aws.sensors.sagemaker import (
+ SageMakerPipelineSensor,
+)
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
+
+DAG_ID = "example_sagemaker_pipeline"
+
+# Externally fetched variables:
+ROLE_ARN_KEY = "ROLE_ARN"
+
+sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
+
+
+@task
+def create_pipeline(name: str, role_arn: str):
+ # Json definition for a dummy pipeline of 30 chained "conditional step" checking that 3 < 6
+ # Each step takes roughly 1 second to execute, so the pipeline runtimes is ~30 seconds, which should be
+ # enough to test stopping and awaiting without race conditions.
+ # Built using sagemaker sdk, and using json.loads(pipeline.definition())
+ pipeline_json_definition = """{"Version": "2020-12-01", "Metadata": {}, "Parameters": [], "PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, "TrialName": {"Get": "Execution.PipelineExecutionId"}}, "Steps": [{"Name": "DummyCond29", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrEqualTo", "LeftValue": 3.0, "RightValue": 6.0}], "IfSteps": [{"Name": "DummyCond28", "Type": "Condition", "Arguments": {"Conditions": [{"Type": "LessThanOrE [...]
+ sgmk_client = boto3.client("sagemaker")
+ sgmk_client.create_pipeline(
+ PipelineName=name, PipelineDefinition=pipeline_json_definition, RoleArn=role_arn
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_pipeline(name: str):
+ sgmk_client = boto3.client("sagemaker")
+ sgmk_client.delete_pipeline(PipelineName=name)
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ tags=["example"],
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+ env_id = test_context[ENV_ID_KEY]
+
+ pipeline_name = f"{env_id}-pipeline"
+
+ create_pipeline = create_pipeline(pipeline_name, test_context[ROLE_ARN_KEY])
+
+ # [START howto_operator_sagemaker_start_pipeline]
+ start_pipeline1 = SageMakerStartPipelineOperator(
+ task_id="start_pipeline1",
+ pipeline_name=pipeline_name,
+ )
+ # [END howto_operator_sagemaker_start_pipeline]
+
+ # [START howto_operator_sagemaker_stop_pipeline]
+ stop_pipeline1 = SageMakerStopPipelineOperator(
+ task_id="stop_pipeline1",
+ pipeline_exec_arn=start_pipeline1.output,
+ )
+ # [END howto_operator_sagemaker_stop_pipeline]
+
+ start_pipeline2 = SageMakerStartPipelineOperator(
+ task_id="start_pipeline2",
+ pipeline_name=pipeline_name,
+ )
+
+ # [START howto_sensor_sagemaker_pipeline]
+ await_pipeline2 = SageMakerPipelineSensor(
+ task_id="await_pipeline2",
+ pipeline_exec_arn=start_pipeline2.output,
+ )
+ # [END howto_sensor_sagemaker_pipeline]
+ await_pipeline2.poke_interval = 10
+
+ chain(
+ # TEST SETUP
+ test_context,
+ create_pipeline,
+ # TEST BODY
+ start_pipeline1,
+ start_pipeline2,
+ stop_pipeline1,
+ await_pipeline2,
+ # TEST TEARDOWN
+ delete_pipeline(pipeline_name),
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)