You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/28 13:42:49 UTC

[airflow] branch main updated: Fix Vertex AI Custom Job training issue (#25367)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a8e4519815 Fix Vertex AI Custom Job training issue (#25367)
a8e4519815 is described below

commit a8e451981572fa09a96660992e68e046c4baa75f
Author: Maksim <ma...@gmail.com>
AuthorDate: Thu Jul 28 16:42:42 2022 +0300

    Fix Vertex AI Custom Job training issue (#25367)
---
 .../google/cloud/hooks/vertex_ai/custom_job.py     | 35 ++++++++++------
 .../google/cloud/operators/vertex_ai/custom_job.py | 48 ++++++++++++++--------
 .../google/cloud/operators/test_vertex_ai.py       |  3 ++
 3 files changed, 58 insertions(+), 28 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
index bd69878685..1f35c4a6e8 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
@@ -246,6 +246,11 @@ class CustomJobHook(GoogleBaseHook):
         """Returns unique id of the Model."""
         return obj["name"].rpartition("/")[-1]
 
+    @staticmethod
+    def extract_training_id(resource_name: str) -> str:
+        """Returns unique id of the Training pipeline."""
+        return resource_name.rpartition("/")[-1]
+
     def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None):
         """Waits for long-lasting operation to complete."""
         try:
@@ -299,7 +304,7 @@ class CustomJobHook(GoogleBaseHook):
         timestamp_split_column_name: Optional[str] = None,
         tensorboard: Optional[str] = None,
         sync=True,
-    ) -> models.Model:
+    ) -> Tuple[Optional[models.Model], str]:
         """Run Job for training pipeline"""
         model = job.run(
             dataset=dataset,
@@ -329,11 +334,17 @@ class CustomJobHook(GoogleBaseHook):
             tensorboard=tensorboard,
             sync=sync,
         )
+        training_id = self.extract_training_id(job.resource_name)
         if model:
             model.wait()
-            return model
         else:
-            raise AirflowException("Training did not produce a Managed Model returning None.")
+            self.log.warning(
+                "Training did not produce a Managed Model returning None. Training Pipeline is not "
+                "configured to upload a Model. Create the Training Pipeline with "
+                "model_serving_container_image_uri and model_display_name passed in. "
+                "Ensure that your training script saves to model to os.environ['AIP_MODEL_DIR']."
+            )
+        return model, training_id
 
     @GoogleBaseHook.fallback_to_default_project_id
     def cancel_pipeline_job(
@@ -618,7 +629,7 @@ class CustomJobHook(GoogleBaseHook):
         timestamp_split_column_name: Optional[str] = None,
         tensorboard: Optional[str] = None,
         sync=True,
-    ) -> models.Model:
+    ) -> Tuple[Optional[models.Model], str]:
         """
         Create Custom Container Training Job
 
@@ -890,7 +901,7 @@ class CustomJobHook(GoogleBaseHook):
         if not self._job:
             raise AirflowException("CustomJob was not created")
 
-        model = self._run_job(
+        model, training_id = self._run_job(
             job=self._job,
             dataset=dataset,
             annotation_schema_uri=annotation_schema_uri,
@@ -920,7 +931,7 @@ class CustomJobHook(GoogleBaseHook):
             sync=sync,
         )
 
-        return model
+        return model, training_id
 
     @GoogleBaseHook.fallback_to_default_project_id
     def create_custom_python_package_training_job(
@@ -980,7 +991,7 @@ class CustomJobHook(GoogleBaseHook):
         timestamp_split_column_name: Optional[str] = None,
         tensorboard: Optional[str] = None,
         sync=True,
-    ) -> models.Model:
+    ) -> Tuple[Optional[models.Model], str]:
         """
         Create Custom Python Package Training Job
 
@@ -1252,7 +1263,7 @@ class CustomJobHook(GoogleBaseHook):
         if not self._job:
             raise AirflowException("CustomJob was not created")
 
