You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/01/05 09:51:31 UTC

[airflow] branch main updated: Add AWS Sagemaker Auto ML operator and sensor (#28472)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e8533d295e Add AWS Sagemaker Auto ML operator and sensor (#28472)
e8533d295e is described below

commit e8533d295e6d25296e23d8e1b8c07a441df55964
Author: Raphaƫl Vandon <11...@users.noreply.github.com>
AuthorDate: Thu Jan 5 01:51:23 2023 -0800

    Add AWS Sagemaker Auto ML operator and sensor (#28472)
    
    * add an operator to create autoML jobs
---
 airflow/providers/amazon/aws/hooks/sagemaker.py    | 82 +++++++++++++++++++
 .../providers/amazon/aws/operators/sagemaker.py    | 91 ++++++++++++++++++++++
 airflow/providers/amazon/aws/sensors/sagemaker.py  | 32 ++++++++
 .../operators/sagemaker.rst                        | 30 +++++++
 tests/providers/amazon/aws/hooks/test_sagemaker.py | 47 +++++++++++
 .../amazon/aws/sensors/test_sagemaker_automl.py    | 63 +++++++++++++++
 .../providers/amazon/aws/example_sagemaker.py      | 77 +++++++++++-------
 7 files changed, 392 insertions(+), 30 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 6e880b8c91..17133829f8 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -1163,3 +1163,85 @@ class SageMakerHook(AwsBaseHook):
             else:
                 self.log.error("Error when trying to create Model Package Group: %s", e)
                 raise
+
+    def _describe_auto_ml_job(self, job_name: str):
+        res = self.conn.describe_auto_ml_job(AutoMLJobName=job_name)
+        self.log.info("%s's current step: %s", job_name, res["AutoMLJobSecondaryStatus"])
+        return res
+
+    def create_auto_ml_job(
+        self,
+        job_name: str,
+        s3_input: str,
+        target_attribute: str,
+        s3_output: str,
+        role_arn: str,
+        compressed_input: bool = False,
+        time_limit: int | None = None,
+        autodeploy_endpoint_name: str | None = None,
+        extras: dict | None = None,
+        wait_for_completion: bool = True,
+        check_interval: int = 30,
+    ) -> dict | None:
+        """
+        Creates an auto ML job, learning to predict the given column from the data provided through S3.
+        The learning output is written to the specified S3 location.
+
+        :param job_name: Name of the job to create, needs to be unique within the account.
+        :param s3_input: The S3 location (folder or file) where to fetch the data.
+            By default, it expects csv with headers.
+        :param target_attribute: The name of the column containing the values to predict.
+        :param s3_output: The S3 folder where to write the model artifacts. Must be 128 characters or fewer.
+        :param role_arn: The ARN or the IAM role to use when interacting with S3.
+            Must have read access to the input, and write access to the output folder.
+        :param compressed_input: Set to True if the input is gzipped.
+        :param time_limit: The maximum amount of time in seconds to spend training the model(s).
+        :param autodeploy_endpoint_name: If specified, the best model will be deployed to an endpoint with
+            that name. No deployment made otherwise.
+        :param extras: Use this dictionary to set any variable input variable for job creation that is not
+            offered through the parameters of this function. The format is described in:
+            https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
+        :param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
+        :param check_interval: Interval in seconds between 2 status checks when waiting for completion.
+
+        :returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that
+            of the "BestCandidate" key in:
+            https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
+        """
+        input_data = [
+            {
+                "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": s3_input}},
+                "TargetAttributeName": target_attribute,
+            },
+        ]
+        params_dict = {
+            "AutoMLJobName": job_name,
+            "InputDataConfig": input_data,
+            "OutputDataConfig": {"S3OutputPath": s3_output},
+            "RoleArn": role_arn,
+        }
+        if compressed_input:
+            input_data[0]["CompressionType"] = "Gzip"
+        if time_limit:
+            params_dict.update(
+                {"AutoMLJobConfig": {"CompletionCriteria": {"MaxAutoMLJobRuntimeInSeconds": time_limit}}}
+            )
+        if autodeploy_endpoint_name:
+            params_dict.update({"ModelDeployConfig": {"EndpointName": autodeploy_endpoint_name}})
+        if extras:
+            params_dict.update(extras)
+
+        # returns the job ARN, but we don't need it because we access it by its name
+        self.conn.create_auto_ml_job(**params_dict)
+
+        if wait_for_completion:
+            res = self.check_status(
+                job_name,
+                "AutoMLJobStatus",
+                # cannot pass the function directly because the parameter needs to be named
+                self._describe_auto_ml_job,
+                check_interval,
+            )
+            if "BestCandidate" in res:
+                return res["BestCandidate"]
+        return None
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 0335f8d15b..06d175a6d1 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -958,3 +958,94 @@ class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
             if group_created:
                 self.hook.conn.delete_model_package_group(ModelPackageGroupName=self.package_group_name)
             raise
