You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/01/05 09:51:31 UTC
[airflow] branch main updated: Add AWS Sagemaker Auto ML operator and sensor (#28472)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e8533d295e Add AWS Sagemaker Auto ML operator and sensor (#28472)
e8533d295e is described below
commit e8533d295e6d25296e23d8e1b8c07a441df55964
Author: Raphaƫl Vandon <11...@users.noreply.github.com>
AuthorDate: Thu Jan 5 01:51:23 2023 -0800
Add AWS Sagemaker Auto ML operator and sensor (#28472)
* add an operator to create autoML jobs
---
airflow/providers/amazon/aws/hooks/sagemaker.py | 82 +++++++++++++++++++
.../providers/amazon/aws/operators/sagemaker.py | 91 ++++++++++++++++++++++
airflow/providers/amazon/aws/sensors/sagemaker.py | 32 ++++++++
.../operators/sagemaker.rst | 30 +++++++
tests/providers/amazon/aws/hooks/test_sagemaker.py | 47 +++++++++++
.../amazon/aws/sensors/test_sagemaker_automl.py | 63 +++++++++++++++
.../providers/amazon/aws/example_sagemaker.py | 77 +++++++++++-------
7 files changed, 392 insertions(+), 30 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 6e880b8c91..17133829f8 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -1163,3 +1163,85 @@ class SageMakerHook(AwsBaseHook):
else:
self.log.error("Error when trying to create Model Package Group: %s", e)
raise
+
+ def _describe_auto_ml_job(self, job_name: str):
+ res = self.conn.describe_auto_ml_job(AutoMLJobName=job_name)
+ self.log.info("%s's current step: %s", job_name, res["AutoMLJobSecondaryStatus"])
+ return res
+
+ def create_auto_ml_job(
+ self,
+ job_name: str,
+ s3_input: str,
+ target_attribute: str,
+ s3_output: str,
+ role_arn: str,
+ compressed_input: bool = False,
+ time_limit: int | None = None,
+ autodeploy_endpoint_name: str | None = None,
+ extras: dict | None = None,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ ) -> dict | None:
+ """
+ Creates an auto ML job, learning to predict the given column from the data provided through S3.
+ The learning output is written to the specified S3 location.
+
+ :param job_name: Name of the job to create, needs to be unique within the account.
+ :param s3_input: The S3 location (folder or file) where to fetch the data.
+ By default, it expects csv with headers.
+ :param target_attribute: The name of the column containing the values to predict.
+ :param s3_output: The S3 folder where to write the model artifacts. Must be 128 characters or fewer.
+ :param role_arn: The ARN or the IAM role to use when interacting with S3.
+ Must have read access to the input, and write access to the output folder.
+ :param compressed_input: Set to True if the input is gzipped.
+ :param time_limit: The maximum amount of time in seconds to spend training the model(s).
+ :param autodeploy_endpoint_name: If specified, the best model will be deployed to an endpoint with
+ that name. No deployment made otherwise.
+ :param extras: Use this dictionary to set any variable input variable for job creation that is not
+ offered through the parameters of this function. The format is described in:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
+ :param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
+ :param check_interval: Interval in seconds between 2 status checks when waiting for completion.
+
+ :returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that
+ of the "BestCandidate" key in:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
+ """
+ input_data = [
+ {
+ "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": s3_input}},
+ "TargetAttributeName": target_attribute,
+ },
+ ]
+ params_dict = {
+ "AutoMLJobName": job_name,
+ "InputDataConfig": input_data,
+ "OutputDataConfig": {"S3OutputPath": s3_output},
+ "RoleArn": role_arn,
+ }
+ if compressed_input:
+ input_data[0]["CompressionType"] = "Gzip"
+ if time_limit:
+ params_dict.update(
+ {"AutoMLJobConfig": {"CompletionCriteria": {"MaxAutoMLJobRuntimeInSeconds": time_limit}}}
+ )
+ if autodeploy_endpoint_name:
+ params_dict.update({"ModelDeployConfig": {"EndpointName": autodeploy_endpoint_name}})
+ if extras:
+ params_dict.update(extras)
+
+ # returns the job ARN, but we don't need it because we access it by its name
+ self.conn.create_auto_ml_job(**params_dict)
+
+ if wait_for_completion:
+ res = self.check_status(
+ job_name,
+ "AutoMLJobStatus",
+ # cannot pass the function directly because the parameter needs to be named
+ self._describe_auto_ml_job,
+ check_interval,
+ )
+ if "BestCandidate" in res:
+ return res["BestCandidate"]
+ return None
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 0335f8d15b..06d175a6d1 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -958,3 +958,94 @@ class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
if group_created:
self.hook.conn.delete_model_package_group(ModelPackageGroupName=self.package_group_name)
raise
+
+
+class SageMakerAutoMLOperator(SageMakerBaseOperator):
+ """
+ Creates an auto ML job, learning to predict the given column from the data provided through S3.
+ The learning output is written to the specified S3 location.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:SageMakerAutoMLOperator`
+
+ :param job_name: Name of the job to create, needs to be unique within the account.
+ :param s3_input: The S3 location (folder or file) where to fetch the data.
+ By default, it expects csv with headers.
+ :param target_attribute: The name of the column containing the values to predict.
+ :param s3_output: The S3 folder where to write the model artifacts. Must be 128 characters or fewer.
+ :param role_arn: The ARN of the IAM role to use when interacting with S3.
+ Must have read access to the input, and write access to the output folder.
+ :param compressed_input: Set to True if the input is gzipped.
+ :param time_limit: The maximum amount of time in seconds to spend training the model(s).
+ :param autodeploy_endpoint_name: If specified, the best model will be deployed to an endpoint with
+ that name. No deployment made otherwise.
+ :param extras: Use this dictionary to set any variable input variable for job creation that is not
+ offered through the parameters of this function. The format is described in:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
+ :param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
+ :param check_interval: Interval in seconds between 2 status checks when waiting for completion.
+
+ :returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that of
+ the "BestCandidate" key in:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
+ """
+
+ template_fields: Sequence[str] = (
+ "job_name",
+ "s3_input",
+ "target_attribute",
+ "s3_output",
+ "role_arn",
+ "compressed_input",
+ "time_limit",
+ "autodeploy_endpoint_name",
+ "extras",
+ )
+
+ def __init__(
+ self,
+ *,
+ job_name: str,
+ s3_input: str,
+ target_attribute: str,
+ s3_output: str,
+ role_arn: str,
+ compressed_input: bool = False,
+ time_limit: int | None = None,
+ autodeploy_endpoint_name: str | None = None,
+ extras: dict | None = None,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ aws_conn_id: str = DEFAULT_CONN_ID,
+ config: dict | None = None,
+ **kwargs,
+ ):
+ super().__init__(config=config or {}, aws_conn_id=aws_conn_id, **kwargs)
+ self.job_name = job_name
+ self.s3_input = s3_input
+ self.target_attribute = target_attribute
+ self.s3_output = s3_output
+ self.role_arn = role_arn
+ self.compressed_input = compressed_input
+ self.time_limit = time_limit
+ self.autodeploy_endpoint_name = autodeploy_endpoint_name
+ self.extras = extras
+ self.wait_for_completion = wait_for_completion
+ self.check_interval = check_interval
+
+ def execute(self, context: Context) -> dict | None:
+ best = self.hook.create_auto_ml_job(
+ self.job_name,
+ self.s3_input,
+ self.target_attribute,
+ self.s3_output,
+ self.role_arn,
+ self.compressed_input,
+ self.time_limit,
+ self.autodeploy_endpoint_name,
+ self.extras,
+ self.wait_for_completion,
+ self.check_interval,
+ )
+ return best
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py
index a93135b48c..f8527fbb2c 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -307,3 +307,35 @@ class SageMakerPipelineSensor(SageMakerBaseSensor):
def state_from_response(self, response: dict) -> str:
return response["PipelineExecutionStatus"]
+
+
+class SageMakerAutoMLSensor(SageMakerBaseSensor):
+ """
+ Polls the auto ML job until it reaches a terminal state.
+ Raises an AirflowException with the failure reason if a failed state is reached.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:SageMakerAutoMLSensor`
+
+ :param job_name: unique name of the AutoML job to watch.
+ """
+
+ template_fields: Sequence[str] = ("job_name",)
+
+ def __init__(self, *, job_name: str, **kwargs):
+ super().__init__(resource_type="autoML job", **kwargs)
+ self.job_name = job_name
+
+ def non_terminal_states(self) -> set[str]:
+ return SageMakerHook.non_terminal_states
+
+ def failed_states(self) -> set[str]:
+ return SageMakerHook.failed_states
+
+ def get_sagemaker_response(self) -> dict:
+ self.log.info("Poking Sagemaker AutoML Execution %s", self.job_name)
+ return self.get_hook()._describe_auto_ml_job(self.job_name)
+
+ def state_from_response(self, response: dict) -> str:
+ return response["AutoMLJobStatus"]
diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
index 8d8319e17e..d33441fcd5 100644
--- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
@@ -192,6 +192,22 @@ You can use this operator to add a new version and model package to the group fo
:start-after: [START howto_operator_sagemaker_register]
:end-before: [END howto_operator_sagemaker_register]
+.. _howto/operator:SageMakerAutoMLOperator:
+
+Launch an AutoML experiment
+===========================
+
+To launch an AutoML experiment, a.k.a. SageMaker Autopilot, you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerAutoMLOperator`.
+An AutoML experiment will take some input data in CSV and the column it should learn to predict,
+and train models on it without needing human supervision.
+The output is placed in an S3 bucket, and automatically deployed if configured for it.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_auto_ml]
+ :end-before: [END howto_operator_sagemaker_auto_ml]
+
Sensors
-------
@@ -265,6 +281,20 @@ you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerPip
:start-after: [START howto_sensor_sagemaker_pipeline]
:end-before: [END howto_sensor_sagemaker_pipeline]
+.. _howto/sensor:SageMakerAutoMLSensor:
+
+Wait on an Amazon SageMaker AutoML experiment state
+===================================================
+
+To check the state of an Amazon Sagemaker AutoML job until it reaches a terminal state
+you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerAutoMLSensor`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_auto_ml]
+ :end-before: [END howto_operator_sagemaker_auto_ml]
+
Reference
---------
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index f2a97450fe..2aefe7a0f9 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -849,3 +849,50 @@ class TestSageMakerHook:
created = hook.create_model_package_group("group-name")
assert created is False
+
+ @patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock)
+ def test_create_auto_ml_parameter_structure(self, conn_mock):
+ hook = SageMakerHook()
+
+ hook.create_auto_ml_job(
+ job_name="a",
+ s3_input="b",
+ target_attribute="c",
+ s3_output="d",
+ role_arn="e",
+ compressed_input=True,
+ time_limit=30,
+ wait_for_completion=False,
+ )
+
+ assert conn_mock().create_auto_ml_job.call_args[1] == {
+ "AutoMLJobConfig": {"CompletionCriteria": {"MaxAutoMLJobRuntimeInSeconds": 30}},
+ "AutoMLJobName": "a",
+ "InputDataConfig": [
+ {
+ "CompressionType": "Gzip",
+ "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "b"}},
+ "TargetAttributeName": "c",
+ }
+ ],
+ "OutputDataConfig": {"S3OutputPath": "d"},
+ "RoleArn": "e",
+ }
+
+ @patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock)
+ def test_create_auto_ml_waits_for_completion(self, conn_mock):
+ hook = SageMakerHook()
+ conn_mock().describe_auto_ml_job.side_effect = [
+ {"AutoMLJobStatus": "InProgress", "AutoMLJobSecondaryStatus": "a"},
+ {"AutoMLJobStatus": "InProgress", "AutoMLJobSecondaryStatus": "b"},
+ {
+ "AutoMLJobStatus": "Completed",
+ "AutoMLJobSecondaryStatus": "c",
+ "BestCandidate": {"name": "me"},
+ },
+ ]
+
+ ret = hook.create_auto_ml_job("a", "b", "c", "d", "e", check_interval=0)
+
+ assert conn_mock().describe_auto_ml_job.call_count == 3
+ assert ret == {"name": "me"}
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_automl.py b/tests/providers/amazon/aws/sensors/test_sagemaker_automl.py
new file mode 100644
index 0000000000..454d8ab1be
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_automl.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerAutoMLSensor
+
+
+class TestSageMakerAutoMLSensor:
+ @staticmethod
+ def get_response_with_state(state: str):
+ states = {"Completed", "InProgress", "Failed", "Stopped", "Stopping"}
+ assert state in states
+ return {
+ "AutoMLJobStatus": state,
+ "AutoMLJobSecondaryStatus": "Starting",
+ "ResponseMetadata": {
+ "HTTPStatusCode": 200,
+ },
+ }
+
+ @mock.patch.object(SageMakerHook, "_describe_auto_ml_job")
+ def test_sensor_with_failure(self, mock_describe):
+ mock_describe.return_value = self.get_response_with_state("Failed")
+ sensor = SageMakerAutoMLSensor(job_name="job_job", task_id="test_task")
+
+ with pytest.raises(AirflowException):
+ sensor.execute(None)
+
+ mock_describe.assert_called_once_with("job_job")
+
+ @mock.patch.object(SageMakerHook, "_describe_auto_ml_job")
+ def test_sensor(self, mock_describe):
+ mock_describe.side_effect = [
+ self.get_response_with_state("InProgress"),
+ self.get_response_with_state("Stopping"),
+ self.get_response_with_state("Stopped"),
+ ]
+ sensor = SageMakerAutoMLSensor(job_name="job_job", task_id="test_task", poke_interval=0)
+
+ sensor.execute(None)
+
+ assert mock_describe.call_count == 3
diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py b/tests/system/providers/amazon/aws/example_sagemaker.py
index ab13b02fa2..009352d234 100644
--- a/tests/system/providers/amazon/aws/example_sagemaker.py
+++ b/tests/system/providers/amazon/aws/example_sagemaker.py
@@ -35,6 +35,7 @@ from airflow.providers.amazon.aws.operators.s3 import (
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.operators.sagemaker import (
+ SageMakerAutoMLOperator,
SageMakerDeleteModelOperator,
SageMakerModelOperator,
SageMakerProcessingOperator,
@@ -46,6 +47,7 @@ from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerTuningOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker import (
+ SageMakerAutoMLSensor,
SageMakerPipelineSensor,
SageMakerTrainingSensor,
SageMakerTransformSensor,
@@ -71,18 +73,7 @@ KNN_IMAGES_BY_REGION = {
"us-west-2": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1",
}
-# For this example we are using a subset of Fischer's Iris Data Set.
-# The full dataset can be found at UC Irvine's machine learning repository:
-# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
-DATASET = """
- 5.1,3.5,1.4,0.2,Iris-setosa
- 4.9,3.0,1.4,0.2,Iris-setosa
- 7.0,3.2,4.7,1.4,Iris-versicolor
- 6.4,3.2,4.5,1.5,Iris-versicolor
- 4.9,2.5,4.5,1.7,Iris-virginica
- 7.3,2.9,6.3,1.8,Iris-virginica
- """
-SAMPLE_SIZE = DATASET.count("\n") - 1
+SAMPLE_SIZE = 600
# This script will be the entrypoint for the docker image which will handle preprocessing the raw data
# NOTE: The following string must remain dedented as it is being written to a file.
@@ -92,34 +83,28 @@ import numpy as np
import pandas as pd
def main():
- # Load the Iris dataset from {input_path}/input.csv, split it into train/test
+ # Load the dataset from {input_path}/input.csv, split it into train/test
# subsets, and write them to {output_path}/ for the Processing Operator.
- columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
- iris = pd.read_csv('{input_path}/input.csv', names=columns)
-
- # Process data
- iris['species'] = iris['species'].replace({{'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}})
- iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']]
+ data = pd.read_csv('{input_path}/input.csv')
# Split into test and train data
- iris_train, iris_test = np.split(
- iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))]
+ data_train, data_test = np.split(
+ data.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(data))]
)
# Remove the "answers" from the test set
- iris_test.drop(['species'], axis=1, inplace=True)
+ data_test.drop(['class'], axis=1, inplace=True)
# Write the splits to disk
- iris_train.to_csv('{output_path}/train.csv', index=False, header=False)
- iris_test.to_csv('{output_path}/test.csv', index=False, header=False)
+ data_train.to_csv('{output_path}/train.csv', index=False, header=False)
+ data_test.to_csv('{output_path}/test.csv', index=False, header=False)
print('Preprocessing Done.')
if __name__ == "__main__":
main()
-
- """
+"""
def _create_ecr_repository(repo_name):
@@ -195,6 +180,14 @@ def _build_and_upload_docker_image(preprocess_script, repository_uri):
)
+def generate_data() -> str:
+ """generates a very simple csv dataset with headers"""
+ content = "class,x,y\n" # headers
+ for i in range(SAMPLE_SIZE):
+ content += f"{i%100},{i},{SAMPLE_SIZE-i}\n"
+ return content
+
+
@task
def set_up(env_id, role_arn):
bucket_name = f"{env_id}-sagemaker-example"
@@ -206,6 +199,7 @@ def set_up(env_id, role_arn):
tuning_job_name = f"{env_id}-tune"
model_package_group_name = f"{env_id}-group"
pipeline_name = f"{env_id}-pipe"
+ auto_ml_job_name = f"{env_id}-automl"
input_data_S3_key = f"{env_id}/processed-input-data"
prediction_output_s3_key = f"{env_id}/transform"
@@ -240,6 +234,7 @@ def set_up(env_id, role_arn):
"InstanceType": "ml.m5.large",
"VolumeSizeInGB": 1,
}
+ input_data_uri = f"s3://{bucket_name}/{raw_data_s3_key}"
processing_config = {
"ProcessingJobName": processing_job_name,
"ProcessingInputs": [
@@ -247,7 +242,7 @@ def set_up(env_id, role_arn):
"InputName": "input",
"AppManaged": False,
"S3Input": {
- "S3Uri": f"s3://{bucket_name}/{raw_data_s3_key}",
+ "S3Uri": input_data_uri,
"LocalPath": processing_local_input_path,
"S3DataType": "S3Prefix",
"S3InputMode": "File",
@@ -297,7 +292,7 @@ def set_up(env_id, role_arn):
},
"HyperParameters": {
"predictor_type": "classifier",
- "feature_dim": "4",
+ "feature_dim": "2",
"k": "3",
"sample_size": str(SAMPLE_SIZE),
},
@@ -357,7 +352,7 @@ def set_up(env_id, role_arn):
"TrainingJobDefinition": {
"StaticHyperParameters": {
"predictor_type": "classifier",
- "feature_dim": "4",
+ "feature_dim": "2",
},
"AlgorithmSpecification": {"TrainingImage": knn_image_uri, "TrainingInputMode": "File"},
"InputDataConfig": [
@@ -407,10 +402,13 @@ def set_up(env_id, role_arn):
ti.xcom_push(key="raw_data_s3_key", value=raw_data_s3_key)
ti.xcom_push(key="ecr_repository_name", value=ecr_repository_name)
ti.xcom_push(key="processing_config", value=processing_config)
+ ti.xcom_push(key="input_data_uri", value=input_data_uri)
+ ti.xcom_push(key="output_data_uri", value=f"s3://{bucket_name}/{training_output_s3_key}")
ti.xcom_push(key="training_config", value=training_config)
ti.xcom_push(key="training_job_name", value=training_job_name)
ti.xcom_push(key="model_package_group_name", value=model_package_group_name)
ti.xcom_push(key="pipeline_name", value=pipeline_name)
+ ti.xcom_push(key="auto_ml_job_name", value=auto_ml_job_name)
ti.xcom_push(key="model_config", value=model_config)
ti.xcom_push(key="model_name", value=model_name)
ti.xcom_push(key="inference_code_image", value=knn_image_uri)
@@ -499,10 +497,27 @@ with DAG(
task_id="upload_dataset",
s3_bucket=test_setup["bucket_name"],
s3_key=test_setup["raw_data_s3_key"],
- data=DATASET,
+ data=generate_data(),
replace=True,
)
+ # [START howto_operator_sagemaker_auto_ml]
+ automl = SageMakerAutoMLOperator(
+ task_id="auto_ML",
+ job_name=test_setup["auto_ml_job_name"],
+ s3_input=test_setup["input_data_uri"],
+ target_attribute="class",
+ s3_output=test_setup["output_data_uri"],
+ role_arn=test_context[ROLE_ARN_KEY],
+ time_limit=30, # will stop the job before it can do anything, but it's not the point here
+ )
+ # [END howto_operator_sagemaker_auto_ml]
+ automl.wait_for_completion = False # just to be able to test the sensor next
+
+ # [START howto_sensor_sagemaker_auto_ml]
+ await_automl = SageMakerAutoMLSensor(job_name=test_setup["auto_ml_job_name"], task_id="await_auto_ML")
+ # [END howto_sensor_sagemaker_auto_ml]
+
# [START howto_operator_sagemaker_start_pipeline]
start_pipeline1 = SageMakerStartPipelineOperator(
task_id="start_pipeline1",
@@ -625,6 +640,8 @@ with DAG(
create_bucket,
upload_dataset,
# TEST BODY
+ automl,
+ await_automl,
start_pipeline1,
start_pipeline2,
stop_pipeline1,