-        model = self._run_job(
+        model, training_id = self._run_job(
             job=self._job,
             dataset=dataset,
             annotation_schema_uri=annotation_schema_uri,
@@ -1282,7 +1293,7 @@ class CustomJobHook(GoogleBaseHook):
             sync=sync,
         )
 
-        return model
+        return model, training_id
 
     @GoogleBaseHook.fallback_to_default_project_id
     def create_custom_training_job(
@@ -1342,7 +1353,7 @@ class CustomJobHook(GoogleBaseHook):
         timestamp_split_column_name: Optional[str] = None,
         tensorboard: Optional[str] = None,
         sync=True,
-    ) -> models.Model:
+    ) -> Tuple[Optional[models.Model], str]:
         """
         Create Custom Training Job
 
@@ -1614,7 +1625,7 @@ class CustomJobHook(GoogleBaseHook):
         if not self._job:
             raise AirflowException("CustomJob was not created")
 
-        model = self._run_job(
+        model, training_id = self._run_job(
             job=self._job,
             dataset=dataset,
             annotation_schema_uri=annotation_schema_uri,
@@ -1644,7 +1655,7 @@ class CustomJobHook(GoogleBaseHook):
             sync=sync,
         )
 
-        return model
+        return model, training_id
 
     @GoogleBaseHook.fallback_to_default_project_id
     def delete_pipeline_job(
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
index 6196bf3880..71e6ab9967 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
@@ -29,7 +29,11 @@ from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
-from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink
+from airflow.providers.google.cloud.links.vertex_ai import (
+    VertexAIModelLink,
+    VertexAITrainingLink,
+    VertexAITrainingPipelinesLink,
+)
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -411,7 +415,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
         'command',
         'impersonation_chain',
     ]
-    operator_extra_links = (VertexAIModelLink(),)
+    operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
 
     def __init__(
         self,
@@ -428,7 +432,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
-        model = self.hook.create_custom_container_training_job(
+        model, training_id = self.hook.create_custom_container_training_job(
             project_id=self.project_id,
             region=self.region,
             display_name=self.display_name,
@@ -478,9 +482,13 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
             sync=True,
         )
 
-        result = Model.to_dict(model)
-        model_id = self.hook.extract_model_id(result)
-        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+        if model:
+            result = Model.to_dict(model)
+            model_id = self.hook.extract_model_id(result)
+            VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+        else:
+            result = model  # type: ignore
+        VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
         return result
 
     def on_kill(self) -> None:
@@ -755,7 +763,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
         'region',
         'impersonation_chain',
     ]
-    operator_extra_links = (VertexAIModelLink(),)
+    operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
 
     def __init__(
         self,
@@ -774,7 +782,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
-        model = self.hook.create_custom_python_package_training_job(
+        model, training_id = self.hook.create_custom_python_package_training_job(
             project_id=self.project_id,
             region=self.region,
             display_name=self.display_name,
@@ -825,9 +833,13 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
             sync=True,
         )
 
-        result = Model.to_dict(model)
-        model_id = self.hook.extract_model_id(result)
-        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+        if model:
+            result = Model.to_dict(model)
+            model_id = self.hook.extract_model_id(result)
+            VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+        else:
+            result = model  # type: ignore
+        VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
         return result
 
     def on_kill(self) -> None:
@@ -1104,7 +1116,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
         'requirements',
         'impersonation_chain',
     ]
-    operator_extra_links = (VertexAIModelLink(),)
+    operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
 
     def __init__(
         self,
@@ -1123,7 +1135,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
-        model = self.hook.create_custom_training_job(
+        model, training_id = self.hook.create_custom_training_job(
             project_id=self.project_id,
             region=self.region,
             display_name=self.display_name,
@@ -1174,9 +1186,13 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
             sync=True,
         )
 
-        result = Model.to_dict(model)
-        model_id = self.hook.extract_model_id(result)
-        VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+        if model:
+            result = Model.to_dict(model)
+            model_id = self.hook.extract_model_id(result)
+            VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+        else:
+            result = model  # type: ignore
+        VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
         return result
 
     def on_kill(self) -> None:
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 51e6674344..34d27d87ae 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -170,6 +170,7 @@ TEST_OUTPUT_CONFIG = {
 class TestVertexAICreateCustomContainerTrainingJobOperator:
     @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
     def test_execute(self, mock_hook):
+        mock_hook.return_value.create_custom_container_training_job.return_value = (None, 'training_id')
         op = CreateCustomContainerTrainingJobOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -250,6 +251,7 @@ class TestVertexAICreateCustomContainerTrainingJobOperator:
 class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
     @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
     def test_execute(self, mock_hook):
+        mock_hook.return_value.create_custom_python_package_training_job.return_value = (None, 'training_id')
         op = CreateCustomPythonPackageTrainingJobOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,
@@ -332,6 +334,7 @@ class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
 class TestVertexAICreateCustomTrainingJobOperator:
     @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
     def test_execute(self, mock_hook):
+        mock_hook.return_value.create_custom_training_job.return_value = (None, 'training_id')
         op = CreateCustomTrainingJobOperator(
             task_id=TASK_ID,
             gcp_conn_id=GCP_CONN_ID,