+
+
+class SageMakerAutoMLOperator(SageMakerBaseOperator):
+    """
+    Creates an auto ML job, learning to predict the given column from the data provided through S3.
+    The learning output is written to the specified S3 location.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:SageMakerAutoMLOperator`
+
+    :param job_name: Name of the job to create, needs to be unique within the account.
+    :param s3_input: The S3 location (folder or file) where to fetch the data.
+        By default, it expects csv with headers.
+    :param target_attribute: The name of the column containing the values to predict.
+    :param s3_output: The S3 folder where to write the model artifacts. Must be 128 characters or fewer.
+    :param role_arn: The ARN of the IAM role to use when interacting with S3.
+        Must have read access to the input, and write access to the output folder.
+    :param compressed_input: Set to True if the input is gzipped.
+    :param time_limit: The maximum amount of time in seconds to spend training the model(s).
+    :param autodeploy_endpoint_name: If specified, the best model will be deployed to an endpoint with
+        that name. No deployment made otherwise.
+    :param extras: Use this dictionary to set any variable input variable for job creation that is not
+        offered through the parameters of this function. The format is described in:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
+    :param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
+    :param check_interval: Interval in seconds between 2 status checks when waiting for completion.
+
+    :returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that of
+        the "BestCandidate" key in:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
+    """
+
+    template_fields: Sequence[str] = (
+        "job_name",
+        "s3_input",
+        "target_attribute",
+        "s3_output",
+        "role_arn",
+        "compressed_input",
+        "time_limit",
+        "autodeploy_endpoint_name",
+        "extras",
+    )
+
+    def __init__(
+        self,
+        *,
+        job_name: str,
+        s3_input: str,
+        target_attribute: str,
+        s3_output: str,
+        role_arn: str,
+        compressed_input: bool = False,
+        time_limit: int | None = None,
+        autodeploy_endpoint_name: str | None = None,
+        extras: dict | None = None,
+        wait_for_completion: bool = True,
+        check_interval: int = 30,
+        aws_conn_id: str = DEFAULT_CONN_ID,
+        config: dict | None = None,
+        **kwargs,
+    ):
+        super().__init__(config=config or {}, aws_conn_id=aws_conn_id, **kwargs)
+        self.job_name = job_name
+        self.s3_input = s3_input
+        self.target_attribute = target_attribute
+        self.s3_output = s3_output
+        self.role_arn = role_arn
+        self.compressed_input = compressed_input
+        self.time_limit = time_limit
+        self.autodeploy_endpoint_name = autodeploy_endpoint_name
+        self.extras = extras
+        self.wait_for_completion = wait_for_completion
+        self.check_interval = check_interval
+
+    def execute(self, context: Context) -> dict | None:
+        best = self.hook.create_auto_ml_job(
+            self.job_name,
+            self.s3_input,
+            self.target_attribute,
+            self.s3_output,
+            self.role_arn,
+            self.compressed_input,
+            self.time_limit,
+            self.autodeploy_endpoint_name,
+            self.extras,
+            self.wait_for_completion,
+            self.check_interval,
+        )
+        return best
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py
index a93135b48c..f8527fbb2c 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -307,3 +307,35 @@ class SageMakerPipelineSensor(SageMakerBaseSensor):
 
     def state_from_response(self, response: dict) -> str:
         return response["PipelineExecutionStatus"]
+
+
+class SageMakerAutoMLSensor(SageMakerBaseSensor):
+    """
+    Polls the auto ML job until it reaches a terminal state.
+    Raises an AirflowException with the failure reason if a failed state is reached.
+
+    .. seealso::
+        For more information on how to use this sensor, take a look at the guide:
+        :ref:`howto/sensor:SageMakerAutoMLSensor`
+
+    :param job_name: unique name of the AutoML job to watch.
+    """
+
+    template_fields: Sequence[str] = ("job_name",)
+
+    def __init__(self, *, job_name: str, **kwargs):
+        super().__init__(resource_type="autoML job", **kwargs)
+        self.job_name = job_name
+
+    def non_terminal_states(self) -> set[str]:
+        return SageMakerHook.non_terminal_states
+
+    def failed_states(self) -> set[str]:
+        return SageMakerHook.failed_states
+
+    def get_sagemaker_response(self) -> dict:
+        self.log.info("Poking Sagemaker AutoML Execution %s", self.job_name)
+        return self.get_hook()._describe_auto_ml_job(self.job_name)
+
+    def state_from_response(self, response: dict) -> str:
+        return response["AutoMLJobStatus"]
diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
index 8d8319e17e..d33441fcd5 100644
--- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst
@@ -192,6 +192,22 @@ You can use this operator to add a new version and model package to the group fo
     :start-after: [START howto_operator_sagemaker_register]
     :end-before: [END howto_operator_sagemaker_register]
 
