You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by tu...@apache.org on 2022/02/20 21:49:20 UTC
[airflow] branch main updated: Add Auto ML operators for Vertex AI service (#21470)
This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6061cc4 Add Auto ML operators for Vertex AI service (#21470)
6061cc4 is described below
commit 6061cc42196053e3540d35f5fdcdedf7bb72cb4a
Author: Maksim <ma...@google.com>
AuthorDate: Sun Feb 20 21:48:39 2022 +0000
Add Auto ML operators for Vertex AI service (#21470)
---
.../google/cloud/example_dags/example_vertex_ai.py | 153 +++
.../google/cloud/hooks/vertex_ai/auto_ml.py | 1256 ++++++++++++++++++++
airflow/providers/google/cloud/links/vertex_ai.py | 182 +++
.../google/cloud/operators/vertex_ai/auto_ml.py | 623 ++++++++++
.../google/cloud/operators/vertex_ai/custom_job.py | 90 +-
.../google/cloud/operators/vertex_ai/dataset.py | 88 +-
airflow/providers/google/provider.yaml | 10 +-
.../operators/cloud/vertex_ai.rst | 90 ++
.../google/cloud/hooks/vertex_ai/test_auto_ml.py | 175 +++
.../google/cloud/operators/test_vertex_ai.py | 345 ++++++
.../cloud/operators/test_vertex_ai_system.py | 4 +
11 files changed, 2857 insertions(+), 159 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py
index 8c1f0d7..37e454c 100644
--- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+++ b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py
@@ -36,6 +36,15 @@ from uuid import uuid4
from google.protobuf.struct_pb2 import Value
from airflow import models
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+ CreateAutoMLForecastingTrainingJobOperator,
+ CreateAutoMLImageTrainingJobOperator,
+ CreateAutoMLTabularTrainingJobOperator,
+ CreateAutoMLTextTrainingJobOperator,
+ CreateAutoMLVideoTrainingJobOperator,
+ DeleteAutoMLTrainingJobOperator,
+ ListAutoMLTrainingJobOperator,
+)
from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
CreateCustomContainerTrainingJobOperator,
CreateCustomPythonPackageTrainingJobOperator,
@@ -121,6 +130,33 @@ TEST_IMPORT_CONFIG = [
DATASET_TO_UPDATE = {"display_name": "test-name"}
TEST_UPDATE_MASK = {"paths": ["displayName"]}
+TEST_TIME_COLUMN = "date"
+TEST_TIME_SERIES_IDENTIFIER_COLUMN = "store_name"
+TEST_TARGET_COLUMN = "sale_dollars"
+
+COLUMN_SPECS = {
+ TEST_TIME_COLUMN: "timestamp",
+ TEST_TARGET_COLUMN: "numeric",
+ "city": "categorical",
+ "zip_code": "categorical",
+ "county": "categorical",
+}
+
+COLUMN_TRANSFORMATIONS = [
+ {"categorical": {"column_name": "Type"}},
+ {"numeric": {"column_name": "Age"}},
+ {"categorical": {"column_name": "Breed1"}},
+ {"categorical": {"column_name": "Color1"}},
+ {"categorical": {"column_name": "Color2"}},
+ {"categorical": {"column_name": "MaturitySize"}},
+ {"categorical": {"column_name": "FurLength"}},
+ {"categorical": {"column_name": "Vaccinated"}},
+ {"categorical": {"column_name": "Sterilized"}},
+ {"categorical": {"column_name": "Health"}},
+ {"numeric": {"column_name": "Fee"}},
+ {"numeric": {"column_name": "PhotoAmt"}},
+]
+
with models.DAG(
"example_gcp_vertex_ai_custom_jobs",
schedule_interval="@once",
@@ -313,3 +349,120 @@ with models.DAG(
create_image_dataset_job >> import_data_job >> export_data_job
create_video_dataset_job >> update_dataset_job
list_dataset_job
+
+with models.DAG(
+ "example_gcp_vertex_ai_auto_ml",
+ schedule_interval="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+) as auto_ml_dag:
+ # [START how_to_cloud_vertex_ai_create_auto_ml_forecasting_training_job_operator]
+ create_auto_ml_forecasting_training_job = CreateAutoMLForecastingTrainingJobOperator(
+ task_id="auto_ml_forecasting_task",
+ display_name=f"auto-ml-forecasting-{DISPLAY_NAME}",
+ optimization_objective="minimize-rmse",
+ column_specs=COLUMN_SPECS,
+ # run params
+ dataset_id=DATASET_ID,
+ target_column=TEST_TARGET_COLUMN,
+ time_column=TEST_TIME_COLUMN,
+ time_series_identifier_column=TEST_TIME_SERIES_IDENTIFIER_COLUMN,
+ available_at_forecast_columns=[TEST_TIME_COLUMN],
+ unavailable_at_forecast_columns=[TEST_TARGET_COLUMN],
+ time_series_attribute_columns=["city", "zip_code", "county"],
+ forecast_horizon=30,
+ context_window=30,
+ data_granularity_unit="day",
+ data_granularity_count=1,
+ weight_column=None,
+ budget_milli_node_hours=1000,
+ model_display_name=f"auto-ml-forecasting-model-{DISPLAY_NAME}",
+ predefined_split_column_name=None,
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_create_auto_ml_forecasting_training_job_operator]
+
+ # [START how_to_cloud_vertex_ai_create_auto_ml_image_training_job_operator]
+ create_auto_ml_image_training_job = CreateAutoMLImageTrainingJobOperator(
+ task_id="auto_ml_image_task",
+ display_name=f"auto-ml-image-{DISPLAY_NAME}",
+ dataset_id=DATASET_ID,
+ prediction_type="classification",
+ multi_label=False,
+ model_type="CLOUD",
+ training_fraction_split=0.6,
+ validation_fraction_split=0.2,
+ test_fraction_split=0.2,
+ budget_milli_node_hours=8000,
+ model_display_name=f"auto-ml-image-model-{DISPLAY_NAME}",
+ disable_early_stopping=False,
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_create_auto_ml_image_training_job_operator]
+
+ # [START how_to_cloud_vertex_ai_create_auto_ml_tabular_training_job_operator]
+ create_auto_ml_tabular_training_job = CreateAutoMLTabularTrainingJobOperator(
+ task_id="auto_ml_tabular_task",
+ display_name=f"auto-ml-tabular-{DISPLAY_NAME}",
+ optimization_prediction_type="classification",
+ column_transformations=COLUMN_TRANSFORMATIONS,
+ dataset_id=DATASET_ID,
+ target_column="Adopted",
+ training_fraction_split=0.8,
+ validation_fraction_split=0.1,
+ test_fraction_split=0.1,
+ model_display_name="adopted-prediction-model",
+ disable_early_stopping=False,
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_create_auto_ml_tabular_training_job_operator]
+
+ # [START how_to_cloud_vertex_ai_create_auto_ml_text_training_job_operator]
+ create_auto_ml_text_training_job = CreateAutoMLTextTrainingJobOperator(
+ task_id="auto_ml_text_task",
+ display_name=f"auto-ml-text-{DISPLAY_NAME}",
+ prediction_type="classification",
+ multi_label=False,
+ dataset_id=DATASET_ID,
+ model_display_name=f"auto-ml-text-model-{DISPLAY_NAME}",
+ training_fraction_split=0.7,
+ validation_fraction_split=0.2,
+ test_fraction_split=0.1,
+ sync=True,
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_create_auto_ml_text_training_job_operator]
+
+ # [START how_to_cloud_vertex_ai_create_auto_ml_video_training_job_operator]
+ create_auto_ml_video_training_job = CreateAutoMLVideoTrainingJobOperator(
+ task_id="auto_ml_video_task",
+ display_name=f"auto-ml-video-{DISPLAY_NAME}",
+ prediction_type="classification",
+ model_type="CLOUD",
+ dataset_id=DATASET_ID,
+ model_display_name=f"auto-ml-video-model-{DISPLAY_NAME}",
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_create_auto_ml_video_training_job_operator]
+
+ # [START how_to_cloud_vertex_ai_delete_auto_ml_training_job_operator]
+ delete_auto_ml_training_job = DeleteAutoMLTrainingJobOperator(
+ task_id="delete_auto_ml_training_job",
+ training_pipeline_id=TRAINING_PIPELINE_ID,
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_delete_auto_ml_training_job_operator]
+
+ # [START how_to_cloud_vertex_ai_list_auto_ml_training_job_operator]
+ list_auto_ml_training_job = ListAutoMLTrainingJobOperator(
+ task_id="list_auto_ml_training_job",
+ region=REGION,
+ project_id=PROJECT_ID,
+ )
+ # [END how_to_cloud_vertex_ai_list_auto_ml_training_job_operator]
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
new file mode 100644
index 0000000..5d1f705
--- /dev/null
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
@@ -0,0 +1,1256 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+"""
+This module contains a Google Cloud Vertex AI hook.
+
+.. spelling::
+
+ aiplatform
+ au
+ codepoints
+ milli
+ mae
+ quantile
+ quantiles
+ Quantiles
+ rmse
+ rmsle
+ rmspe
+ wape
+ prc
+ roc
+ Jetson
+ forecasted
+ Struct
+ sentimentMax
+ TrainingPipeline
+ targetColumn
+ optimizationObjective
+"""
+
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+from google.api_core.operation import Operation
+from google.api_core.retry import Retry
+from google.cloud.aiplatform import (
+ AutoMLForecastingTrainingJob,
+ AutoMLImageTrainingJob,
+ AutoMLTabularTrainingJob,
+ AutoMLTextTrainingJob,
+ AutoMLVideoTrainingJob,
+ datasets,
+ models,
+)
+from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
+from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ListTrainingPipelinesPager
+from google.cloud.aiplatform_v1.types import TrainingPipeline
+
+from airflow import AirflowException
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+
+
+class AutoMLHook(GoogleBaseHook):
+ """Hook for Google Cloud Vertex AI Auto ML APIs."""
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ ) -> None:
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ delegate_to=delegate_to,
+ impersonation_chain=impersonation_chain,
+ )
+ self._job: Optional[
+ Union[
+ AutoMLForecastingTrainingJob,
+ AutoMLImageTrainingJob,
+ AutoMLTabularTrainingJob,
+ AutoMLTextTrainingJob,
+ AutoMLVideoTrainingJob,
+ ]
+ ] = None
+
+ def get_pipeline_service_client(
+ self,
+ region: Optional[str] = None,
+ ) -> PipelineServiceClient:
+ """Returns PipelineServiceClient."""
+ client_options = None
+ if region and region != 'global':
+ client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'}
+
+ return PipelineServiceClient(
+ credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
+ )
+
+ def get_job_service_client(
+ self,
+ region: Optional[str] = None,
+ ) -> JobServiceClient:
+ """Returns JobServiceClient"""
+ client_options = None
+ if region and region != 'global':
+ client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'}
+
+ return JobServiceClient(
+ credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
+ )
+
+ def get_auto_ml_tabular_training_job(
+ self,
+ display_name: str,
+ optimization_prediction_type: str,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ optimization_objective_recall_value: Optional[float] = None,
+ optimization_objective_precision_value: Optional[float] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ ) -> AutoMLTabularTrainingJob:
+ """Returns AutoMLTabularTrainingJob object"""
+ return AutoMLTabularTrainingJob(
+ display_name=display_name,
+ optimization_prediction_type=optimization_prediction_type,
+ optimization_objective=optimization_objective,
+ column_specs=column_specs,
+ column_transformations=column_transformations,
+ optimization_objective_recall_value=optimization_objective_recall_value,
+ optimization_objective_precision_value=optimization_objective_precision_value,
+ project=project,
+ location=location,
+ credentials=self._get_credentials(),
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ def get_auto_ml_forecasting_training_job(
+ self,
+ display_name: str,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ ) -> AutoMLForecastingTrainingJob:
+ """Returns AutoMLForecastingTrainingJob object"""
+ return AutoMLForecastingTrainingJob(
+ display_name=display_name,
+ optimization_objective=optimization_objective,
+ column_specs=column_specs,
+ column_transformations=column_transformations,
+ project=project,
+ location=location,
+ credentials=self._get_credentials(),
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ def get_auto_ml_image_training_job(
+ self,
+ display_name: str,
+ prediction_type: str = "classification",
+ multi_label: bool = False,
+ model_type: str = "CLOUD",
+ base_model: Optional[models.Model] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ ) -> AutoMLImageTrainingJob:
+ """Returns AutoMLImageTrainingJob object"""
+ return AutoMLImageTrainingJob(
+ display_name=display_name,
+ prediction_type=prediction_type,
+ multi_label=multi_label,
+ model_type=model_type,
+ base_model=base_model,
+ project=project,
+ location=location,
+ credentials=self._get_credentials(),
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ def get_auto_ml_text_training_job(
+ self,
+ display_name: str,
+ prediction_type: str,
+ multi_label: bool = False,
+ sentiment_max: int = 10,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ ) -> AutoMLTextTrainingJob:
+ """Returns AutoMLTextTrainingJob object"""
+ return AutoMLTextTrainingJob(
+ display_name=display_name,
+ prediction_type=prediction_type,
+ multi_label=multi_label,
+ sentiment_max=sentiment_max,
+ project=project,
+ location=location,
+ credentials=self._get_credentials(),
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ def get_auto_ml_video_training_job(
+ self,
+ display_name: str,
+ prediction_type: str = "classification",
+ model_type: str = "CLOUD",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ ) -> AutoMLVideoTrainingJob:
+ """Returns AutoMLVideoTrainingJob object"""
+ return AutoMLVideoTrainingJob(
+ display_name=display_name,
+ prediction_type=prediction_type,
+ model_type=model_type,
+ project=project,
+ location=location,
+ credentials=self._get_credentials(),
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ @staticmethod
+ def extract_model_id(obj: Dict) -> str:
+ """Returns unique id of the Model."""
+ return obj["name"].rpartition("/")[-1]
+
+ def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None):
+ """Waits for long-lasting operation to complete."""
+ try:
+ return operation.result(timeout=timeout)
+ except Exception:
+ error = operation.exception(timeout=timeout)
+ raise AirflowException(error)
+
+ def cancel_auto_ml_job(self) -> None:
+ """Cancel Auto ML Job for training pipeline"""
+ if self._job:
+ self._job.cancel()
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_auto_ml_tabular_training_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ dataset: datasets.TabularDataset,
+ target_column: str,
+ optimization_prediction_type: str,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ optimization_objective_recall_value: Optional[float] = None,
+ optimization_objective_precision_value: Optional[float] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ disable_early_stopping: bool = False,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ sync: bool = True,
+ ) -> models.Model:
+ """
+ Create an AutoML Tabular Training Job.
+
+ :param project_id: Required. Project to run training in.
+ :param region: Required. Location to run training in.
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
+ :param dataset: Required. The dataset within the same Project from which data will be used to train
+ the Model. The Dataset must use schema compatible with Model being trained, and what is
+ compatible should be described in the used TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For tabular
+ Datasets, all their data is exported to training, to pick and choose from.
+ :param target_column: Required. The name of the column values of which the Model is to predict.
+ :param optimization_prediction_type: The type of prediction the Model is to produce.
+ "classification" - Predict one out of multiple target values is picked for each row.
+ "regression" - Predict a value based on its relation to other values. This type is available only
+ to columns that contain semantically numeric values, i.e. integers or floating point number, even
+ if stored as e.g. strings.
+ :param optimization_objective: Optional. Objective function the Model is to be optimized towards.
+ The training task creates a Model that maximizes/minimizes the value of the objective function
+ over the validation set.
+
+ The supported optimization objectives depend on the prediction type, and in the case of
+ classification also the number of distinct values in the target column (two distinct values
+ -> binary, 3 or more distinct values -> multi class). If the field is not set, the default
+ objective function is used.
+
+ Classification (binary):
+ "maximize-au-roc" (default) - Maximize the area under the receiver operating characteristic (ROC)
+ curve.
+ "minimize-log-loss" - Minimize log loss.
+ "maximize-au-prc" - Maximize the area under the precision-recall curve.
+ "maximize-precision-at-recall" - Maximize precision for a specified recall value.
+ "maximize-recall-at-precision" - Maximize recall for a specified precision value.
+
+ Classification (multi class):
+ "minimize-log-loss" (default) - Minimize log loss.
+
+ Regression:
+ "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
+ "minimize-mae" - Minimize mean-absolute error (MAE).
+ "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
+ :param column_specs: Optional. Alternative to column_transformations where the keys of the dict are
+ column names and their respective values are one of AutoMLTabularTrainingJob.column_data_types.
+ When creating transformation for BigQuery Struct column, the column should be flattened using "."
+ as the delimiter. Only columns with no child should have a transformation. If an input column has
+ no transformations on it, such a column is ignored by the training, except for the targetColumn,
+ which should have no transformations defined on. Only one of column_transformations or
+ column_specs should be passed.
+ :param column_transformations: Optional. Transformations to apply to the input columns (i.e. columns
+ other than the targetColumn). Each transformation may produce multiple result values from the
+ column's value, and all are used for training. When creating transformation for BigQuery Struct
+ column, the column should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation. If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have no transformations
+ defined on. Only one of column_transformations or column_specs should be passed. Consider using
+ column_specs as column_transformations will be deprecated eventually.
+ :param optimization_objective_recall_value: Optional. Required when maximize-precision-at-recall
+ optimizationObjective was picked, represents the recall value at which the optimization is done.
+ The minimum value is 0 and the maximum is 1.0.
+ :param optimization_objective_precision_value: Optional. Required when maximize-recall-at-precision
+ optimizationObjective was picked, represents the precision value at which the optimization is
+ done.
+ The minimum value is 0 and the maximum is 1.0.
+ :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be
+ in the same region as where the compute resource is created. If set, this TrainingPipeline will
+ be secured by this key.
+ Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
+ is not set separately.
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be
+ in the same region as where the compute resource is created. If set, the trained Model will be
+ secured by this key.
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
+ validate the Model. This is ignored if Dataset is not provided.
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key (either the label's value or value in the column) must be one of
+ {``training``, ``validation``, ``test``}, and it defines to which set the given piece of data is
+ assigned. If for a piece of data the key is not present or has an invalid value, that piece is
+ ignored by the pipeline. Supported only for tabular and time series Datasets.
+ :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data columns.
+ The value of the key values of the key (the values in the column) must be in RFC 3339 `date-time`
+ format, where `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a piece of data the
+ key is not present or has an invalid value, that piece is ignored by the pipeline. Supported only
+ for tabular and time series Datasets. This parameter must be used with training_fraction_split,
+ validation_fraction_split and test_fraction_split.
+ :param weight_column: Optional. Name of the column that should be used as the weight column. Higher
+ values in this column give more importance to the row during Model training. The column must have
+ numeric values between 0 and 10000 inclusively, and 0 value means that the row is ignored. If the
+ weight column field is not set, then all rows are assumed to have equal weight of 1.
+ :param budget_milli_node_hours (int): Optional. The train budget of creating this Model, expressed in
+ milli node hours i.e. 1,000 value in this field means 1 node hour. The training cost of the model
+ will not exceed this budget. The final cost will be attempted to be close to the budget, though
+ may end up being (even) noticeably smaller - at the backend's discretion. This especially may
+ happen when further model training ceases to provide any improvements. If the budget is set to a
+ value known to be insufficient to train a Model for the given training set, the training won't be
+ attempted and will error. The minimum value is 1000 and the maximum is 72000.
+ :param model_display_name: Optional. If the script produces a managed Vertex AI Model. The display
+ name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8
+ characters. If not provided upon creation, the job's display_name is used.
+ :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param disable_early_stopping: Required. If true, the entire budget is used. This disables the early
+ stopping feature. By default, the early stopping feature is enabled, which means that training
+ might stop before the entire training budget has been used, if further training does no longer
+ brings significant improvement to the model.
+ :param export_evaluated_data_items: Whether to export the test set predictions to a BigQuery table.
+ If False, then the export is not performed.
+ :param export_evaluated_data_items_bigquery_destination_uri: Optional. URI of desired destination
+ BigQuery table for exported test set predictions.
+
+ Expected format: ``bq://<project_id>:<dataset_id>:<table>``
+
+ If not specified, then results are exported to the following auto-created BigQuery table:
+ ``<project_id>:export_evaluated_examples_<model_name>_<yyyy_MM_dd'T'HH_mm_ss_SSS'Z'>
+ .evaluated_examples``
+
+ Applies only if [export_evaluated_data_items] is True.
+ :param export_evaluated_data_items_override_destination: Whether to override the contents of
+ [export_evaluated_data_items_bigquery_destination_uri], if the table exists, for exported test
+ set predictions. If False, and the table exists, then the training job will fail. Applies only if
+ [export_evaluated_data_items] is True and [export_evaluated_data_items_bigquery_destination_uri]
+ is specified.
+ :param sync: Whether to execute this method synchronously. If False, this method will be executed in
+ concurrent Future and any downstream object will be immediately returned and synced when the
+ Future has completed.
+ """
+ self._job = self.get_auto_ml_tabular_training_job(
+ project=project_id,
+ location=region,
+ display_name=display_name,
+ optimization_prediction_type=optimization_prediction_type,
+ optimization_objective=optimization_objective,
+ column_specs=column_specs,
+ column_transformations=column_transformations,
+ optimization_objective_recall_value=optimization_objective_recall_value,
+ optimization_objective_precision_value=optimization_objective_precision_value,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ if not self._job:
+ raise AirflowException("AutoMLTabularTrainingJob was not created")
+
+ model = self._job.run(
+ dataset=dataset,
+ target_column=target_column,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
+ weight_column=weight_column,
+ budget_milli_node_hours=budget_milli_node_hours,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ disable_early_stopping=disable_early_stopping,
+ export_evaluated_data_items=export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=(
+ export_evaluated_data_items_bigquery_destination_uri
+ ),
+ export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
+ sync=sync,
+ )
+ model.wait()
+
+ return model
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_auto_ml_forecasting_training_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ dataset: datasets.TimeSeriesDataset,
+ target_column: str,
+ time_column: str,
+ time_series_identifier_column: str,
+ unavailable_at_forecast_columns: List[str],
+ available_at_forecast_columns: List[str],
+ forecast_horizon: int,
+ data_granularity_unit: str,
+ data_granularity_count: int,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ time_series_attribute_columns: Optional[List[str]] = None,
+ context_window: Optional[int] = None,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ quantiles: Optional[List[float]] = None,
+ validation_options: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ sync: bool = True,
+ ) -> models.Model:
+ """
+ Create an AutoML Forecasting Training Job.
+
+ :param project_id: Required. Project to run training in.
+ :param region: Required. Location to run training in.
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
+ :param dataset: Required. The dataset within the same Project from which data will be used to train
+ the Model. The Dataset must use schema compatible with Model being trained, and what is
+ compatible should be described in the used TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For time series
+ Datasets, all their data is exported to training, to pick and choose from.
+ :param target_column: Required. Name of the column that the Model is to predict values for.
+ :param time_column: Required. Name of the column that identifies time order in the time series.
+ :param time_series_identifier_column: Required. Name of the column that identifies the time series.
+ :param unavailable_at_forecast_columns: Required. Column names of columns that are unavailable at
+ forecast. Each column contains information for the given entity (identified by the
+ [time_series_identifier_column]) that is unknown before the forecast (e.g. population of a city
+ in a given year, or weather on a given day).
+ :param available_at_forecast_columns: Required. Column names of columns that are available at
+ forecast. Each column contains information for the given entity (identified by the
+ [time_series_identifier_column]) that is known at forecast.
+ :param forecast_horizon: Required. The amount of time into the future for which forecasted values for
+ the target are returned. Expressed in number of units defined by the [data_granularity_unit] and
+ [data_granularity_count] field. Inclusive.
+ :param data_granularity_unit: Required. The data granularity unit. Accepted values are ``minute``,
+ ``hour``, ``day``, ``week``, ``month``, ``year``.
+ :param data_granularity_count: Required. The number of data granularity units between data points in
+ the training data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all
+ other values of [data_granularity_unit], must be 1.
+ :param optimization_objective: Optional. Objective function the model is to be optimized towards. The
+ training process creates a Model that optimizes the value of the objective function over the
+ validation set. The supported optimization objectives:
+ "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
+ "minimize-mae" - Minimize mean-absolute error (MAE).
+ "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
+ "minimize-rmspe" - Minimize root-mean-squared percentage error (RMSPE).
+ "minimize-wape-mae" - Minimize the combination of weighted absolute percentage error (WAPE) and
+ mean-absolute-error (MAE).
+ "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles. (Set this
+ objective to build quantile forecasts.)
+ :param column_specs: Optional. Alternative to column_transformations where the keys of the dict are
+ column names and their respective values are one of AutoMLTabularTrainingJob.column_data_types.
+ When creating transformation for BigQuery Struct column, the column should be flattened using "."
+ as the delimiter. Only columns with no child should have a transformation. If an input column has
+ no transformations on it, such a column is ignored by the training, except for the targetColumn,
+ which should have no transformations defined on. Only one of column_transformations or
+ column_specs should be passed.
+ :param column_transformations: Optional. Transformations to apply to the input columns (i.e. columns
+ other than the targetColumn). Each transformation may produce multiple result values from the
+ column's value, and all are used for training. When creating transformation for BigQuery Struct
+ column, the column should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation. If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have no transformations
+ defined on. Only one of column_transformations or column_specs should be passed. Consider using
+ column_specs as column_transformations will be deprecated eventually.
+ :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be
+ in the same region as where the compute resource is created. If set, this TrainingPipeline will
+ be secured by this key.
+ Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
+ is not set separately.
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be
+ in the same region as where the compute resource is created.
+ If set, the trained Model will be secured by this key.
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
+ validate the Model. This is ignored if Dataset is not provided.
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key (either the label's value or value in the column) must be one of
+ {``TRAIN``, ``VALIDATE``, ``TEST``}, and it defines to which set the given piece of data is
+ assigned. If for a piece of data the key is not present or has an invalid value, that piece is
+ ignored by the pipeline.
+ Supported only for tabular and time series Datasets.
+ :param weight_column: Optional. Name of the column that should be used as the weight column. Higher
+ values in this column give more importance to the row during Model training. The column must have
+ numeric values between 0 and 10000 inclusively, and 0 value means that the row is ignored. If the
+ weight column field is not set, then all rows are assumed to have equal weight of 1.
+ :param time_series_attribute_columns: Optional. Column names that should be used as attribute
+ columns. Each column is constant within a time series.
+ :param context_window: Optional. The amount of time into the past training and prediction data is
+ used for model training and prediction respectively. Expressed in number of units defined by the
+ [data_granularity_unit] and [data_granularity_count] fields. When not provided uses the default
+ value of 0 which means the model sets each series context window to be 0 (also known as "cold
+ start"). Inclusive.
+ :param export_evaluated_data_items: Whether to export the test set predictions to a BigQuery table.
+ If False, then the export is not performed.
+ :param export_evaluated_data_items_bigquery_destination_uri: Optional. URI of desired destination
+ BigQuery table for exported test set predictions. Expected format:
+ ``bq://<project_id>:<dataset_id>:<table>``
+ If not specified, then results are exported to the following auto-created BigQuery table:
+ ``<project_id>:export_evaluated_examples_<model_name>_<yyyy_MM_dd'T'HH_mm_ss_SSS'Z'>
+ .evaluated_examples``
+ Applies only if [export_evaluated_data_items] is True.
+ :param export_evaluated_data_items_override_destination: Whether to override the contents of
+ [export_evaluated_data_items_bigquery_destination_uri], if the table exists, for exported test
+ set predictions. If False, and the table exists, then the training job will fail.
+ Applies only if [export_evaluated_data_items] is True and
+ [export_evaluated_data_items_bigquery_destination_uri] is specified.
+ :param quantiles: Quantiles to use for the `minizmize-quantile-loss`
+ [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in this case.
+ Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive. Each quantile must be
+ unique.
+ :param validation_options: Validation options for the data validation component. The available
+ options are: "fail-pipeline" - (default), will validate against the validation and fail the
+ pipeline if it fails. "ignore-validation" - ignore the results of the validation and continue the
+ pipeline
+ :param budget_milli_node_hours: Optional. The train budget of creating this Model, expressed in milli
+ node hours i.e. 1,000 value in this field means 1 node hour. The training cost of the model will
+ not exceed this budget. The final cost will be attempted to be close to the budget, though may
+ end up being (even) noticeably smaller - at the backend's discretion. This especially may happen
+ when further model training ceases to provide any improvements. If the budget is set to a value
+ known to be insufficient to train a Model for the given training set, the training won't be
+ attempted and will error. The minimum value is 1000 and the maximum is 72000.
+ :param model_display_name: Optional. If the script produces a managed Vertex AI Model. The display
+ name of the Model. The name can be up to 128 characters long and can be consist of any UTF-8
+ characters. If not provided upon creation, the job's display_name is used.
+ :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param sync: Whether to execute this method synchronously. If False, this method will be executed in
+ concurrent Future and any downstream object will be immediately returned and synced when the
+ Future has completed.
+ """
+ self._job = self.get_auto_ml_forecasting_training_job(
+ project=project_id,
+ location=region,
+ display_name=display_name,
+ optimization_objective=optimization_objective,
+ column_specs=column_specs,
+ column_transformations=column_transformations,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ if not self._job:
+ raise AirflowException("AutoMLForecastingTrainingJob was not created")
+
+ model = self._job.run(
+ dataset=dataset,
+ target_column=target_column,
+ time_column=time_column,
+ time_series_identifier_column=time_series_identifier_column,
+ unavailable_at_forecast_columns=unavailable_at_forecast_columns,
+ available_at_forecast_columns=available_at_forecast_columns,
+ forecast_horizon=forecast_horizon,
+ data_granularity_unit=data_granularity_unit,
+ data_granularity_count=data_granularity_count,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ predefined_split_column_name=predefined_split_column_name,
+ weight_column=weight_column,
+ time_series_attribute_columns=time_series_attribute_columns,
+ context_window=context_window,
+ export_evaluated_data_items=export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=(
+ export_evaluated_data_items_bigquery_destination_uri
+ ),
+ export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
+ quantiles=quantiles,
+ validation_options=validation_options,
+ budget_milli_node_hours=budget_milli_node_hours,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ sync=sync,
+ )
+ model.wait()
+
+ return model
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_auto_ml_image_training_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ dataset: datasets.ImageDataset,
+ prediction_type: str = "classification",
+ multi_label: bool = False,
+ model_type: str = "CLOUD",
+ base_model: Optional[models.Model] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ budget_milli_node_hours: Optional[int] = None,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ disable_early_stopping: bool = False,
+ sync: bool = True,
+ ) -> models.Model:
+ """
+ Create an AutoML Image Training Job.
+
+ :param project_id: Required. Project to run training in.
+ :param region: Required. Location to run training in.
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
+ :param dataset: Required. The dataset within the same Project from which data will be used to train
+ the Model. The Dataset must use schema compatible with Model being trained, and what is
+ compatible should be described in the used TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For tabular
+ Datasets, all their data is exported to training, to pick and choose from.
+ :param prediction_type: The type of prediction the Model is to produce, one of:
+ "classification" - Predict one out of multiple target values is picked for each row.
+ "object_detection" - Predict a value based on its relation to other values. This type is
+ available only to columns that contain semantically numeric values, i.e. integers or floating
+ point number, even if stored as e.g. strings.
+ :param multi_label: Required. Default is False. If false, a single-label (multi-class) Model will be
+ trained (i.e. assuming that for each image just up to one annotation may be applicable). If true,
+ a multi-label Model will be trained (i.e. assuming that for each image multiple annotations may
+ be applicable).
+ This is only applicable for the "classification" prediction_type and will be ignored otherwise.
+ :param model_type: Required. One of the following:
+ "CLOUD" - Default for Image Classification. A Model best tailored to be used within Google Cloud,
+ and which cannot be exported.
+ "CLOUD_HIGH_ACCURACY_1" - Default for Image Object Detection. A model best tailored to be used
+ within Google Cloud, and which cannot be exported. Expected to have a higher latency, but should
+ also have a higher prediction quality than other cloud models.
+ "CLOUD_LOW_LATENCY_1" - A model best tailored to be used within Google Cloud, and which cannot be
+ exported. Expected to have a low latency, but may have lower prediction quality than other cloud
+ models.
+ "MOBILE_TF_LOW_LATENCY_1" - A model that, in addition to being available within Google Cloud, can
+ also be exported as TensorFlow or Core ML model and used on a mobile or edge device afterwards.
+ Expected to have low latency, but may have lower prediction quality than other mobile models.
+ "MOBILE_TF_VERSATILE_1" - A model that, in addition to being available within Google Cloud, can
+ also be exported as TensorFlow or Core ML model and used on a mobile or edge device with
+ afterwards.
+ "MOBILE_TF_HIGH_ACCURACY_1" - A model that, in addition to being available within Google Cloud,
+ can also be exported as TensorFlow or Core ML model and used on a mobile or edge device
+ afterwards. Expected to have a higher latency, but should also have a higher prediction quality
+ than other mobile models.
+ :param base_model: Optional. Only permitted for Image Classification models. If it is specified, the
+ new model will be trained based on the `base` model. Otherwise, the new model will be trained
+ from scratch. The `base` model must be in the same Project and Location as the new Model to
+ train, and have the same model_type.
+ :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be
+ in the same region as where the compute resource is created. If set, this TrainingPipeline will
+ be secured by this key.
+ Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
+ is not set separately.
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute resource is created.
+ If set, the trained Model will be secured by this key.
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
+ validate the Model. This is ignored if Dataset is not provided.
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
+ filter are used to test the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param budget_milli_node_hours: Optional. The train budget of creating this Model, expressed in milli
+ node hours i.e. 1,000 value in this field means 1 node hour.
+ Defaults by `prediction_type`:
+ `classification` - For Cloud models the budget must be: 8,000 - 800,000 milli node hours
+ (inclusive). The default value is 192,000 which represents one day in wall time, assuming 8 nodes
+ are used.
+ `object_detection` - For Cloud models the budget must be: 20,000 - 900,000 milli node hours
+ (inclusive). The default value is 216,000 which represents one day in wall time, assuming 9 nodes
+ are used.
+ The training cost of the model will not exceed this budget. The final cost will be attempted to
+ be close to the budget, though may end up being (even) noticeably smaller - at the backend's
+ discretion. This especially may happen when further model training ceases to provide any
+ improvements. If the budget is set to a value known to be insufficient to train a Model for the
+ given training set, the training won't be attempted and will error.
+ :param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
+ up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon
+ creation, the job's display_name is used.
+ :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param disable_early_stopping: Required. If true, the entire budget is used. This disables the early
+ stopping feature. By default, the early stopping feature is enabled, which means that training
+ might stop before the entire training budget has been used, if further training does no longer
+ brings significant improvement to the model.
+ :param sync: Whether to execute this method synchronously. If False, this method will be executed in
+ concurrent Future and any downstream object will be immediately returned and synced when the
+ Future has completed.
+ """
+ self._job = self.get_auto_ml_image_training_job(
+ project=project_id,
+ location=region,
+ display_name=display_name,
+ prediction_type=prediction_type,
+ multi_label=multi_label,
+ model_type=model_type,
+ base_model=base_model,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ if not self._job:
+ raise AirflowException("AutoMLImageTrainingJob was not created")
+
+ model = self._job.run(
+ dataset=dataset,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
+ budget_milli_node_hours=budget_milli_node_hours,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ disable_early_stopping=disable_early_stopping,
+ sync=sync,
+ )
+ model.wait()
+
+ return model
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_auto_ml_text_training_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ dataset: datasets.TextDataset,
+ prediction_type: str,
+ multi_label: bool = False,
+ sentiment_max: int = 10,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ sync: bool = True,
+ ) -> models.Model:
+ """
+ Create an AutoML Text Training Job.
+
+ :param project_id: Required. Project to run training in.
+ :param region: Required. Location to run training in.
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
+ :param dataset: Required. The dataset within the same Project from which data will be used to train
+ the Model. The Dataset must use schema compatible with Model being trained, and what is
+ compatible should be described in the used TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
+ :param prediction_type: The type of prediction the Model is to produce, one of:
+ "classification" - A classification model analyzes text data and returns a list of categories
+ that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
+ classification models.
+ "extraction" - An entity extraction model inspects text data for known entities referenced in the
+ data and labels those entities in the text.
+ "sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
+ emotional opinion within it, especially to determine a writer's attitude as positive, negative,
+ or neutral.
+ :param multi_label: Required and only applicable for text classification task. If false, a
+ single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
+ up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
+ assuming that for each text snippet multiple annotations may be applicable).
+ :param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
+ integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
+ will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
+ range must be represented in the dataset before a model can be created. Only the Annotations with
+ this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
+ (inclusive).
+ :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute resource is created.
+ If set, this TrainingPipeline will be secured by this key.
+ Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
+ is not set separately.
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute resource is created.
+ If set, the trained Model will be secured by this key.
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
+ validate the Model. This is ignored if Dataset is not provided.
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
+ filter are used to test the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
+ up to 128 characters long and can consist of any UTF-8 characters.
+ If not provided upon creation, the job's display_name is used.
+ :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param sync: Whether to execute this method synchronously. If False, this method will be executed in
+ concurrent Future and any downstream object will be immediately returned and synced when the
+ Future has completed.
+ """
+ self._job = self.get_auto_ml_text_training_job(
+ project=project_id,
+ location=region,
+ display_name=display_name,
+ prediction_type=prediction_type,
+ multi_label=multi_label,
+ sentiment_max=sentiment_max,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ if not self._job:
+ raise AirflowException("AutoMLTextTrainingJob was not created")
+
+ model = self._job.run(
+ dataset=dataset,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ sync=sync,
+ )
+ model.wait()
+
+ return model
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_auto_ml_video_training_job(
+ self,
+ project_id: str,
+ region: str,
+ display_name: str,
+ dataset: datasets.VideoDataset,
+ prediction_type: str = "classification",
+ model_type: str = "CLOUD",
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ sync: bool = True,
+ ) -> models.Model:
+ """
+ Create an AutoML Video Training Job.
+
+ :param project_id: Required. Project to run training in.
+ :param region: Required. Location to run training in.
+ :param display_name: Required. The user-defined name of this TrainingPipeline.
+ :param dataset: Required. The dataset within the same Project from which data will be used to train
+ the Model. The Dataset must use schema compatible with Model being trained, and what is
+ compatible should be described in the used TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. For tabular
+ Datasets, all their data is exported to training, to pick and choose from.
+ :param prediction_type: The type of prediction the Model is to produce, one of:
+ "classification" - A video classification model classifies shots and segments in your videos
+ according to your own defined labels.
+ "object_tracking" - A video object tracking model detects and tracks multiple objects in shots
+ and segments. You can use these models to track objects in your videos according to your own
+ pre-defined, custom labels.
+ "action_recognition" - A video action recognition model pinpoints the location of actions with
+ short temporal durations (~1 second).
+ :param model_type: Required. One of the following:
+ "CLOUD" - available for "classification", "object_tracking" and "action_recognition" A Model best
+ tailored to be used within Google Cloud, and which cannot be exported.
+ "MOBILE_VERSATILE_1" - available for "classification", "object_tracking" and "action_recognition"
+ A model that, in addition to being available within Google Cloud, can also be exported (see
+ ModelService.ExportModel) as a TensorFlow or TensorFlow Lite model and used on a mobile or edge
+ device with afterwards.
+ "MOBILE_CORAL_VERSATILE_1" - available only for "object_tracking" A versatile model that is meant
+ to be exported (see ModelService.ExportModel) and used on a Google Coral device.
+ "MOBILE_CORAL_LOW_LATENCY_1" - available only for "object_tracking" A model that trades off
+ quality for low latency, to be exported (see ModelService.ExportModel) and used on a Google Coral
+ device.
+ "MOBILE_JETSON_VERSATILE_1" - available only for "object_tracking" A versatile model that is
+ meant to be exported (see ModelService.ExportModel) and used on an NVIDIA Jetson device.
+ "MOBILE_JETSON_LOW_LATENCY_1" - available only for "object_tracking" A model that trades off
+ quality for low latency, to be exported (see ModelService.ExportModel) and used on an NVIDIA
+ Jetson device.
+ :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute resource is created.
+ If set, this TrainingPipeline will be secured by this key.
+ Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
+ is not set separately.
+ :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute resource is created.
+ If set, the trained Model will be secured by this key.
+ :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
+ filter are used to test the Model. A filter with same syntax as the one used in
+ DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
+ FilterSplit filters, then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ :param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
+ up to 128 characters long and can be consist of any UTF-8 characters. If not provided upon
+ creation, the job's display_name is used.
+ :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
+ keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
+ lowercase letters, numeric characters, underscores and dashes. International characters are
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
+ :param sync: Whether to execute this method synchronously. If False, this method will be executed in
+ concurrent Future and any downstream object will be immediately returned and synced when the
+ Future has completed.
+ """
+ self._job = self.get_auto_ml_video_training_job(
+ project=project_id,
+ location=region,
+ display_name=display_name,
+ prediction_type=prediction_type,
+ model_type=model_type,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ if not self._job:
+ raise AirflowException("AutoMLVideoTrainingJob was not created")
+
+ model = self._job.run(
+ dataset=dataset,
+ training_fraction_split=training_fraction_split,
+ test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ test_filter_split=test_filter_split,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ sync=sync,
+ )
+ model.wait()
+
+ return model
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_training_pipeline(
+ self,
+ project_id: str,
+ region: str,
+ training_pipeline: str,
+ retry: Optional[Retry] = None,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Deletes a TrainingPipeline.
+
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
+ :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as metadata.
+ """
+ client = self.get_pipeline_service_client(region)
+ name = client.training_pipeline_path(project_id, region, training_pipeline)
+
+ result = client.delete_training_pipeline(
+ request={
+ 'name': name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_training_pipeline(
+ self,
+ project_id: str,
+ region: str,
+ training_pipeline: str,
+ retry: Optional[Retry] = None,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> TrainingPipeline:
+ """
+ Gets a TrainingPipeline.
+
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
+ :param training_pipeline: Required. The name of the TrainingPipeline resource.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as metadata.
+ """
+ client = self.get_pipeline_service_client(region)
+ name = client.training_pipeline_path(project_id, region, training_pipeline)
+
+ result = client.get_training_pipeline(
+ request={
+ 'name': name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_training_pipelines(
+ self,
+ project_id: str,
+ region: str,
+ page_size: Optional[int] = None,
+ page_token: Optional[str] = None,
+ filter: Optional[str] = None,
+ read_mask: Optional[str] = None,
+ retry: Optional[Retry] = None,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> ListTrainingPipelinesPager:
+ """
+ Lists TrainingPipelines in a Location.
+
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the service belongs to.
+ :param filter: Optional. The standard list filter. Supported fields:
+
+ - ``display_name`` supports = and !=.
+
+ - ``state`` supports = and !=.
+
+ Some examples of using the filter are:
+
+ - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"``
+
+ - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"``
+
+ - ``NOT display_name="my_pipeline"``
+
+ - ``state="PIPELINE_STATE_FAILED"``
+ :param page_size: Optional. The standard list page size.
+ :param page_token: Optional. The standard list page token. Typically obtained via
+ [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token]
+ of the previous
+ [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines]
+ call.
+ :param read_mask: Optional. Mask specifying which fields to read.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as metadata.
+ """
+ client = self.get_pipeline_service_client(region)
+ parent = client.common_location_path(project_id, region)
+
+ result = client.list_training_pipelines(
+ request={
+ 'parent': parent,
+ 'page_size': page_size,
+ 'page_token': page_token,
+ 'filter': filter,
+ 'read_mask': read_mask,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
diff --git a/airflow/providers/google/cloud/links/vertex_ai.py b/airflow/providers/google/cloud/links/vertex_ai.py
new file mode 100644
index 0000000..6ffa21b
--- /dev/null
+++ b/airflow/providers/google/cloud/links/vertex_ai.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains a links for Vertex AI assets."""
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+from airflow.models import BaseOperator, BaseOperatorLink
+from airflow.models.xcom import XCom
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai"
+VERTEX_AI_MODEL_LINK = (
+ VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}"
+)
+VERTEX_AI_TRAINING_PIPELINES_LINK = VERTEX_AI_BASE_LINK + "/training/training-pipelines?project={project_id}"
+VERTEX_AI_DATASET_LINK = (
+ VERTEX_AI_BASE_LINK + "/locations/{region}/datasets/{dataset_id}/analyze?project={project_id}"
+)
+VERTEX_AI_DATASET_LIST_LINK = VERTEX_AI_BASE_LINK + "/datasets?project={project_id}"
+
+
+class VertexAIModelLink(BaseOperatorLink):
+ """Helper class for constructing Vertex AI Model link"""
+
+ name = "Vertex AI Model"
+ key = "model_conf"
+
+ @staticmethod
+ def persist(
+ context: "Context",
+ task_instance,
+ model_id: str,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=VertexAIModelLink.key,
+ value={
+ "model_id": model_id,
+ "region": task_instance.region,
+ "project_id": task_instance.project_id,
+ },
+ )
+
+ def get_link(self, operator: BaseOperator, dttm: datetime):
+ model_conf = XCom.get_one(
+ key=VertexAIModelLink.key,
+ dag_id=operator.dag.dag_id,
+ task_id=operator.task_id,
+ execution_date=dttm,
+ )
+ return (
+ VERTEX_AI_MODEL_LINK.format(
+ region=model_conf["region"],
+ model_id=model_conf["model_id"],
+ project_id=model_conf["project_id"],
+ )
+ if model_conf
+ else ""
+ )
+
+
+class VertexAITrainingPipelinesLink(BaseOperatorLink):
+ """Helper class for constructing Vertex AI Training Pipelines link"""
+
+ name = "Vertex AI Training Pipelines"
+ key = "pipelines_conf"
+
+ @staticmethod
+ def persist(
+ context: "Context",
+ task_instance,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=VertexAITrainingPipelinesLink.key,
+ value={
+ "project_id": task_instance.project_id,
+ },
+ )
+
+ def get_link(self, operator: BaseOperator, dttm: datetime):
+ pipelines_conf = XCom.get_one(
+ key=VertexAITrainingPipelinesLink.key,
+ dag_id=operator.dag.dag_id,
+ task_id=operator.task_id,
+ execution_date=dttm,
+ )
+ return (
+ VERTEX_AI_TRAINING_PIPELINES_LINK.format(
+ project_id=pipelines_conf["project_id"],
+ )
+ if pipelines_conf
+ else ""
+ )
+
+
+class VertexAIDatasetLink(BaseOperatorLink):
+ """Helper class for constructing Vertex AI Dataset link"""
+
+ name = "Dataset"
+ key = "dataset_conf"
+
+ @staticmethod
+ def persist(context: "Context", task_instance, dataset_id: str):
+ task_instance.xcom_push(
+ context=context,
+ key=VertexAIDatasetLink.key,
+ value={
+ "dataset_id": dataset_id,
+ "region": task_instance.region,
+ "project_id": task_instance.project_id,
+ },
+ )
+
+ def get_link(self, operator: BaseOperator, dttm: datetime):
+ dataset_conf = XCom.get_one(
+ key=VertexAIDatasetLink.key,
+ dag_id=operator.dag.dag_id,
+ task_id=operator.task_id,
+ execution_date=dttm,
+ )
+ return (
+ VERTEX_AI_DATASET_LINK.format(
+ region=dataset_conf["region"],
+ dataset_id=dataset_conf["dataset_id"],
+ project_id=dataset_conf["project_id"],
+ )
+ if dataset_conf
+ else ""
+ )
+
+
+class VertexAIDatasetListLink(BaseOperatorLink):
+ """Helper class for constructing Vertex AI Datasets Link"""
+
+ name = "Dataset List"
+ key = "datasets_conf"
+
+ @staticmethod
+ def persist(
+ context: "Context",
+ task_instance,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=VertexAIDatasetListLink.key,
+ value={
+ "project_id": task_instance.project_id,
+ },
+ )
+
+ def get_link(self, operator: BaseOperator, dttm: datetime):
+ datasets_conf = XCom.get_one(
+ key=VertexAIDatasetListLink.key,
+ dag_id=operator.dag.dag_id,
+ task_id=operator.task_id,
+ execution_date=dttm,
+ )
+ return (
+ VERTEX_AI_DATASET_LIST_LINK.format(
+ project_id=datasets_conf["project_id"],
+ )
+ if datasets_conf
+ else ""
+ )
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
new file mode 100644
index 0000000..370211a
--- /dev/null
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -0,0 +1,623 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+"""This module contains Google Vertex AI operators."""
+
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
+
+from google.api_core.exceptions import NotFound
+from google.api_core.retry import Retry
+from google.cloud.aiplatform import datasets
+from google.cloud.aiplatform.models import Model
+from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
+
+from airflow.models import BaseOperator
+from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
+from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class AutoMLTrainingJobBaseOperator(BaseOperator):
+ """The base class for operators that launch AutoML jobs on VertexAI."""
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ region: str,
+ display_name: str,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
+ # RUN
+ training_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ sync: bool = True,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.display_name = display_name
+ self.labels = labels
+ self.training_encryption_spec_key_name = training_encryption_spec_key_name
+ self.model_encryption_spec_key_name = model_encryption_spec_key_name
+ # START Run param
+ self.training_fraction_split = training_fraction_split
+ self.test_fraction_split = test_fraction_split
+ self.model_display_name = model_display_name
+ self.model_labels = model_labels
+ self.sync = sync
+ # END Run param
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.impersonation_chain = impersonation_chain
+ self.hook = None # type: Optional[AutoMLHook]
+
+ def on_kill(self) -> None:
+ """
+ Callback called when the operator is killed.
+ Cancel any running job.
+ """
+ if self.hook:
+ self.hook.cancel_auto_ml_job()
+
+
+class CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
+ """Create AutoML Forecasting Training job"""
+
+ template_fields = [
+ 'region',
+ 'impersonation_chain',
+ ]
+ operator_extra_links = (VertexAIModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ target_column: str,
+ time_column: str,
+ time_series_identifier_column: str,
+ unavailable_at_forecast_columns: List[str],
+ available_at_forecast_columns: List[str],
+ forecast_horizon: int,
+ data_granularity_unit: str,
+ data_granularity_count: int,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ validation_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ time_series_attribute_columns: Optional[List[str]] = None,
+ context_window: Optional[int] = None,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ quantiles: Optional[List[float]] = None,
+ validation_options: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.target_column = target_column
+ self.time_column = time_column
+ self.time_series_identifier_column = time_series_identifier_column
+ self.unavailable_at_forecast_columns = unavailable_at_forecast_columns
+ self.available_at_forecast_columns = available_at_forecast_columns
+ self.forecast_horizon = forecast_horizon
+ self.data_granularity_unit = data_granularity_unit
+ self.data_granularity_count = data_granularity_count
+ self.optimization_objective = optimization_objective
+ self.column_specs = column_specs
+ self.column_transformations = column_transformations
+ self.validation_fraction_split = validation_fraction_split
+ self.predefined_split_column_name = predefined_split_column_name
+ self.weight_column = weight_column
+ self.time_series_attribute_columns = time_series_attribute_columns
+ self.context_window = context_window
+ self.export_evaluated_data_items = export_evaluated_data_items
+ self.export_evaluated_data_items_bigquery_destination_uri = (
+ export_evaluated_data_items_bigquery_destination_uri
+ )
+ self.export_evaluated_data_items_override_destination = (
+ export_evaluated_data_items_override_destination
+ )
+ self.quantiles = quantiles
+ self.validation_options = validation_options
+ self.budget_milli_node_hours = budget_milli_node_hours
+
+ def execute(self, context: "Context"):
+ self.hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ model = self.hook.create_auto_ml_forecasting_training_job(
+ project_id=self.project_id,
+ region=self.region,
+ display_name=self.display_name,
+ dataset=datasets.TimeSeriesDataset(dataset_name=self.dataset_id),
+ target_column=self.target_column,
+ time_column=self.time_column,
+ time_series_identifier_column=self.time_series_identifier_column,
+ unavailable_at_forecast_columns=self.unavailable_at_forecast_columns,
+ available_at_forecast_columns=self.available_at_forecast_columns,
+ forecast_horizon=self.forecast_horizon,
+ data_granularity_unit=self.data_granularity_unit,
+ data_granularity_count=self.data_granularity_count,
+ optimization_objective=self.optimization_objective,
+ column_specs=self.column_specs,
+ column_transformations=self.column_transformations,
+ labels=self.labels,
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
+ training_fraction_split=self.training_fraction_split,
+ validation_fraction_split=self.validation_fraction_split,
+ test_fraction_split=self.test_fraction_split,
+ predefined_split_column_name=self.predefined_split_column_name,
+ weight_column=self.weight_column,
+ time_series_attribute_columns=self.time_series_attribute_columns,
+ context_window=self.context_window,
+ export_evaluated_data_items=self.export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=(
+ self.export_evaluated_data_items_bigquery_destination_uri
+ ),
+ export_evaluated_data_items_override_destination=(
+ self.export_evaluated_data_items_override_destination
+ ),
+ quantiles=self.quantiles,
+ validation_options=self.validation_options,
+ budget_milli_node_hours=self.budget_milli_node_hours,
+ model_display_name=self.model_display_name,
+ model_labels=self.model_labels,
+ sync=self.sync,
+ )
+
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ return result
+
+
+class CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
+ """Create Auto ML Image Training job"""
+
+ template_fields = [
+ 'region',
+ 'impersonation_chain',
+ ]
+ operator_extra_links = (VertexAIModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ prediction_type: str = "classification",
+ multi_label: bool = False,
+ model_type: str = "CLOUD",
+ base_model: Optional[Model] = None,
+ validation_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ budget_milli_node_hours: Optional[int] = None,
+ disable_early_stopping: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.prediction_type = prediction_type
+ self.multi_label = multi_label
+ self.model_type = model_type
+ self.base_model = base_model
+ self.validation_fraction_split = validation_fraction_split
+ self.training_filter_split = training_filter_split
+ self.validation_filter_split = validation_filter_split
+ self.test_filter_split = test_filter_split
+ self.budget_milli_node_hours = budget_milli_node_hours
+ self.disable_early_stopping = disable_early_stopping
+
+ def execute(self, context: "Context"):
+ self.hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ model = self.hook.create_auto_ml_image_training_job(
+ project_id=self.project_id,
+ region=self.region,
+ display_name=self.display_name,
+ dataset=datasets.ImageDataset(dataset_name=self.dataset_id),
+ prediction_type=self.prediction_type,
+ multi_label=self.multi_label,
+ model_type=self.model_type,
+ base_model=self.base_model,
+ labels=self.labels,
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
+ training_fraction_split=self.training_fraction_split,
+ validation_fraction_split=self.validation_fraction_split,
+ test_fraction_split=self.test_fraction_split,
+ training_filter_split=self.training_filter_split,
+ validation_filter_split=self.validation_filter_split,
+ test_filter_split=self.test_filter_split,
+ budget_milli_node_hours=self.budget_milli_node_hours,
+ model_display_name=self.model_display_name,
+ model_labels=self.model_labels,
+ disable_early_stopping=self.disable_early_stopping,
+ sync=self.sync,
+ )
+
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ return result
+
+
+class CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
+ """Create Auto ML Tabular Training job"""
+
+ template_fields = [
+ 'region',
+ 'impersonation_chain',
+ ]
+ operator_extra_links = (VertexAIModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ target_column: str,
+ optimization_prediction_type: str,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ optimization_objective_recall_value: Optional[float] = None,
+ optimization_objective_precision_value: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ disable_early_stopping: bool = False,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.target_column = target_column
+ self.optimization_prediction_type = optimization_prediction_type
+ self.optimization_objective = optimization_objective
+ self.column_specs = column_specs
+ self.column_transformations = column_transformations
+ self.optimization_objective_recall_value = optimization_objective_recall_value
+ self.optimization_objective_precision_value = optimization_objective_precision_value
+ self.validation_fraction_split = validation_fraction_split
+ self.predefined_split_column_name = predefined_split_column_name
+ self.timestamp_split_column_name = timestamp_split_column_name
+ self.weight_column = weight_column
+ self.budget_milli_node_hours = budget_milli_node_hours
+ self.disable_early_stopping = disable_early_stopping
+ self.export_evaluated_data_items = export_evaluated_data_items
+ self.export_evaluated_data_items_bigquery_destination_uri = (
+ export_evaluated_data_items_bigquery_destination_uri
+ )
+ self.export_evaluated_data_items_override_destination = (
+ export_evaluated_data_items_override_destination
+ )
+
+ def execute(self, context: "Context"):
+ self.hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ model = self.hook.create_auto_ml_tabular_training_job(
+ project_id=self.project_id,
+ region=self.region,
+ display_name=self.display_name,
+ dataset=datasets.TabularDataset(dataset_name=self.dataset_id),
+ target_column=self.target_column,
+ optimization_prediction_type=self.optimization_prediction_type,
+ optimization_objective=self.optimization_objective,
+ column_specs=self.column_specs,
+ column_transformations=self.column_transformations,
+ optimization_objective_recall_value=self.optimization_objective_recall_value,
+ optimization_objective_precision_value=self.optimization_objective_precision_value,
+ labels=self.labels,
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
+ training_fraction_split=self.training_fraction_split,
+ validation_fraction_split=self.validation_fraction_split,
+ test_fraction_split=self.test_fraction_split,
+ predefined_split_column_name=self.predefined_split_column_name,
+ timestamp_split_column_name=self.timestamp_split_column_name,
+ weight_column=self.weight_column,
+ budget_milli_node_hours=self.budget_milli_node_hours,
+ model_display_name=self.model_display_name,
+ model_labels=self.model_labels,
+ disable_early_stopping=self.disable_early_stopping,
+ export_evaluated_data_items=self.export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=(
+ self.export_evaluated_data_items_bigquery_destination_uri
+ ),
+ export_evaluated_data_items_override_destination=(
+ self.export_evaluated_data_items_override_destination
+ ),
+ sync=self.sync,
+ )
+
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ return result
+
+
+class CreateAutoMLTextTrainingJobOperator(AutoMLTrainingJobBaseOperator):
+ """Create Auto ML Text Training job"""
+
+ template_fields = [
+ 'region',
+ 'impersonation_chain',
+ ]
+ operator_extra_links = (VertexAIModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ prediction_type: str,
+ multi_label: bool = False,
+ sentiment_max: int = 10,
+ validation_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.prediction_type = prediction_type
+ self.multi_label = multi_label
+ self.sentiment_max = sentiment_max
+ self.validation_fraction_split = validation_fraction_split
+ self.training_filter_split = training_filter_split
+ self.validation_filter_split = validation_filter_split
+ self.test_filter_split = test_filter_split
+
+ def execute(self, context: "Context"):
+ self.hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ model = self.hook.create_auto_ml_text_training_job(
+ project_id=self.project_id,
+ region=self.region,
+ display_name=self.display_name,
+ dataset=datasets.TextDataset(dataset_name=self.dataset_id),
+ prediction_type=self.prediction_type,
+ multi_label=self.multi_label,
+ sentiment_max=self.sentiment_max,
+ labels=self.labels,
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
+ training_fraction_split=self.training_fraction_split,
+ validation_fraction_split=self.validation_fraction_split,
+ test_fraction_split=self.test_fraction_split,
+ training_filter_split=self.training_filter_split,
+ validation_filter_split=self.validation_filter_split,
+ test_filter_split=self.test_filter_split,
+ model_display_name=self.model_display_name,
+ model_labels=self.model_labels,
+ sync=self.sync,
+ )
+
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ return result
+
+
+class CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
+ """Create Auto ML Video Training job"""
+
+ template_fields = [
+ 'region',
+ 'impersonation_chain',
+ ]
+ operator_extra_links = (VertexAIModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ prediction_type: str = "classification",
+ model_type: str = "CLOUD",
+ training_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset_id = dataset_id
+ self.prediction_type = prediction_type
+ self.model_type = model_type
+ self.training_filter_split = training_filter_split
+ self.test_filter_split = test_filter_split
+
+ def execute(self, context: "Context"):
+ self.hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ model = self.hook.create_auto_ml_video_training_job(
+ project_id=self.project_id,
+ region=self.region,
+ display_name=self.display_name,
+ dataset=datasets.VideoDataset(dataset_name=self.dataset_id),
+ prediction_type=self.prediction_type,
+ model_type=self.model_type,
+ labels=self.labels,
+ training_encryption_spec_key_name=self.training_encryption_spec_key_name,
+ model_encryption_spec_key_name=self.model_encryption_spec_key_name,
+ training_fraction_split=self.training_fraction_split,
+ test_fraction_split=self.test_fraction_split,
+ training_filter_split=self.training_filter_split,
+ test_filter_split=self.test_filter_split,
+ model_display_name=self.model_display_name,
+ model_labels=self.model_labels,
+ sync=self.sync,
+ )
+
+ result = Model.to_dict(model)
+ model_id = self.hook.extract_model_id(result)
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
+ return result
+
+
+class DeleteAutoMLTrainingJobOperator(BaseOperator):
+ """Deletes an AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTabularTrainingJob,
+ AutoMLTextTrainingJob, or AutoMLVideoTrainingJob.
+ """
+
+ template_fields = ("region", "project_id", "impersonation_chain")
+
+ def __init__(
+ self,
+ *,
+ training_pipeline_id: str,
+ region: str,
+ project_id: str,
+ retry: Optional[Retry] = None,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.training_pipeline = training_pipeline_id
+ self.region = region
+ self.project_id = project_id
+ self.retry = retry
+ self.timeout = timeout
+ self.metadata = metadata
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: "Context"):
+ hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ try:
+ self.log.info("Deleting Auto ML training pipeline: %s", self.training_pipeline)
+ training_pipeline_operation = hook.delete_training_pipeline(
+ training_pipeline=self.training_pipeline,
+ region=self.region,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation)
+ self.log.info("Training pipeline was deleted.")
+ except NotFound:
+ self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline)
+
+
+class ListAutoMLTrainingJobOperator(BaseOperator):
+ """Lists AutoMLForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTabularTrainingJob,
+ AutoMLTextTrainingJob, or AutoMLVideoTrainingJob in a Location.
+ """
+
+ template_fields = [
+ "region",
+ "project_id",
+ "impersonation_chain",
+ ]
+ operator_extra_links = [
+ VertexAITrainingPipelinesLink(),
+ ]
+
+ def __init__(
+ self,
+ *,
+ region: str,
+ project_id: str,
+ page_size: Optional[int] = None,
+ page_token: Optional[str] = None,
+ filter: Optional[str] = None,
+ read_mask: Optional[str] = None,
+ retry: Optional[Retry] = None,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.region = region
+ self.project_id = project_id
+ self.page_size = page_size
+ self.page_token = page_token
+ self.filter = filter
+ self.read_mask = read_mask
+ self.retry = retry
+ self.timeout = timeout
+ self.metadata = metadata
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: "Context"):
+ hook = AutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
+ results = hook.list_training_pipelines(
+ region=self.region,
+ project_id=self.project_id,
+ page_size=self.page_size,
+ page_token=self.page_token,
+ filter=self.filter,
+ read_mask=self.read_mask,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
+ return [TrainingPipeline.to_dict(result) for result in results]
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
index 822cf12..1f52c2c 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
@@ -26,57 +26,13 @@ from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform_v1.types.dataset import Dataset
from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
-from airflow.models import BaseOperator, BaseOperatorLink
-from airflow.models.xcom import XCom
+from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
+from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink
if TYPE_CHECKING:
from airflow.utils.context import Context
-VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai"
-VERTEX_AI_MODEL_LINK = (
- VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}"
-)
-VERTEX_AI_TRAINING_PIPELINES_LINK = VERTEX_AI_BASE_LINK + "/training/training-pipelines?project={project_id}"
-
-
-class VertexAIModelLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Model link"""
-
- name = "Vertex AI Model"
-
- def get_link(self, operator, dttm):
- model_conf = XCom.get_one(
- key='model_conf', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- VERTEX_AI_MODEL_LINK.format(
- region=model_conf["region"],
- model_id=model_conf["model_id"],
- project_id=model_conf["project_id"],
- )
- if model_conf
- else ""
- )
-
-
-class VertexAITrainingPipelinesLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Training Pipelines link"""
-
- name = "Vertex AI Training Pipelines"
-
- def get_link(self, operator, dttm):
- project_id = XCom.get_one(
- key='project_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- VERTEX_AI_TRAINING_PIPELINES_LINK.format(
- project_id=project_id,
- )
- if project_id
- else ""
- )
-
class CustomTrainingJobBaseOperator(BaseOperator):
"""The base class for operators that launch Custom jobs on VertexAI."""
@@ -465,7 +421,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
super().__init__(**kwargs)
self.command = command
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
self.hook = CustomJobHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -523,15 +479,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(
- context,
- key="model_conf",
- value={
- "model_id": model_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
return result
def on_kill(self) -> None:
@@ -819,7 +767,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
self.python_package_gcs_uri = python_package_gcs_uri
self.python_module_name = python_module_name
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
self.hook = CustomJobHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -878,15 +826,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(
- context,
- key="model_conf",
- value={
- "model_id": model_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
return result
def on_kill(self) -> None:
@@ -1176,7 +1116,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
self.requirements = requirements
self.script_path = script_path
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
self.hook = CustomJobHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -1235,15 +1175,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(
- context,
- key="model_conf",
- value={
- "model_id": model_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
+ VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
return result
def on_kill(self) -> None:
@@ -1308,7 +1240,7 @@ class DeleteCustomTrainingJobOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = CustomJobHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -1428,7 +1360,7 @@ class ListCustomTrainingJobOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = CustomJobHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -1445,5 +1377,5 @@ class ListCustomTrainingJobOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.xcom_push(context, key="project_id", value=self.project_id)
+ VertexAITrainingPipelinesLink.persist(context=context, task_instance=self)
return [TrainingPipeline.to_dict(result) for result in results]
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
index 34fc46b..ff409ab 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
@@ -25,57 +25,13 @@ from google.api_core.retry import Retry
from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig
from google.protobuf.field_mask_pb2 import FieldMask
-from airflow.models import BaseOperator, BaseOperatorLink
-from airflow.models.xcom import XCom
+from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
+from airflow.providers.google.cloud.links.vertex_ai import VertexAIDatasetLink, VertexAIDatasetListLink
if TYPE_CHECKING:
from airflow.utils.context import Context
-VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai"
-VERTEX_AI_DATASET_LINK = (
- VERTEX_AI_BASE_LINK + "/locations/{region}/datasets/{dataset_id}/analyze?project={project_id}"
-)
-VERTEX_AI_DATASET_LIST_LINK = VERTEX_AI_BASE_LINK + "/datasets?project={project_id}"
-
-
-class VertexAIDatasetLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Dataset link"""
-
- name = "Dataset"
-
- def get_link(self, operator, dttm):
- dataset_conf = XCom.get_one(
- key='dataset_conf', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- VERTEX_AI_DATASET_LINK.format(
- region=dataset_conf["region"],
- dataset_id=dataset_conf["dataset_id"],
- project_id=dataset_conf["project_id"],
- )
- if dataset_conf
- else ""
- )
-
-
-class VertexAIDatasetListLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Datasets Link"""
-
- name = "Dataset List"
-
- def get_link(self, operator, dttm):
- project_id = XCom.get_one(
- key='project_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- VERTEX_AI_DATASET_LIST_LINK.format(
- project_id=project_id,
- )
- if project_id
- else ""
- )
-
class CreateDatasetOperator(BaseOperator):
"""
@@ -130,7 +86,7 @@ class CreateDatasetOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -153,15 +109,7 @@ class CreateDatasetOperator(BaseOperator):
self.log.info("Dataset was created. Dataset id: %s", dataset_id)
self.xcom_push(context, key="dataset_id", value=dataset_id)
- self.xcom_push(
- context,
- key="dataset_conf",
- value={
- "dataset_id": dataset_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
+ VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=dataset_id)
return dataset
@@ -219,7 +167,7 @@ class GetDatasetOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -237,15 +185,7 @@ class GetDatasetOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.xcom_push(
- context,
- key="dataset_conf",
- value={
- "dataset_id": self.dataset_id,
- "project_id": self.project_id,
- "region": self.region,
- },
- )
+ VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=self.dataset_id)
self.log.info("Dataset was gotten.")
return Dataset.to_dict(dataset_obj)
except NotFound:
@@ -303,7 +243,7 @@ class DeleteDatasetOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -380,7 +320,7 @@ class ExportDataOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -456,7 +396,7 @@ class ImportDataOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -542,7 +482,7 @@ class ListDatasetsOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -560,11 +500,7 @@ class ListDatasetsOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.xcom_push(
- context,
- key="project_id",
- value=self.project_id,
- )
+ VertexAIDatasetListLink.persist(context=context, task_instance=self)
return [Dataset.to_dict(result) for result in results]
@@ -625,7 +561,7 @@ class UpdateDatasetOperator(BaseOperator):
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
- def execute(self, context: 'Context'):
+ def execute(self, context: "Context"):
hook = DatasetHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 873170f..2034e81 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -482,6 +482,7 @@ operators:
python-modules:
- airflow.providers.google.cloud.operators.vertex_ai.dataset
- airflow.providers.google.cloud.operators.vertex_ai.custom_job
+ - airflow.providers.google.cloud.operators.vertex_ai.auto_ml
sensors:
- integration-name: Google BigQuery
@@ -683,6 +684,7 @@ hooks:
python-modules:
- airflow.providers.google.cloud.hooks.vertex_ai.dataset
- airflow.providers.google.cloud.hooks.vertex_ai.custom_job
+ - airflow.providers.google.cloud.hooks.vertex_ai.auto_ml
transfers:
- source-integration-name: Presto
@@ -839,10 +841,10 @@ extra-links:
- airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink
- - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAIModelLink
- - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAITrainingPipelinesLink
- - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetLink
- - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetListLink
+ - airflow.providers.google.cloud.links.vertex_ai.VertexAIModelLink
+ - airflow.providers.google.cloud.links.vertex_ai.VertexAITrainingPipelinesLink
+ - airflow.providers.google.cloud.links.vertex_ai.VertexAIDatasetLink
+ - airflow.providers.google.cloud.links.vertex_ai.VertexAIDatasetListLink
- airflow.providers.google.cloud.operators.cloud_composer.CloudComposerEnvironmentLink
- airflow.providers.google.cloud.operators.cloud_composer.CloudComposerEnvironmentsLink
- airflow.providers.google.cloud.links.dataflow.DataflowJobLink
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index 92c22af..3a95ef9 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -164,6 +164,96 @@ If you wish to delete a Custom Training Job you can use
:start-after: [START how_to_cloud_vertex_ai_delete_custom_training_job_operator]
:end-before: [END how_to_cloud_vertex_ai_delete_custom_training_job_operator]
+Creating an AutoML Training Jobs
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To create a Google Vertex AI Auto ML training jobs you have five operators
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLForecastingTrainingJobOperator`
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`
+Each of them will wait for the operation to complete. The results of each operator will be a model
+which was trained by user using these operators.
+
+How to run AutoML Forecasting Training Job
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLForecastingTrainingJobOperator`
+
+Before start running this Job you must prepare and create ``TimeSeries`` dataset. After that you should
+put dataset id to ``dataset_id`` parameter in operator.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_create_auto_ml_forecasting_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_create_auto_ml_forecasting_training_job_operator]
+
+How to run AutoML Image Training Job
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`
+
+Before start running this Job you must prepare and create ``Image`` dataset. After that you should
+put dataset id to ``dataset_id`` parameter in operator.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_create_auto_ml_image_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_create_auto_ml_image_training_job_operator]
+
+How to run AutoML Tabular Training Job
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`
+
+Before start running this Job you must prepare and create ``Tabular`` dataset. After that you should
+put dataset id to ``dataset_id`` parameter in operator.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_create_auto_ml_tabular_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_create_auto_ml_tabular_training_job_operator]
+
+How to run AutoML Text Training Job
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
+
+Before start running this Job you must prepare and create ``Text`` dataset. After that you should
+put dataset id to ``dataset_id`` parameter in operator.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_create_auto_ml_text_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_create_auto_ml_text_training_job_operator]
+
+How to run AutoML Video Training Job
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`
+
+Before start running this Job you must prepare and create ``Video`` dataset. After that you should
+put dataset id to ``dataset_id`` parameter in operator.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_create_auto_ml_video_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_create_auto_ml_video_training_job_operator]
+
+You can get a list of AutoML Training Jobs using
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.ListAutoMLTrainingJobOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_list_auto_ml_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_list_auto_ml_training_job_operator]
+
+If you wish to delete a Auto ML Training Job you can use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.DeleteAutoMLTrainingJobOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_vertex_ai_delete_auto_ml_training_job_operator]
+ :end-before: [END how_to_cloud_vertex_ai_delete_auto_ml_training_job_operator]
+
Reference
^^^^^^^^^
diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_auto_ml.py b/tests/providers/google/cloud/hooks/vertex_ai/test_auto_ml.py
new file mode 100644
index 0000000..2d9346d
--- /dev/null
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_auto_ml.py
@@ -0,0 +1,175 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from unittest import TestCase, mock
+
+from airflow.providers.google.cloud.hooks.vertex_ai.auto_ml import AutoMLHook
+from tests.providers.google.cloud.utils.base_gcp_mock import (
+ mock_base_gcp_hook_default_project_id,
+ mock_base_gcp_hook_no_default_project_id,
+)
+
+TEST_GCP_CONN_ID: str = "test-gcp-conn-id"
+TEST_REGION: str = "test-region"
+TEST_PROJECT_ID: str = "test-project-id"
+TEST_PIPELINE_JOB: dict = {}
+TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id"
+TEST_TRAINING_PIPELINE: dict = {}
+TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline"
+
+BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
+CUSTOM_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.auto_ml.{}"
+
+
+class TestAutoMLWithDefaultProjectIdHook(TestCase):
+ def setUp(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id
+ ):
+ self.hook = AutoMLHook(gcp_conn_id=TEST_GCP_CONN_ID)
+
+ @mock.patch(CUSTOM_JOB_STRING.format("AutoMLHook.get_pipeline_service_client"))
+ def test_delete_training_pipeline(self, mock_client) -> None:
+ self.hook.delete_training_pipeline(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ training_pipeline=TEST_TRAINING_PIPELINE_NAME,
+ )
+ mock_client.assert_called_once_with(TEST_REGION)
+ mock_client.return_value.delete_training_pipeline.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.training_pipeline_path.return_value,
+ ),
+ metadata=(),
+ retry=None,
+ timeout=None,
+ )
+ mock_client.return_value.training_pipeline_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
+ )
+
+ @mock.patch(CUSTOM_JOB_STRING.format("AutoMLHook.get_pipeline_service_client"))
+ def test_get_training_pipeline(self, mock_client) -> None:
+ self.hook.get_training_pipeline(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ training_pipeline=TEST_TRAINING_PIPELINE_NAME,
+ )
+ mock_client.assert_called_once_with(TEST_REGION)
+ mock_client.return_value.get_training_pipeline.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.training_pipeline_path.return_value,
+ ),
+ metadata=(),
+ retry=None,
+ timeout=None,
+ )
+ mock_client.return_value.training_pipeline_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
+ )
+
+ @mock.patch(CUSTOM_JOB_STRING.format("AutoMLHook.get_pipeline_service_client"))
+ def test_list_training_pipelines(self, mock_client) -> None:
+ self.hook.list_training_pipelines(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ )
+ mock_client.assert_called_once_with(TEST_REGION)
+ mock_client.return_value.list_training_pipelines.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.common_location_path.return_value,
+ page_size=None,
+ page_token=None,
+ filter=None,
+ read_mask=None,
+ ),
+ metadata=(),
+ retry=None,
+ timeout=None,
+ )
+ mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
+
+
+class TestAutoMLWithoutDefaultProjectIdHook(TestCase):
+ def setUp(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id
+ ):
+ self.hook = AutoMLHook(gcp_conn_id=TEST_GCP_CONN_ID)
+
+ @mock.patch(CUSTOM_JOB_STRING.format("AutoMLHook.get_pipeline_service_client"))
+ def test_delete_training_pipeline(self, mock_client) -> None:
+ self.hook.delete_training_pipeline(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ training_pipeline=TEST_TRAINING_PIPELINE_NAME,
+ )
+ mock_client.assert_called_once_with(TEST_REGION)
+ mock_client.return_value.delete_training_pipeline.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.training_pipeline_path.return_value,
+ ),
+ metadata=(),
+ retry=None,
+ timeout=None,
+ )
+ mock_client.return_value.training_pipeline_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
+ )
+
+ @mock.patch(CUSTOM_JOB_STRING.format("AutoMLHook.get_pipeline_service_client"))
+ def test_get_training_pipeline(self, mock_client) -> None:
+ self.hook.get_training_pipeline(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ training_pipeline=TEST_TRAINING_PIPELINE_NAME,
+ )
+ mock_client.assert_called_once_with(TEST_REGION)
+ mock_client.return_value.get_training_pipeline.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.training_pipeline_path.return_value,
+ ),
+ metadata=(),
+ retry=None,
+ timeout=None,
+ )
+ mock_client.return_value.training_pipeline_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
+ )
+
+ @mock.patch(CUSTOM_JOB_STRING.format("AutoMLHook.get_pipeline_service_client"))
+ def test_list_training_pipelines(self, mock_client) -> None:
+ self.hook.list_training_pipelines(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ )
+ mock_client.assert_called_once_with(TEST_REGION)
+ mock_client.return_value.list_training_pipelines.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.common_location_path.return_value,
+ page_size=None,
+ page_token=None,
+ filter=None,
+ read_mask=None,
+ ),
+ metadata=(),
+ retry=None,
+ timeout=None,
+ )
+ mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py
index ec5a63d..cfb99fd 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -15,10 +15,20 @@
# specific language governing permissions and limitations
# under the License.
+from typing import List
from unittest import mock
from google.api_core.retry import Retry
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+ CreateAutoMLForecastingTrainingJobOperator,
+ CreateAutoMLImageTrainingJobOperator,
+ CreateAutoMLTabularTrainingJobOperator,
+ CreateAutoMLTextTrainingJobOperator,
+ CreateAutoMLVideoTrainingJobOperator,
+ DeleteAutoMLTrainingJobOperator,
+ ListAutoMLTrainingJobOperator,
+)
from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
CreateCustomContainerTrainingJobOperator,
CreateCustomPythonPackageTrainingJobOperator,
@@ -95,6 +105,15 @@ TEST_IMPORT_CONFIG = [
]
TEST_UPDATE_MASK = "test-update-mask"
+TEST_TRAINING_TARGET_COLUMN = "target"
+TEST_TRAINING_TIME_COLUMN = "time"
+TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN = "time_series_identifier"
+TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS: List[str] = []
+TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS: List[str] = []
+TEST_TRAINING_FORECAST_HORIZON = 10
+TEST_TRAINING_DATA_GRANULARITY_UNIT = "day"
+TEST_TRAINING_DATA_GRANULARITY_COUNT = 1
+
class TestVertexAICreateCustomContainerTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
@@ -611,3 +630,329 @@ class TestVertexAIUpdateDatasetOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+
+
+class TestVertexAICreateAutoMLForecastingTrainingJobOperator:
+ @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook, mock_dataset):
+ op = CreateAutoMLForecastingTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ target_column=TEST_TRAINING_TARGET_COLUMN,
+ time_column=TEST_TRAINING_TIME_COLUMN,
+ time_series_identifier_column=TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
+ unavailable_at_forecast_columns=TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
+ available_at_forecast_columns=TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
+ forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
+ data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
+ data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ op.execute(context={'ti': mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+ mock_hook.return_value.create_auto_ml_forecasting_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ target_column=TEST_TRAINING_TARGET_COLUMN,
+ time_column=TEST_TRAINING_TIME_COLUMN,
+ time_series_identifier_column=TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
+ unavailable_at_forecast_columns=TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
+ available_at_forecast_columns=TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
+ forecast_horizon=TEST_TRAINING_FORECAST_HORIZON,
+ data_granularity_unit=TEST_TRAINING_DATA_GRANULARITY_UNIT,
+ data_granularity_count=TEST_TRAINING_DATA_GRANULARITY_COUNT,
+ optimization_objective=None,
+ column_specs=None,
+ column_transformations=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ predefined_split_column_name=None,
+ weight_column=None,
+ time_series_attribute_columns=None,
+ context_window=None,
+ export_evaluated_data_items=False,
+ export_evaluated_data_items_bigquery_destination_uri=None,
+ export_evaluated_data_items_override_destination=False,
+ quantiles=None,
+ validation_options=None,
+ budget_milli_node_hours=1000,
+ model_display_name=None,
+ model_labels=None,
+ sync=True,
+ )
+
+
+class TestVertexAICreateAutoMLImageTrainingJobOperator:
+ @mock.patch("google.cloud.aiplatform.datasets.ImageDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook, mock_dataset):
+ op = CreateAutoMLImageTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ prediction_type="classification",
+ multi_label=False,
+ model_type="CLOUD",
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ op.execute(context={'ti': mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+ mock_hook.return_value.create_auto_ml_image_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ prediction_type="classification",
+ multi_label=False,
+ model_type="CLOUD",
+ base_model=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ budget_milli_node_hours=None,
+ model_display_name=None,
+ model_labels=None,
+ disable_early_stopping=False,
+ sync=True,
+ )
+
+
+class TestVertexAICreateAutoMLTabularTrainingJobOperator:
+ @mock.patch("google.cloud.aiplatform.datasets.TabularDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook, mock_dataset):
+ op = CreateAutoMLTabularTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ target_column=None,
+ optimization_prediction_type=None,
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ op.execute(context={'ti': mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+ mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ target_column=None,
+ optimization_prediction_type=None,
+ optimization_objective=None,
+ column_specs=None,
+ column_transformations=None,
+ optimization_objective_recall_value=None,
+ optimization_objective_precision_value=None,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ predefined_split_column_name=None,
+ timestamp_split_column_name=None,
+ weight_column=None,
+ budget_milli_node_hours=1000,
+ model_display_name=None,
+ model_labels=None,
+ disable_early_stopping=False,
+ export_evaluated_data_items=False,
+ export_evaluated_data_items_bigquery_destination_uri=None,
+ export_evaluated_data_items_override_destination=False,
+ sync=True,
+ )
+
+
+class TestVertexAICreateAutoMLTextTrainingJobOperator:
+ @mock.patch("google.cloud.aiplatform.datasets.TextDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook, mock_dataset):
+ op = CreateAutoMLTextTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ prediction_type=None,
+ multi_label=False,
+ sentiment_max=10,
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ op.execute(context={'ti': mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+ mock_hook.return_value.create_auto_ml_text_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ prediction_type=None,
+ multi_label=False,
+ sentiment_max=10,
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ validation_fraction_split=None,
+ test_fraction_split=None,
+ training_filter_split=None,
+ validation_filter_split=None,
+ test_filter_split=None,
+ model_display_name=None,
+ model_labels=None,
+ sync=True,
+ )
+
+
+class TestVertexAICreateAutoMLVideoTrainingJobOperator:
+ @mock.patch("google.cloud.aiplatform.datasets.VideoDataset")
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook, mock_dataset):
+ op = CreateAutoMLVideoTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ display_name=DISPLAY_NAME,
+ dataset_id=TEST_DATASET_ID,
+ prediction_type="classification",
+ model_type="CLOUD",
+ sync=True,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ op.execute(context={'ti': mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID)
+ mock_hook.return_value.create_auto_ml_video_training_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ display_name=DISPLAY_NAME,
+ dataset=mock_dataset.return_value,
+ prediction_type="classification",
+ model_type="CLOUD",
+ labels=None,
+ training_encryption_spec_key_name=None,
+ model_encryption_spec_key_name=None,
+ training_fraction_split=None,
+ test_fraction_split=None,
+ training_filter_split=None,
+ test_filter_split=None,
+ model_display_name=None,
+ model_labels=None,
+ sync=True,
+ )
+
+
+class TestVertexAIDeleteAutoMLTrainingJobOperator:
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook):
+ op = DeleteAutoMLTrainingJobOperator(
+ task_id=TASK_ID,
+ training_pipeline_id=TRAINING_PIPELINE_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_hook.return_value.delete_training_pipeline.assert_called_once_with(
+ training_pipeline=TRAINING_PIPELINE_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestVertexAIListAutoMLTrainingJobOperator:
+ @mock.patch(VERTEX_AI_PATH.format("auto_ml.AutoMLHook"))
+ def test_execute(self, mock_hook):
+ page_token = "page_token"
+ page_size = 42
+ filter = "filter"
+ read_mask = "read_mask"
+
+ op = ListAutoMLTrainingJobOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ page_size=page_size,
+ page_token=page_token,
+ filter=filter,
+ read_mask=read_mask,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={'ti': mock.MagicMock()})
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_hook.return_value.list_training_pipelines.assert_called_once_with(
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ page_size=page_size,
+ page_token=page_token,
+ filter=filter,
+ read_mask=read_mask,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py
index 84b84c3..bf81030 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai_system.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai_system.py
@@ -39,3 +39,7 @@ class VertexAIExampleDagsTest(GoogleSystemTest):
@provide_gcp_context(GCP_VERTEX_AI_KEY)
def test_run_dataset_example_dag(self):
self.run_dag(dag_id="example_gcp_vertex_ai_dataset", dag_folder=CLOUD_DAG_FOLDER)
+
+ @provide_gcp_context(GCP_VERTEX_AI_KEY)
+ def test_run_auto_ml_example_dag(self):
+ self.run_dag(dag_id="example_gcp_vertex_ai_auto_ml", dag_folder=CLOUD_DAG_FOLDER)