You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/28 13:42:49 UTC
[airflow] branch main updated: Fix Vertex AI Custom Job training issue (#25367)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a8e4519815 Fix Vertex AI Custom Job training issue (#25367)
a8e4519815 is described below
commit a8e451981572fa09a96660992e68e046c4baa75f
Author: Maksim <ma...@gmail.com>
AuthorDate: Thu Jul 28 16:42:42 2022 +0300
Fix Vertex AI Custom Job training issue (#25367)
---
.../google/cloud/hooks/vertex_ai/custom_job.py | 35 ++++++++++------
.../google/cloud/operators/vertex_ai/custom_job.py | 48 ++++++++++++++--------
.../google/cloud/operators/test_vertex_ai.py | 3 ++
3 files changed, 58 insertions(+), 28 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
index bd69878685..1f35c4a6e8 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
@@ -246,6 +246,11 @@ class CustomJobHook(GoogleBaseHook):
"""Returns unique id of the Model."""
return obj["name"].rpartition("/")[-1]
+ @staticmethod
+ def extract_training_id(resource_name: str) -> str:
+ """Returns unique id of the Training pipeline."""
+ return resource_name.rpartition("/")[-1]
+
def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None):
"""Waits for long-lasting operation to complete."""
try:
@@ -299,7 +304,7 @@ class CustomJobHook(GoogleBaseHook):
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
- ) -> models.Model:
+ ) -> Tuple[Optional[models.Model], str]:
"""Run Job for training pipeline"""
model = job.run(
dataset=dataset,
@@ -329,11 +334,17 @@ class CustomJobHook(GoogleBaseHook):
tensorboard=tensorboard,
sync=sync,
)
+ training_id = self.extract_training_id(job.resource_name)
if model:
model.wait()
- return model
else:
- raise AirflowException("Training did not produce a Managed Model returning None.")
+ self.log.warning(
+ "Training did not produce a Managed Model returning None. Training Pipeline is not "
+ "configured to upload a Model. Create the Training Pipeline with "
+ "model_serving_container_image_uri and model_display_name passed in. "
+ "Ensure that your training script saves to model to os.environ['AIP_MODEL_DIR']."
+ )
+ return model, training_id
@GoogleBaseHook.fallback_to_default_project_id
def cancel_pipeline_job(
@@ -618,7 +629,7 @@ class CustomJobHook(GoogleBaseHook):
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
- ) -> models.Model:
+ ) -> Tuple[Optional[models.Model], str]:
"""
Create Custom Container Training Job
@@ -890,7 +901,7 @@ class CustomJobHook(GoogleBaseHook):
if not self._job:
raise AirflowException("CustomJob was not created")
- model = self._run_job(
+ model, training_id = self._run_job(
job=self._job,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
@@ -920,7 +931,7 @@ class CustomJobHook(GoogleBaseHook):
sync=sync,
)
- return model
+ return model, training_id
@GoogleBaseHook.fallback_to_default_project_id
def create_custom_python_package_training_job(
@@ -980,7 +991,7 @@ class CustomJobHook(GoogleBaseHook):
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
- ) -> models.Model:
+ ) -> Tuple[Optional[models.Model], str]:
"""
Create Custom Python Package Training Job
@@ -1252,7 +1263,7 @@ class CustomJobHook(GoogleBaseHook):
if not self._job:
raise AirflowException("CustomJob was not created")
- model = self._run_job(
+ model, training_id = self._run_job(
job=self._job,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
@@ -1282,7 +1293,7 @@ class CustomJobHook(GoogleBaseHook):
sync=sync,
)
- return model
+ return model, training_id
@GoogleBaseHook.fallback_to_default_project_id
def create_custom_training_job(
@@ -1342,7 +1353,7 @@ class CustomJobHook(GoogleBaseHook):
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
- ) -> models.Model:
+ ) -> Tuple[Optional[models.Model], str]:
"""
Create Custom Training Job
@@ -1614,7 +1625,7 @@ class CustomJobHook(GoogleBaseHook):
if not self._job:
raise AirflowException("CustomJob was not created")
- model = self._run_job(
+ model, training_id = self._run_job(
job=self._job,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
@@ -1644,7 +1655,7 @@ class CustomJobHook(GoogleBaseHook):
sync=sync,
)
- return model
+ return model, training_id
@GoogleBaseHook.fallback_to_default_project_id
def delete_pipeline_job(
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
index 6196bf3880..71e6ab9967 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
@@ -29,7 +29,11 @@ from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
-from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink
+from airflow.providers.google.cloud.links.vertex_ai import (
+ VertexAIModelLink,
+ VertexAITrainingLink,
+ VertexAITrainingPipelinesLink,
+)
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -411,7 +415,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
'command',
'impersonation_chain',
]
- operator_extra_links = (VertexAIModelLink(),)
+ operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
def __init__(
self,
@@ -428,7 +432,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
- model = self.hook.create_custom_container_training_job(
+ model, training_id = self.hook.create_custom_container_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
@@ -478,9 +482,13 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
sync=True,
)
- result = Model.to_dict(model)
- model_id = self.hook.extract_model_id(result)
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ if model:
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ else:
+ result = model # type: ignore
+ VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
return result
def on_kill(self) -> None:
@@ -755,7 +763,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
'region',
'impersonation_chain',
]
- operator_extra_links = (VertexAIModelLink(),)
+ operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
def __init__(
self,
@@ -774,7 +782,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
- model = self.hook.create_custom_python_package_training_job(
+ model, training_id = self.hook.create_custom_python_package_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
@@ -825,9 +833,13 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
sync=True,
)
- result = Model.to_dict(model)
- model_id = self.hook.extract_model_id(result)
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ if model:
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ else:
+ result = model # type: ignore
+ VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
return result
def on_kill(self) -> None:
@@ -1104,7 +1116,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
'requirements',
'impersonation_chain',
]
- operator_extra_links = (VertexAIModelLink(),)
+ operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
def __init__(
self,
@@ -1123,7 +1135,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
- model = self.hook.create_custom_training_job(
+ model, training_id = self.hook.create_custom_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
@@ -1174,9 +1186,13 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
sync=True,
)
- result = Model.to_dict(model)
- model_id = self.hook.extract_model_id(result)
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ if model:
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ else:
+ result = model # type: ignore
+ VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
return result
def on_kill(self) -> None:
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 51e6674344..34d27d87ae 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -170,6 +170,7 @@ TEST_OUTPUT_CONFIG = {
class TestVertexAICreateCustomContainerTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
+ mock_hook.return_value.create_custom_container_training_job.return_value = (None, 'training_id')
op = CreateCustomContainerTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -250,6 +251,7 @@ class TestVertexAICreateCustomContainerTrainingJobOperator:
class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
+ mock_hook.return_value.create_custom_python_package_training_job.return_value = (None, 'training_id')
op = CreateCustomPythonPackageTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -332,6 +334,7 @@ class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
class TestVertexAICreateCustomTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
+ mock_hook.return_value.create_custom_training_job.return_value = (None, 'training_id')
op = CreateCustomTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,