+.. _howto/operator:SageMakerAutoMLOperator:
+
+Launch an AutoML experiment
+===========================
+
+To launch an AutoML experiment, a.k.a. SageMaker Autopilot, you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerAutoMLOperator`.
+An AutoML experiment will take some input data in CSV and the column it should learn to predict,
+and train models on it without needing human supervision.
+The output is placed in an S3 bucket, and automatically deployed if configured for it.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_auto_ml]
+    :end-before: [END howto_operator_sagemaker_auto_ml]
+
 Sensors
 -------
 
@@ -265,6 +281,20 @@ you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerPip
     :start-after: [START howto_sensor_sagemaker_pipeline]
     :end-before: [END howto_sensor_sagemaker_pipeline]
 
+.. _howto/sensor:SageMakerAutoMLSensor:
+
+Wait on an Amazon SageMaker AutoML experiment state
+===================================================
+
+To check the state of an Amazon Sagemaker AutoML job until it reaches a terminal state
+you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerAutoMLSensor`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_auto_ml]
+    :end-before: [END howto_operator_sagemaker_auto_ml]
+
 Reference
 ---------
 
diff --git a/tests/providers/amazon/aws/hooks/test_sagemaker.py b/tests/providers/amazon/aws/hooks/test_sagemaker.py
index f2a97450fe..2aefe7a0f9 100644
--- a/tests/providers/amazon/aws/hooks/test_sagemaker.py
+++ b/tests/providers/amazon/aws/hooks/test_sagemaker.py
@@ -849,3 +849,50 @@ class TestSageMakerHook:
         created = hook.create_model_package_group("group-name")
 
         assert created is False
+
+    @patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock)
+    def test_create_auto_ml_parameter_structure(self, conn_mock):
+        hook = SageMakerHook()
+
+        hook.create_auto_ml_job(
+            job_name="a",
+            s3_input="b",
+            target_attribute="c",
+            s3_output="d",
+            role_arn="e",
+            compressed_input=True,
+            time_limit=30,
+            wait_for_completion=False,
+        )
+
+        assert conn_mock().create_auto_ml_job.call_args[1] == {
+            "AutoMLJobConfig": {"CompletionCriteria": {"MaxAutoMLJobRuntimeInSeconds": 30}},
+            "AutoMLJobName": "a",
+            "InputDataConfig": [
+                {
+                    "CompressionType": "Gzip",
+                    "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "b"}},
+                    "TargetAttributeName": "c",
+                }
+            ],
+            "OutputDataConfig": {"S3OutputPath": "d"},
+            "RoleArn": "e",
+        }
+
+    @patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock)
+    def test_create_auto_ml_waits_for_completion(self, conn_mock):
+        hook = SageMakerHook()
+        conn_mock().describe_auto_ml_job.side_effect = [
+            {"AutoMLJobStatus": "InProgress", "AutoMLJobSecondaryStatus": "a"},
+            {"AutoMLJobStatus": "InProgress", "AutoMLJobSecondaryStatus": "b"},
+            {
+                "AutoMLJobStatus": "Completed",
+                "AutoMLJobSecondaryStatus": "c",
+                "BestCandidate": {"name": "me"},
+            },
+        ]
+
+        ret = hook.create_auto_ml_job("a", "b", "c", "d", "e", check_interval=0)
+
+        assert conn_mock().describe_auto_ml_job.call_count == 3
+        assert ret == {"name": "me"}
diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_automl.py b/tests/providers/amazon/aws/sensors/test_sagemaker_automl.py
new file mode 100644
index 0000000000..454d8ab1be
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_sagemaker_automl.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerAutoMLSensor
+
+
+class TestSageMakerAutoMLSensor:
+    @staticmethod
+    def get_response_with_state(state: str):
+        states = {"Completed", "InProgress", "Failed", "Stopped", "Stopping"}
+        assert state in states
+        return {
+            "AutoMLJobStatus": state,
+            "AutoMLJobSecondaryStatus": "Starting",
+            "ResponseMetadata": {
+                "HTTPStatusCode": 200,
+            },
+        }
+
+    @mock.patch.object(SageMakerHook, "_describe_auto_ml_job")
+    def test_sensor_with_failure(self, mock_describe):
+        mock_describe.return_value = self.get_response_with_state("Failed")
+        sensor = SageMakerAutoMLSensor(job_name="job_job", task_id="test_task")
+
+        with pytest.raises(AirflowException):
+            sensor.execute(None)
+
+        mock_describe.assert_called_once_with("job_job")
+
+    @mock.patch.object(SageMakerHook, "_describe_auto_ml_job")
+    def test_sensor(self, mock_describe):
+        mock_describe.side_effect = [
+            self.get_response_with_state("InProgress"),
+            self.get_response_with_state("Stopping"),
+            self.get_response_with_state("Stopped"),
+        ]
+        sensor = SageMakerAutoMLSensor(job_name="job_job", task_id="test_task", poke_interval=0)
+
+        sensor.execute(None)
+
+        assert mock_describe.call_count == 3
diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py b/tests/system/providers/amazon/aws/example_sagemaker.py
index ab13b02fa2..009352d234 100644
--- a/tests/system/providers/amazon/aws/example_sagemaker.py
+++ b/tests/system/providers/amazon/aws/example_sagemaker.py
@@ -35,6 +35,7 @@ from airflow.providers.amazon.aws.operators.s3 import (
     S3DeleteBucketOperator,
 )
 from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerAutoMLOperator,
     SageMakerDeleteModelOperator,
     SageMakerModelOperator,
     SageMakerProcessingOperator,
