You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mo...@apache.org on 2023/08/02 10:59:40 UTC

[airflow] branch openlineage-sagemaker-operators updated (21da40d741 -> a7bccaadd7)

This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a change to branch openlineage-sagemaker-operators
in repository https://gitbox.apache.org/repos/asf/airflow.git


 discard 21da40d741 openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, Transform and Training operators
     add 0c894dbb24 Handle multiple connections using exceptions (#32365)
     add b45cc1493c Rephrase scheduler process doc (#32983)
     add 6ada88a407 Always show gantt and code tabs (#33029)
     add 9cbe494e23 Change log level from ERROR to INFO (#32979)
     new a7bccaadd7 openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, Transform and Training operators

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (21da40d741)
            \
             N -- N -- N   refs/heads/openlineage-sagemaker-operators (a7bccaadd7)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 airflow/jobs/triggerer_job_runner.py               |   2 +-
 .../providers/google/cloud/hooks/compute_ssh.py    |  79 ++++++++-----
 airflow/www/static/js/dag/details/gantt/index.tsx  |   8 +-
 airflow/www/static/js/dag/details/index.tsx        |  68 +++++-------
 .../core-concepts/executor/index.rst               |   2 +-
 .../google/cloud/hooks/test_compute_ssh.py         | 123 ++++++++++++++++++++-
 .../google/cloud/compute/example_compute_ssh.py    |   4 +-
 ...pute_ssh.py => example_compute_ssh_os_login.py} |  38 ++++---
 ...pute_ssh.py => example_compute_ssh_parallel.py} |  33 +++---
 9 files changed, 254 insertions(+), 103 deletions(-)
 copy tests/system/providers/google/cloud/compute/{example_compute_ssh.py => example_compute_ssh_os_login.py} (86%)
 copy tests/system/providers/google/cloud/compute/{example_compute_ssh.py => example_compute_ssh_parallel.py} (86%)


[airflow] 01/01: openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, Transform and Training operators

Posted by mo...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-sagemaker-operators
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a7bccaadd7bf31cd471dfc1f022a77093403a5c9
Author: Maciej Obuchowski <ob...@gmail.com>
AuthorDate: Tue Aug 1 16:59:19 2023 +0200

    openlineage, sagemaker: add OpenLineage support for SageMaker's Processing, Transform and Training operators
    
    Signed-off-by: Maciej Obuchowski <ob...@gmail.com>
---
 .../providers/amazon/aws/operators/sagemaker.py    | 156 +++++++++++++++++++--
 dev/breeze/tests/test_selective_checks.py          |   9 +-
 generated/provider_dependencies.json               |   1 +
 .../aws/operators/test_sagemaker_processing.py     |  50 ++++++-
 .../aws/operators/test_sagemaker_training.py       |  32 ++++-
 .../aws/operators/test_sagemaker_transform.py      |  30 ++++
 6 files changed, 251 insertions(+), 27 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 83a1e4f3d2..7e1a0a3567 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -40,13 +40,16 @@ from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.utils.json import AirflowJsonEncoder
 
 if TYPE_CHECKING:
+    from openlineage.client.run import Dataset
+
+    from airflow.providers.openlineage.extractors.base import OperatorLineage
     from airflow.utils.context import Context
 
 DEFAULT_CONN_ID: str = "aws_default"
 CHECK_INTERVAL_SECOND: int = 30
 
 
-def serialize(result: dict) -> str:
+def serialize(result: dict) -> dict:
     return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
 
 
@@ -158,6 +161,14 @@ class SageMakerBaseOperator(BaseOperator):
         """Return SageMakerHook."""
         return SageMakerHook(aws_conn_id=self.aws_conn_id)
 
+    @staticmethod
+    def path_to_s3_dataset(path) -> Dataset:
+        from openlineage.client.run import Dataset
+
+        path = path.replace("s3://", "")
+        split_path = path.split("/")
+        return Dataset(namespace=f"s3://{split_path[0]}", name="/".join(split_path[1:]), facets={})
+
 
 class SageMakerProcessingOperator(SageMakerBaseOperator):
     """
@@ -225,6 +236,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
         self.deferrable = deferrable
+        self.serialized_job: dict
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -282,14 +294,48 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return {"Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        return {"Processing": self.serialized_job}
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        return {"Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+        self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+        return {"Processing": self.serialized_job}
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
+        """Returns OpenLineage data gathered from SageMaker's API response saved by processing job."""
+        from airflow.providers.openlineage.extractors.base import OperatorLineage
+
+        inputs = []
+        outputs = []
+        try:
+            inputs, outputs = self._extract_s3_dataset_identifiers(
+                processing_inputs=self.serialized_job["ProcessingInputs"],
+                processing_outputs=self.serialized_job["ProcessingOutputConfig"]["Outputs"],
+            )
+        except KeyError:
+            self.log.exception("Could not find input/output information in Xcom.")
+
+        return OperatorLineage(inputs=inputs, outputs=outputs)
+
+    def _extract_s3_dataset_identifiers(self, processing_inputs, processing_outputs):
+        inputs = []
+        outputs = []
+        try:
+            for processing_input in processing_inputs:
+                inputs.append(self.path_to_s3_dataset(processing_input["S3Input"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Cannot find S3 input details", exc_info=True)
+
+        try:
+            for processing_output in processing_outputs:
+                outputs.append(self.path_to_s3_dataset(processing_output["S3Output"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Cannot find S3 output details.", exc_info=True)
+        return inputs, outputs
 
 
 class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
@@ -579,6 +625,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.serialized_model: dict
+        self.serialized_tranform: dict
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -650,10 +698,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return {
-            "Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
-            "Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
-        }
+        self.serialized_model = serialize(self.hook.describe_model(transform_config["ModelName"]))
+        self.serialized_tranform = serialize(
+            self.hook.describe_transform_job(transform_config["TransformJobName"])
+        )
+        return {"Model": self.serialized_model, "Transform": self.serialized_tranform}
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
@@ -661,10 +710,62 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         else:
             self.log.info(event["message"])
         transform_config = self.config.get("Transform", self.config)
-        return {
-            "Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
-            "Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
-        }
+        self.serialized_model = serialize(self.hook.describe_model(transform_config["ModelName"]))
+        self.serialized_tranform = serialize(
+            self.hook.describe_transform_job(transform_config["TransformJobName"])
+        )
+        return {"Model": self.serialized_model, "Transform": self.serialized_tranform}
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
+        """Returns OpenLineage data gathered from SageMaker's API response saved by transform job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        model_package_arn = None
+        transform_input = None
+        transform_output = None
+
+        try:
+            model_package_arn = self.serialized_model["PrimaryContainer"]["ModelPackageName"]
+        except KeyError:
+            self.log.error("Cannot find Model Package Name.", exc_info=True)
+
+        try:
+            transform_input = self.serialized_tranform["TransformInput"]["DataSource"]["S3DataSource"][
+                "S3Uri"
+            ]
+            transform_output = self.serialized_tranform["TransformOutput"]["S3OutputPath"]
+        except KeyError:
+            self.log.error("Cannot find some required input/output details.", exc_info=True)
+
+        inputs = []
+
+        if transform_input is not None:
+            inputs.append(self.path_to_s3_dataset(transform_input))
+
+        if model_package_arn is not None:
+            model_data_urls = self._get_model_data_urls(model_package_arn)
+            for model_data_url in model_data_urls:
+                inputs.append(self.path_to_s3_dataset(model_data_url))
+
+        outputs = []
+        if transform_output is not None:
+            outputs.append(self.path_to_s3_dataset(transform_output))
+
+        return OperatorLineage(inputs=inputs, outputs=outputs)
+
+    def _get_model_data_urls(self, model_package_arn) -> list:
+        model_data_urls = []
+        try:
+            model_containers = self.hook.get_conn().describe_model_package(
+                ModelPackageName=model_package_arn
+            )["InferenceSpecification"]["Containers"]
+
+            for container in model_containers:
+                model_data_urls.append(container["ModelDataUrl"])
+        except KeyError:
+            self.log.exception("Cannot retrieve model details.", exc_info=True)
+
+        return model_data_urls
 
 
 class SageMakerTuningOperator(SageMakerBaseOperator):
@@ -891,6 +992,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 Provided value: '{action_if_job_exists}'."
             )
         self.deferrable = deferrable
+        self.serialized_training_data: dict
 
     def expand_role(self) -> None:
         """Expands an IAM role name into an ARN."""
@@ -951,16 +1053,40 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        result = {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
-        return result
+        self.serialized_training_data = serialize(
+            self.hook.describe_training_job(self.config["TrainingJobName"])
+        )
+        return {"Training": self.serialized_training_data}
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
         else:
             self.log.info(event["message"])
-        result = {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
-        return result
+        self.serialized_training_data = serialize(
+            self.hook.describe_training_job(self.config["TrainingJobName"])
+        )
+        return {"Training": self.serialized_training_data}
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
+        """Returns OpenLineage data gathered from SageMaker's API response saved by training job."""
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        inputs = []
+        outputs = []
+        try:
+            for input_data in self.serialized_training_data["InputDataConfig"]:
+                inputs.append(self.path_to_s3_dataset(input_data["DataSource"]["S3DataSource"]["S3Uri"]))
+        except KeyError:
+            self.log.exception("Issues extracting inputs.")
+
+        try:
+            outputs.append(
+                self.path_to_s3_dataset(self.serialized_training_data["ModelArtifacts"]["S3ModelArtifacts"])
+            )
+        except KeyError:
+            self.log.exception("Issues extracting inputs.")
+        return OperatorLineage(inputs=inputs, outputs=outputs)
 
 
 class SageMakerDeleteModelOperator(SageMakerBaseOperator):
diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py
index e3f9bacb98..3cdf3dc7bb 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -312,7 +312,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
             {
                 "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes "
                 "common.sql exasol ftp google http imap microsoft.azure "
-                "mongo mysql postgres salesforce ssh",
+                "mongo mysql openlineage postgres salesforce ssh",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "python-versions": "['3.8']",
@@ -326,7 +326,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
                 "run-amazon-tests": "true",
                 "parallel-test-types-list-as-string": "Providers[amazon] Always "
                 "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap,microsoft.azure,"
-                "mongo,mysql,postgres,salesforce,ssh] Providers[google]",
+                "mongo,mysql,openlineage,postgres,salesforce,ssh] Providers[google]",
             },
             id="Providers tests run including amazon tests if amazon provider files changed",
         ),
@@ -354,7 +354,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
             {
                 "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes "
                 "common.sql exasol ftp google http imap microsoft.azure "
-                "mongo mysql postgres salesforce ssh",
+                "mongo mysql openlineage postgres salesforce ssh",
                 "all-python-versions": "['3.8']",
                 "all-python-versions-list-as-string": "3.8",
                 "python-versions": "['3.8']",
@@ -368,7 +368,8 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
                 "upgrade-to-newer-dependencies": "false",
                 "parallel-test-types-list-as-string": "Providers[amazon] Always "
                 "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,"
-                "http,imap,microsoft.azure,mongo,mysql,postgres,salesforce,ssh] Providers[google]",
+                "http,imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh] "
+                "Providers[google]",
             },
             id="Providers tests run including amazon tests if amazon provider files changed",
         ),
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 42146c2767..02b5d59a13 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -47,6 +47,7 @@
       "imap",
       "microsoft.azure",
       "mongo",
+      "openlineage",
       "salesforce",
       "ssh"
     ],
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 1d73d44bdf..817761c014 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -20,12 +20,17 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import (
+    SageMakerBaseOperator,
+    SageMakerProcessingOperator,
+)
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 CREATE_PROCESSING_PARAMS: dict = {
     "AppSpecification": {
@@ -238,14 +243,16 @@ class TestSageMakerProcessingOperator:
                 action_if_job_exists="not_fail_or_increment",
             )
 
-    @mock.patch.object(SageMakerHook, "create_processing_job")
-    @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
-    def test_operator_defer(self, mock_job_exists, mock_processing):
-        mock_processing.return_value = {
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={
             "ProcessingJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
-        }
-        mock_job_exists.return_value = False
+        },
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False)
+    def test_operator_defer(self, mock_job_exists, mock_processing):
         sagemaker_operator = SageMakerProcessingOperator(
             **self.processing_config_kwargs,
             config=CREATE_PROCESSING_PARAMS,
@@ -255,3 +262,32 @@ class TestSageMakerProcessingOperator:
         with pytest.raises(TaskDeferred) as exc:
             sagemaker_operator.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger"
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_processing_job",
+        return_value={
+            "ProcessingInputs": [{"S3Input": {"S3Uri": "s3://input-bucket/input-path"}}],
+            "ProcessingOutputConfig": {
+                "Outputs": [{"S3Output": {"S3Uri": "s3://output-bucket/output-path"}}]
+            },
+        },
+    )
+    @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
+    @mock.patch.object(
+        SageMakerHook,
+        "create_processing_job",
+        return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False)
+    def test_operator_openlineage_data(self, check_job_exists, mock_processing, _, mock_desc):
+        sagemaker = SageMakerProcessingOperator(
+            **self.processing_config_kwargs,
+            config=CREATE_PROCESSING_PARAMS,
+            deferrable=True,
+        )
+        sagemaker.execute(context=None)
+        assert sagemaker.get_openlineage_facets_on_complete(None) == OperatorLineage(
+            inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+            outputs=[Dataset(namespace="s3://output-bucket", name="output-path")],
+        )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index e551317d33..9cb50de7c8 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -20,12 +20,14 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator
+from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator, SageMakerTrainingOperator
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["ResourceConfig", "InstanceCount"],
@@ -127,3 +129,31 @@ class TestSageMakerTrainingOperator:
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger"
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job",
+        return_value={
+            "InputDataConfig": [
+                {
+                    "DataSource": {"S3DataSource": {"S3Uri": "s3://input-bucket/input-path"}},
+                }
+            ],
+            "ModelArtifacts": {"S3ModelArtifacts": "s3://model-bucket/model-path"},
+        },
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "create_training_job",
+        return_value={
+            "TrainingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        },
+    )
+    @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists", return_value=False)
+    def test_execute_openlineage_data(self, mock_exists, mock_training, mock_desc):
+        self.sagemaker.execute(None)
+        assert self.sagemaker.get_openlineage_facets_on_complete(None) == OperatorLineage(
+            inputs=[Dataset(namespace="s3://input-bucket", name="input-path")],
+            outputs=[Dataset(namespace="s3://model-bucket", name="model-path")],
+        )
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 76a4d877b6..9a9af38b36 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -22,12 +22,14 @@ from unittest import mock
 
 import pytest
 from botocore.exceptions import ClientError
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
 from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["Transform", "TransformResources", "InstanceCount"],
@@ -178,3 +180,31 @@ class TestSageMakerTransformOperator:
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger"
+
+    @mock.patch.object(SageMakerHook, "describe_transform_job")
+    @mock.patch.object(SageMakerHook, "create_model")
+    @mock.patch.object(SageMakerHook, "describe_model")
+    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "create_transform_job")
+    def test_operator_lineage_data(self, mock_transform, mock_conn, mock_model, _, mock_desc):
+        self.sagemaker.check_if_job_exists = False
+        mock_conn.return_value.describe_model_package.return_value = {
+            "InferenceSpecification": {"Containers": [{"ModelDataUrl": "s3://model-bucket/model-path"}]},
+        }
+        mock_model.return_value = {"PrimaryContainer": {"ModelPackageName": "package-name"}}
+        mock_desc.return_value = {
+            "TransformInput": {"DataSource": {"S3DataSource": {"S3Uri": "s3://input-bucket/input-path"}}},
+            "TransformOutput": {"S3OutputPath": "s3://output-bucket/output-path"},
+        }
+        mock_transform.return_value = {
+            "TransformJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.execute(None)
+        assert self.sagemaker.get_openlineage_facets_on_complete(None) == OperatorLineage(
+            inputs=[
+                Dataset(namespace="s3://input-bucket", name="input-path"),
+                Dataset(namespace="s3://model-bucket", name="model-path"),
+            ],
+            outputs=[Dataset(namespace="s3://output-bucket", name="output-path")],
+        )