You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2023/01/03 18:12:34 UTC

[GitHub] [airflow] vandonr-amz commented on a diff in pull request #28472: Add AWS Sagemaker Auto ML operator and sensor

vandonr-amz commented on code in PR #28472:
URL: https://github.com/apache/airflow/pull/28472#discussion_r1060836486


##########
airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -958,3 +958,96 @@ def execute(self, context: Context):
             if group_created:
                 self.hook.conn.delete_model_package_group(ModelPackageGroupName=self.package_group_name)
             raise
+
+
+class SageMakerAutoMLOperator(SageMakerBaseOperator):
+    """
+    Creates an auto ML job, learning to predict the given column from the data provided through S3.
+    The learning output is written to the specified S3 location.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:SageMakerAutoMLOperator`
+
+    :param job_name: Name of the job to create, needs to be unique within the account.
+    :param s3_input: The S3 location (folder or file) where to fetch the data.
+        By default, it expects csv with headers.
+    :param target_attribute: The name of the column containing the values to predict.
+    :param s3_output: The S3 folder where to write the model artifacts. Must be 128 characters or fewer.
+    :param role_arn: The ARN of the IAM role to use when interacting with S3.
+        Must have read access to the input, and write access to the output folder.
+    :param compressed_input: Set to True if the input is gzipped.
+    :param time_limit: The maximum amount of time in seconds to spend training the model(s).
+    :param autodeploy_endpoint_name: If specified, the best model will be deployed to an endpoint with
+        that name. No deployment made otherwise.
+    :param extras: Use this dictionary to set any variable input variable for job creation that is not
+        offered through the parameters of this function. The format is described in:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
+    :param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
+    :param check_interval: Interval in seconds between 2 status checks when waiting for completion.
+
+    :returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that of
+        the "BestCandidate" key in:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
+    """
+
+    template_fields: Sequence[str] = (
+        "job_name",
+        "s3_input",
+        "target_attribute",
+        "s3_output",
+        "role_arn",
+        "compressed_input",
+        "time_limit",
+        "autodeploy_endpoint_name",
+        "extras",
+        "wait_for_completion",
+        "check_interval",
+    )
+
+    def __init__(
+        self,
+        *,
+        job_name: str,
+        s3_input: str,
+        target_attribute: str,
+        s3_output: str,
+        role_arn: str,
+        compressed_input: bool = False,
+        time_limit: int | None = None,
+        autodeploy_endpoint_name: str | None = None,
+        extras: dict | None = None,
+        wait_for_completion: bool = True,
+        check_interval: int = 30,
+        aws_conn_id: str = DEFAULT_CONN_ID,
+        config: dict | None = None,
+        **kwargs,
+    ):
+        super().__init__(config=config or {}, aws_conn_id=aws_conn_id, **kwargs)
+        self.job_name = job_name
+        self.s3_input = s3_input
+        self.target_attribute = target_attribute
+        self.s3_output = s3_output
+        self.role_arn = role_arn
+        self.compressed_input = compressed_input
+        self.time_limit = time_limit
+        self.autodeploy_endpoint_name = autodeploy_endpoint_name
+        self.extras = extras
+        self.wait_for_completion = wait_for_completion
+        self.check_interval = check_interval
+
+    def execute(self, context: Context) -> dict | None:
+        best = self.hook.create_auto_ml_job(
+            self.job_name,
+            self.s3_input,
+            self.target_attribute,
+            self.s3_output,
+            self.role_arn,
+            self.compressed_input,
+            self.time_limit,
+            self.autodeploy_endpoint_name,
+            self.extras,
+            self.wait_for_completion,
+            self.check_interval,
+        )
+        return best

Review Comment:
   well, best is going to be empty if we don't wait for completion. If we fire-and-forget, we don't really have anything to return. Do you think it would be a problem ?
   It's documented in the docstring that this function returns something only if waiting for completion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org