@@ -46,6 +47,7 @@ from airflow.providers.amazon.aws.operators.sagemaker import (
     SageMakerTuningOperator,
 )
 from airflow.providers.amazon.aws.sensors.sagemaker import (
+    SageMakerAutoMLSensor,
     SageMakerPipelineSensor,
     SageMakerTrainingSensor,
     SageMakerTransformSensor,
@@ -71,18 +73,7 @@ KNN_IMAGES_BY_REGION = {
     "us-west-2": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1",
 }
 
-# For this example we are using a subset of Fischer's Iris Data Set.
-# The full dataset can be found at UC Irvine's machine learning repository:
-# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
-DATASET = """
-        5.1,3.5,1.4,0.2,Iris-setosa
-        4.9,3.0,1.4,0.2,Iris-setosa
-        7.0,3.2,4.7,1.4,Iris-versicolor
-        6.4,3.2,4.5,1.5,Iris-versicolor
-        4.9,2.5,4.5,1.7,Iris-virginica
-        7.3,2.9,6.3,1.8,Iris-virginica
-        """
-SAMPLE_SIZE = DATASET.count("\n") - 1
+SAMPLE_SIZE = 600
 
 # This script will be the entrypoint for the docker image which will handle preprocessing the raw data
 # NOTE:  The following string must remain dedented as it is being written to a file.
@@ -92,34 +83,28 @@ import numpy as np
 import pandas as pd
 
 def main():
-    # Load the Iris dataset from {input_path}/input.csv, split it into train/test
+    # Load the dataset from {input_path}/input.csv, split it into train/test
     # subsets, and write them to {output_path}/ for the Processing Operator.
 
-    columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
-    iris = pd.read_csv('{input_path}/input.csv', names=columns)
-
-    # Process data
-    iris['species'] = iris['species'].replace({{'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}})
-    iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']]
+    data = pd.read_csv('{input_path}/input.csv')
 
     # Split into test and train data
-    iris_train, iris_test = np.split(
-        iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))]
+    data_train, data_test = np.split(
+        data.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(data))]
     )
 
     # Remove the "answers" from the test set
-    iris_test.drop(['species'], axis=1, inplace=True)
+    data_test.drop(['class'], axis=1, inplace=True)
 
     # Write the splits to disk
-    iris_train.to_csv('{output_path}/train.csv', index=False, header=False)
-    iris_test.to_csv('{output_path}/test.csv', index=False, header=False)
+    data_train.to_csv('{output_path}/train.csv', index=False, header=False)
+    data_test.to_csv('{output_path}/test.csv', index=False, header=False)
 
     print('Preprocessing Done.')
 
 if __name__ == "__main__":
     main()
-
-    """
+"""
 
 
 def _create_ecr_repository(repo_name):
@@ -195,6 +180,14 @@ def _build_and_upload_docker_image(preprocess_script, repository_uri):
             )
 
 
+def generate_data() -> str:
+    """generates a very simple csv dataset with headers"""
+    content = "class,x,y\n"  # headers
+    for i in range(SAMPLE_SIZE):
+        content += f"{i%100},{i},{SAMPLE_SIZE-i}\n"
+    return content
+
+
 @task
 def set_up(env_id, role_arn):
     bucket_name = f"{env_id}-sagemaker-example"
@@ -206,6 +199,7 @@ def set_up(env_id, role_arn):
     tuning_job_name = f"{env_id}-tune"
     model_package_group_name = f"{env_id}-group"
     pipeline_name = f"{env_id}-pipe"
+    auto_ml_job_name = f"{env_id}-automl"
 
     input_data_S3_key = f"{env_id}/processed-input-data"
     prediction_output_s3_key = f"{env_id}/transform"
@@ -240,6 +234,7 @@ def set_up(env_id, role_arn):
         "InstanceType": "ml.m5.large",
         "VolumeSizeInGB": 1,
     }
+    input_data_uri = f"s3://{bucket_name}/{raw_data_s3_key}"
     processing_config = {
         "ProcessingJobName": processing_job_name,
         "ProcessingInputs": [
@@ -247,7 +242,7 @@ def set_up(env_id, role_arn):
                 "InputName": "input",
                 "AppManaged": False,
                 "S3Input": {
-                    "S3Uri": f"s3://{bucket_name}/{raw_data_s3_key}",
+                    "S3Uri": input_data_uri,
                     "LocalPath": processing_local_input_path,
                     "S3DataType": "S3Prefix",
                     "S3InputMode": "File",
@@ -297,7 +292,7 @@ def set_up(env_id, role_arn):
         },
         "HyperParameters": {
             "predictor_type": "classifier",
-            "feature_dim": "4",
+            "feature_dim": "2",
             "k": "3",
             "sample_size": str(SAMPLE_SIZE),
         },
@@ -357,7 +352,7 @@ def set_up(env_id, role_arn):
         "TrainingJobDefinition": {
             "StaticHyperParameters": {
                 "predictor_type": "classifier",
-                "feature_dim": "4",
+                "feature_dim": "2",
             },
             "AlgorithmSpecification": {"TrainingImage": knn_image_uri, "TrainingInputMode": "File"},
             "InputDataConfig": [
@@ -407,10 +402,13 @@ def set_up(env_id, role_arn):
     ti.xcom_push(key="raw_data_s3_key", value=raw_data_s3_key)
     ti.xcom_push(key="ecr_repository_name", value=ecr_repository_name)
     ti.xcom_push(key="processing_config", value=processing_config)
+    ti.xcom_push(key="input_data_uri", value=input_data_uri)
+    ti.xcom_push(key="output_data_uri", value=f"s3://{bucket_name}/{training_output_s3_key}")
     ti.xcom_push(key="training_config", value=training_config)
     ti.xcom_push(key="training_job_name", value=training_job_name)
     ti.xcom_push(key="model_package_group_name", value=model_package_group_name)
     ti.xcom_push(key="pipeline_name", value=pipeline_name)
+    ti.xcom_push(key="auto_ml_job_name", value=auto_ml_job_name)
     ti.xcom_push(key="model_config", value=model_config)
     ti.xcom_push(key="model_name", value=model_name)
     ti.xcom_push(key="inference_code_image", value=knn_image_uri)
@@ -499,10 +497,27 @@ with DAG(
         task_id="upload_dataset",
         s3_bucket=test_setup["bucket_name"],
         s3_key=test_setup["raw_data_s3_key"],
-        data=DATASET,
+        data=generate_data(),
         replace=True,
     )
 
+    # [START howto_operator_sagemaker_auto_ml]
+    automl = SageMakerAutoMLOperator(
+        task_id="auto_ML",
+        job_name=test_setup["auto_ml_job_name"],
+        s3_input=test_setup["input_data_uri"],
+        target_attribute="class",
+        s3_output=test_setup["output_data_uri"],
+        role_arn=test_context[ROLE_ARN_KEY],
+        time_limit=30,  # will stop the job before it can do anything, but it's not the point here
+    )
+    # [END howto_operator_sagemaker_auto_ml]
+    automl.wait_for_completion = False  # just to be able to test the sensor next
+
+    # [START howto_sensor_sagemaker_auto_ml]
+    await_automl = SageMakerAutoMLSensor(job_name=test_setup["auto_ml_job_name"], task_id="await_auto_ML")
+    # [END howto_sensor_sagemaker_auto_ml]
+
     # [START howto_operator_sagemaker_start_pipeline]
     start_pipeline1 = SageMakerStartPipelineOperator(
         task_id="start_pipeline1",
@@ -625,6 +640,8 @@ with DAG(
         create_bucket,
         upload_dataset,
         # TEST BODY
+        automl,
+        await_automl,
         start_pipeline1,
         start_pipeline2,
         stop_pipeline1,