You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/29 00:43:45 UTC
[airflow] branch main updated: Revert "Create CustomJob and Datasets operators for Vertex AI service (#20077)" (#21203)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1d4b709 Revert "Create CustomJob and Datasets operators for Vertex AI service (#20077)" (#21203)
1d4b709 is described below
commit 1d4b709e20b07c6f0b5d1bab1935e19557df2913
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Fri Jan 28 17:43:10 2022 -0700
Revert "Create CustomJob and Datasets operators for Vertex AI service (#20077)" (#21203)
This reverts commit 640c0b67631c5f2c8ee866b0726fa7a8a452cd3c.
---
.../google/cloud/example_dags/example_vertex_ai.py | 313 ---
.../google/cloud/hooks/vertex_ai/__init__.py | 16 -
.../google/cloud/hooks/vertex_ai/custom_job.py | 2032 --------------------
.../google/cloud/hooks/vertex_ai/dataset.py | 460 -----
.../google/cloud/operators/vertex_ai/__init__.py | 16 -
.../google/cloud/operators/vertex_ai/custom_job.py | 1427 --------------
.../google/cloud/operators/vertex_ai/dataset.py | 644 -------
airflow/providers/google/provider.yaml | 17 -
docs/apache-airflow-providers-google/index.rst | 1 -
.../operators/cloud/vertex_ai.rst | 173 --
.../pre_commit_check_provider_yaml_files.py | 5 +-
setup.py | 1 -
.../google/cloud/hooks/vertex_ai/__init__.py | 16 -
.../cloud/hooks/vertex_ai/test_custom_job.py | 457 -----
.../google/cloud/hooks/vertex_ai/test_dataset.py | 504 -----
.../google/cloud/operators/test_vertex_ai.py | 613 ------
.../cloud/operators/test_vertex_ai_system.py | 41 -
.../google/cloud/utils/gcp_authenticator.py | 1 -
18 files changed, 1 insertion(+), 6736 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py b/airflow/providers/google/cloud/example_dags/example_vertex_ai.py
deleted file mode 100644
index 5a459e7..0000000
--- a/airflow/providers/google/cloud/example_dags/example_vertex_ai.py
+++ /dev/null
@@ -1,313 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-Example Airflow DAG that demonstrates operators for the Google Vertex AI service in the Google
-Cloud Platform.
-
-This DAG relies on the following OS environment variables:
-
-* GCP_VERTEX_AI_BUCKET - Google Cloud Storage bucket where the model will be saved
-after training process was finished.
-* CUSTOM_CONTAINER_URI - path to container with model.
-* PYTHON_PACKAGE_GSC_URI - path to test model in archive.
-* LOCAL_TRAINING_SCRIPT_PATH - path to local training script.
-* DATASET_ID - ID of dataset which will be used in training process.
-"""
-import os
-from datetime import datetime
-from uuid import uuid4
-
-from airflow import models
-from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
- CreateCustomContainerTrainingJobOperator,
- CreateCustomPythonPackageTrainingJobOperator,
- CreateCustomTrainingJobOperator,
- DeleteCustomTrainingJobOperator,
- ListCustomTrainingJobOperator,
-)
-from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
- CreateDatasetOperator,
- DeleteDatasetOperator,
- ExportDataOperator,
- GetDatasetOperator,
- ImportDataOperator,
- ListDatasetsOperator,
- UpdateDatasetOperator,
-)
-
-PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
-REGION = os.environ.get("GCP_LOCATION", "us-central1")
-BUCKET = os.environ.get("GCP_VERTEX_AI_BUCKET", "vertex-ai-system-tests")
-
-STAGING_BUCKET = f"gs://{BUCKET}"
-DISPLAY_NAME = str(uuid4()) # Create random display name
-CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest"
-CUSTOM_CONTAINER_URI = os.environ.get("CUSTOM_CONTAINER_URI", "path_to_container_with_model")
-MODEL_SERVING_CONTAINER_URI = "gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest"
-REPLICA_COUNT = 1
-MACHINE_TYPE = "n1-standard-4"
-ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
-ACCELERATOR_COUNT = 0
-TRAINING_FRACTION_SPLIT = 0.7
-TEST_FRACTION_SPLIT = 0.15
-VALIDATION_FRACTION_SPLIT = 0.15
-
-PYTHON_PACKAGE_GCS_URI = os.environ.get("PYTHON_PACKAGE_GSC_URI", "path_to_test_model_in_arch")
-PYTHON_MODULE_NAME = "aiplatform_custom_trainer_script.task"
-
-LOCAL_TRAINING_SCRIPT_PATH = os.environ.get("LOCAL_TRAINING_SCRIPT_PATH", "path_to_training_script")
-
-TRAINING_PIPELINE_ID = "test-training-pipeline-id"
-CUSTOM_JOB_ID = "test-custom-job-id"
-
-IMAGE_DATASET = {
- "display_name": str(uuid4()),
- "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml",
- "metadata": "test-image-dataset",
-}
-TABULAR_DATASET = {
- "display_name": str(uuid4()),
- "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml",
- "metadata": "test-tabular-dataset",
-}
-TEXT_DATASET = {
- "display_name": str(uuid4()),
- "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml",
- "metadata": "test-text-dataset",
-}
-VIDEO_DATASET = {
- "display_name": str(uuid4()),
- "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml",
- "metadata": "test-video-dataset",
-}
-TIME_SERIES_DATASET = {
- "display_name": str(uuid4()),
- "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/time_series_1.0.0.yaml",
- "metadata": "test-video-dataset",
-}
-DATASET_ID = os.environ.get("DATASET_ID", "test-dataset-id")
-TEST_EXPORT_CONFIG = {"gcs_destination": {"output_uri_prefix": "gs://test-vertex-ai-bucket/exports"}}
-TEST_IMPORT_CONFIG = [
- {
- "data_item_labels": {
- "test-labels-name": "test-labels-value",
- },
- "import_schema_uri": (
- "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml"
- ),
- "gcs_source": {
- "uris": ["gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl"]
- },
- },
-]
-DATASET_TO_UPDATE = {"display_name": "test-name"}
-TEST_UPDATE_MASK = {"paths": ["displayName"]}
-
-with models.DAG(
- "example_gcp_vertex_ai_custom_jobs",
- schedule_interval="@once",
- start_date=datetime(2021, 1, 1),
- catchup=False,
-) as custom_jobs_dag:
- # [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
- create_custom_container_training_job = CreateCustomContainerTrainingJobOperator(
- task_id="custom_container_task",
- staging_bucket=STAGING_BUCKET,
- display_name=f"train-housing-container-{DISPLAY_NAME}",
- container_uri=CUSTOM_CONTAINER_URI,
- model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
- # run params
- dataset_id=DATASET_ID,
- command=["python3", "task.py"],
- model_display_name=f"container-housing-model-{DISPLAY_NAME}",
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=TRAINING_FRACTION_SPLIT,
- validation_fraction_split=VALIDATION_FRACTION_SPLIT,
- test_fraction_split=TEST_FRACTION_SPLIT,
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
-
- # [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
- create_custom_python_package_training_job = CreateCustomPythonPackageTrainingJobOperator(
- task_id="python_package_task",
- staging_bucket=STAGING_BUCKET,
- display_name=f"train-housing-py-package-{DISPLAY_NAME}",
- python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
- python_module_name=PYTHON_MODULE_NAME,
- container_uri=CONTAINER_URI,
- model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
- # run params
- dataset_id=DATASET_ID,
- model_display_name=f"py-package-housing-model-{DISPLAY_NAME}",
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=TRAINING_FRACTION_SPLIT,
- validation_fraction_split=VALIDATION_FRACTION_SPLIT,
- test_fraction_split=TEST_FRACTION_SPLIT,
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
-
- # [START how_to_cloud_vertex_ai_create_custom_training_job_operator]
- create_custom_training_job = CreateCustomTrainingJobOperator(
- task_id="custom_task",
- staging_bucket=STAGING_BUCKET,
- display_name=f"train-housing-custom-{DISPLAY_NAME}",
- script_path=LOCAL_TRAINING_SCRIPT_PATH,
- container_uri=CONTAINER_URI,
- requirements=["gcsfs==0.7.1"],
- model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI,
- # run params
- dataset_id=DATASET_ID,
- replica_count=1,
- model_display_name=f"custom-housing-model-{DISPLAY_NAME}",
- sync=False,
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_create_custom_training_job_operator]
-
- # [START how_to_cloud_vertex_ai_delete_custom_training_job_operator]
- delete_custom_training_job = DeleteCustomTrainingJobOperator(
- task_id="delete_custom_training_job",
- training_pipeline_id=TRAINING_PIPELINE_ID,
- custom_job_id=CUSTOM_JOB_ID,
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_delete_custom_training_job_operator]
-
- # [START how_to_cloud_vertex_ai_list_custom_training_job_operator]
- list_custom_training_job = ListCustomTrainingJobOperator(
- task_id="list_custom_training_job",
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_list_custom_training_job_operator]
-
-with models.DAG(
- "example_gcp_vertex_ai_dataset",
- schedule_interval="@once",
- start_date=datetime(2021, 1, 1),
- catchup=False,
-) as dataset_dag:
- # [START how_to_cloud_vertex_ai_create_dataset_operator]
- create_image_dataset_job = CreateDatasetOperator(
- task_id="image_dataset",
- dataset=IMAGE_DATASET,
- region=REGION,
- project_id=PROJECT_ID,
- )
- create_tabular_dataset_job = CreateDatasetOperator(
- task_id="tabular_dataset",
- dataset=TABULAR_DATASET,
- region=REGION,
- project_id=PROJECT_ID,
- )
- create_text_dataset_job = CreateDatasetOperator(
- task_id="text_dataset",
- dataset=TEXT_DATASET,
- region=REGION,
- project_id=PROJECT_ID,
- )
- create_video_dataset_job = CreateDatasetOperator(
- task_id="video_dataset",
- dataset=VIDEO_DATASET,
- region=REGION,
- project_id=PROJECT_ID,
- )
- create_time_series_dataset_job = CreateDatasetOperator(
- task_id="time_series_dataset",
- dataset=TIME_SERIES_DATASET,
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_create_dataset_operator]
-
- # [START how_to_cloud_vertex_ai_delete_dataset_operator]
- delete_dataset_job = DeleteDatasetOperator(
- task_id="delete_dataset",
- dataset_id=create_text_dataset_job.output['dataset_id'],
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_delete_dataset_operator]
-
- # [START how_to_cloud_vertex_ai_get_dataset_operator]
- get_dataset = GetDatasetOperator(
- task_id="get_dataset",
- project_id=PROJECT_ID,
- region=REGION,
- dataset_id=create_tabular_dataset_job.output['dataset_id'],
- )
- # [END how_to_cloud_vertex_ai_get_dataset_operator]
-
- # [START how_to_cloud_vertex_ai_export_data_operator]
- export_data_job = ExportDataOperator(
- task_id="export_data",
- dataset_id=create_image_dataset_job.output['dataset_id'],
- region=REGION,
- project_id=PROJECT_ID,
- export_config=TEST_EXPORT_CONFIG,
- )
- # [END how_to_cloud_vertex_ai_export_data_operator]
-
- # [START how_to_cloud_vertex_ai_import_data_operator]
- import_data_job = ImportDataOperator(
- task_id="import_data",
- dataset_id=create_image_dataset_job.output['dataset_id'],
- region=REGION,
- project_id=PROJECT_ID,
- import_configs=TEST_IMPORT_CONFIG,
- )
- # [END how_to_cloud_vertex_ai_import_data_operator]
-
- # [START how_to_cloud_vertex_ai_list_dataset_operator]
- list_dataset_job = ListDatasetsOperator(
- task_id="list_dataset",
- region=REGION,
- project_id=PROJECT_ID,
- )
- # [END how_to_cloud_vertex_ai_list_dataset_operator]
-
- # [START how_to_cloud_vertex_ai_update_dataset_operator]
- update_dataset_job = UpdateDatasetOperator(
- task_id="update_dataset",
- project_id=PROJECT_ID,
- region=REGION,
- dataset_id=create_video_dataset_job.output['dataset_id'],
- dataset=DATASET_TO_UPDATE,
- update_mask=TEST_UPDATE_MASK,
- )
- # [END how_to_cloud_vertex_ai_update_dataset_operator]
-
- create_time_series_dataset_job
- create_text_dataset_job >> delete_dataset_job
- create_tabular_dataset_job >> get_dataset
- create_image_dataset_job >> import_data_job >> export_data_job
- create_video_dataset_job >> update_dataset_job
- list_dataset_job
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py b/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py
deleted file mode 100644
index 13a8339..0000000
--- a/airflow/providers/google/cloud/hooks/vertex_ai/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
deleted file mode 100644
index fc59753..0000000
--- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
+++ /dev/null
@@ -1,2032 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-"""This module contains a Google Cloud Vertex AI hook."""
-
-from typing import Dict, List, Optional, Sequence, Tuple, Union
-
-from google.api_core.operation import Operation
-from google.api_core.retry import Retry
-from google.cloud.aiplatform import (
- CustomContainerTrainingJob,
- CustomPythonPackageTrainingJob,
- CustomTrainingJob,
- datasets,
-)
-from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
-from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager
-from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
- ListPipelineJobsPager,
- ListTrainingPipelinesPager,
-)
-from google.cloud.aiplatform_v1.types import CustomJob, Model, PipelineJob, TrainingPipeline
-
-from airflow import AirflowException
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
-
-
-class CustomJobHook(GoogleBaseHook):
- """Hook for Google Cloud Vertex AI Custom Job APIs."""
-
- def __init__(
- self,
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- ) -> None:
- super().__init__(
- gcp_conn_id=gcp_conn_id,
- delegate_to=delegate_to,
- impersonation_chain=impersonation_chain,
- )
- self._job = None
-
- def get_pipeline_service_client(
- self,
- region: Optional[str] = None,
- ) -> PipelineServiceClient:
- """Returns PipelineServiceClient."""
- client_options = None
- if region and region != 'global':
- client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'}
-
- return PipelineServiceClient(
- credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
- )
-
- def get_job_service_client(
- self,
- region: Optional[str] = None,
- ) -> JobServiceClient:
- """Returns JobServiceClient"""
- client_options = None
- if region and region != 'global':
- client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'}
-
- return JobServiceClient(
- credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
- )
-
- def get_custom_container_training_job(
- self,
- display_name: str,
- container_uri: str,
- command: Sequence[str] = [],
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- project: Optional[str] = None,
- location: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- ) -> CustomContainerTrainingJob:
- """Returns CustomContainerTrainingJob object"""
- return CustomContainerTrainingJob(
- display_name=display_name,
- container_uri=container_uri,
- command=command,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_description=model_description,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- project=project,
- location=location,
- credentials=self._get_credentials(),
- labels=labels,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- staging_bucket=staging_bucket,
- )
-
- def get_custom_python_package_training_job(
- self,
- display_name: str,
- python_package_gcs_uri: str,
- python_module_name: str,
- container_uri: str,
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- project: Optional[str] = None,
- location: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- ):
- """Returns CustomPythonPackageTrainingJob object"""
- return CustomPythonPackageTrainingJob(
- display_name=display_name,
- container_uri=container_uri,
- python_package_gcs_uri=python_package_gcs_uri,
- python_module_name=python_module_name,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_description=model_description,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- project=project,
- location=location,
- credentials=self._get_credentials(),
- labels=labels,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- staging_bucket=staging_bucket,
- )
-
- def get_custom_training_job(
- self,
- display_name: str,
- script_path: str,
- container_uri: str,
- requirements: Optional[Sequence[str]] = None,
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- project: Optional[str] = None,
- location: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- ):
- """Returns CustomTrainingJob object"""
- return CustomTrainingJob(
- display_name=display_name,
- script_path=script_path,
- container_uri=container_uri,
- requirements=requirements,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_description=model_description,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- project=project,
- location=location,
- credentials=self._get_credentials(),
- labels=labels,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- staging_bucket=staging_bucket,
- )
-
- @staticmethod
- def extract_model_id(obj: Dict) -> str:
- """Returns unique id of the Model."""
- return obj["name"].rpartition("/")[-1]
-
- def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None):
- """Waits for long-lasting operation to complete."""
- try:
- return operation.result(timeout=timeout)
- except Exception:
- error = operation.exception(timeout=timeout)
- raise AirflowException(error)
-
- def cancel_job(self) -> None:
- """Cancel Job for training pipeline"""
- if self._job:
- self._job.cancel()
-
- def _run_job(
- self,
- job: Union[
- CustomTrainingJob,
- CustomContainerTrainingJob,
- CustomPythonPackageTrainingJob,
- ],
- dataset: Optional[
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ] = None,
- annotation_schema_uri: Optional[str] = None,
- model_display_name: Optional[str] = None,
- model_labels: Optional[Dict[str, str]] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 1,
- machine_type: str = "n1-standard-4",
- accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
- accelerator_count: int = 0,
- boot_disk_type: str = "pd-ssd",
- boot_disk_size_gb: int = 100,
- training_fraction_split: Optional[float] = None,
- validation_fraction_split: Optional[float] = None,
- test_fraction_split: Optional[float] = None,
- training_filter_split: Optional[str] = None,
- validation_filter_split: Optional[str] = None,
- test_filter_split: Optional[str] = None,
- predefined_split_column_name: Optional[str] = None,
- timestamp_split_column_name: Optional[str] = None,
- tensorboard: Optional[str] = None,
- sync=True,
- ) -> Model:
- """Run Job for training pipeline"""
- model = job.run(
- dataset=dataset,
- annotation_schema_uri=annotation_schema_uri,
- model_display_name=model_display_name,
- model_labels=model_labels,
- base_output_dir=base_output_dir,
- service_account=service_account,
- network=network,
- bigquery_destination=bigquery_destination,
- args=args,
- environment_variables=environment_variables,
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_type=accelerator_type,
- accelerator_count=accelerator_count,
- boot_disk_type=boot_disk_type,
- boot_disk_size_gb=boot_disk_size_gb,
- training_fraction_split=training_fraction_split,
- validation_fraction_split=validation_fraction_split,
- test_fraction_split=test_fraction_split,
- training_filter_split=training_filter_split,
- validation_filter_split=validation_filter_split,
- test_filter_split=test_filter_split,
- predefined_split_column_name=predefined_split_column_name,
- timestamp_split_column_name=timestamp_split_column_name,
- tensorboard=tensorboard,
- sync=sync,
- )
- model.wait()
- return model
-
- @GoogleBaseHook.fallback_to_default_project_id
- def cancel_pipeline_job(
- self,
- project_id: str,
- region: str,
- pipeline_job: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> None:
- """
- Cancels a PipelineJob. Starts asynchronous cancellation on the PipelineJob. The server makes a best
- effort to cancel the pipeline, but success is not guaranteed. Clients can use
- [PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other
- methods to check whether the cancellation succeeded or whether the pipeline completed despite
- cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a
- pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a
- [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and
- [PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param pipeline_job: The name of the PipelineJob to cancel.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- name = client.pipeline_job_path(project_id, region, pipeline_job)
-
- client.cancel_pipeline_job(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
-
- @GoogleBaseHook.fallback_to_default_project_id
- def cancel_training_pipeline(
- self,
- project_id: str,
- region: str,
- training_pipeline: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> None:
- """
- Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes
- a best effort to cancel the pipeline, but success is not guaranteed. Clients can use
- [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1.PipelineService.GetTrainingPipeline]
- or other methods to check whether the cancellation succeeded or whether the pipeline completed despite
- cancellation. On successful cancellation, the TrainingPipeline is not deleted; instead it becomes a
- pipeline with a [TrainingPipeline.error][google.cloud.aiplatform.v1.TrainingPipeline.error] value with
- a [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and
- [TrainingPipeline.state][google.cloud.aiplatform.v1.TrainingPipeline.state] is set to ``CANCELLED``.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param training_pipeline: Required. The name of the TrainingPipeline to cancel.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- name = client.training_pipeline_path(project_id, region, training_pipeline)
-
- client.cancel_training_pipeline(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
-
- @GoogleBaseHook.fallback_to_default_project_id
- def cancel_custom_job(
- self,
- project_id: str,
- region: str,
- custom_job: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> None:
- """
- Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort
- to cancel the job, but success is not guaranteed. Clients can use
- [JobService.GetCustomJob][google.cloud.aiplatform.v1.JobService.GetCustomJob] or other methods to
- check whether the cancellation succeeded or whether the job completed despite cancellation. On
- successful cancellation, the CustomJob is not deleted; instead it becomes a job with a
- [CustomJob.error][google.cloud.aiplatform.v1.CustomJob.error] value with a
- [google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and
- [CustomJob.state][google.cloud.aiplatform.v1.CustomJob.state] is set to ``CANCELLED``.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param custom_job: Required. The name of the CustomJob to cancel.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_job_service_client(region)
- name = JobServiceClient.custom_job_path(project_id, region, custom_job)
-
- client.cancel_custom_job(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_pipeline_job(
- self,
- project_id: str,
- region: str,
- pipeline_job: PipelineJob,
- pipeline_job_id: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> PipelineJob:
- """
- Creates a PipelineJob. A PipelineJob will run immediately when created.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param pipeline_job: Required. The PipelineJob to create.
- :param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of
- the PipelineJob name. If not provided, an ID will be automatically generated.
-
- This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- parent = client.common_location_path(project_id, region)
-
- result = client.create_pipeline_job(
- request={
- 'parent': parent,
- 'pipeline_job': pipeline_job,
- 'pipeline_job_id': pipeline_job_id,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_training_pipeline(
- self,
- project_id: str,
- region: str,
- training_pipeline: TrainingPipeline,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> TrainingPipeline:
- """
- Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param training_pipeline: Required. The TrainingPipeline to create.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- parent = client.common_location_path(project_id, region)
-
- result = client.create_training_pipeline(
- request={
- 'parent': parent,
- 'training_pipeline': training_pipeline,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_custom_job(
- self,
- project_id: str,
- region: str,
- custom_job: CustomJob,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> CustomJob:
- """
- Creates a CustomJob. A created CustomJob right away will be attempted to be run.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param custom_job: Required. The CustomJob to create. This corresponds to the ``custom_job`` field on
- the ``request`` instance; if ``request`` is provided, this should not be set.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_job_service_client(region)
- parent = JobServiceClient.common_location_path(project_id, region)
-
- result = client.create_custom_job(
- request={
- 'parent': parent,
- 'custom_job': custom_job,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_custom_container_training_job(
- self,
- project_id: str,
- region: str,
- display_name: str,
- container_uri: str,
- command: Sequence[str] = [],
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- # RUN
- dataset: Optional[
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ] = None,
- annotation_schema_uri: Optional[str] = None,
- model_display_name: Optional[str] = None,
- model_labels: Optional[Dict[str, str]] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 1,
- machine_type: str = "n1-standard-4",
- accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
- accelerator_count: int = 0,
- boot_disk_type: str = "pd-ssd",
- boot_disk_size_gb: int = 100,
- training_fraction_split: Optional[float] = None,
- validation_fraction_split: Optional[float] = None,
- test_fraction_split: Optional[float] = None,
- training_filter_split: Optional[str] = None,
- validation_filter_split: Optional[str] = None,
- test_filter_split: Optional[str] = None,
- predefined_split_column_name: Optional[str] = None,
- timestamp_split_column_name: Optional[str] = None,
- tensorboard: Optional[str] = None,
- sync=True,
- ) -> Model:
- """
- Create Custom Container Training Job
-
- :param display_name: Required. The user-defined name of this TrainingPipeline.
- :param command: The command to be invoked when the container is started.
- It overrides the entrypoint instruction in Dockerfile when provided
- :param container_uri: Required: Uri of the training container image in the GCR.
- :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
- of the Model serving container suitable for serving the model produced by the
- training script.
- :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
- HTTP path to send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
- HTTP path to send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI Platform.
- :param model_serving_container_command: The command with which the container is run. Not executed
- within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
- this is not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- :param model_serving_container_environment_variables: The environment variables that are to be
- present in the container. Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
- field is primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- :param model_description: The description of the Model.
- :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param project_id: Project to run training in.
- :param region: Location to run training in.
- :param labels: Optional. The labels with user-defined metadata to
- organize TrainingPipelines.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
- :param staging_bucket: Bucket used to stage source and training artifacts.
- :param dataset: Vertex AI to fit this training against.
- :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object]
- (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
-
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
-
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
- :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
- :param model_labels: Optional. The labels with user-defined metadata to
- organize your Models.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
- staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
- i.e. <base_output_dir>/model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
- i.e. <base_output_dir>/checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
- logs, i.e. <base_output_dir>/logs/
-
- :param service_account: Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- :param network: The full name of the Compute Engine network to which the job
- should be peered.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- :param args: Command line arguments to be passed to the Python script.
- :param environment_variables: Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
- :param replica_count: The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- :param machine_type: The type of machine to use for training.
- :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- :param accelerator_count: The number of accelerators to attach to a worker replica.
- :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
- Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
- `pd-standard` (Persistent Disk Hard Disk Drive).
- :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
- boot disk size must be within the range of [100, 64000].
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
- the Model. This is ignored if Dataset is not provided.
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
- validate the Model. This is ignored if Dataset is not provided.
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
- the Model. This is ignored if Dataset is not provided.
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to train the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to validate the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to test the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key values of the key (the values in
- the column) must be in RFC 3339 `date-time` format, where
- `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
- piece of data the key is not present or has an invalid value,
- that piece is ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
- logs. Format:
- ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
- For more information on configuring your service account please visit:
- https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
- will be executed in concurrent Future and any downstream object will
- be immediately returned and synced when the Future has completed.
- """
- self._job = self.get_custom_container_training_job(
- project=project_id,
- location=region,
- display_name=display_name,
- container_uri=container_uri,
- command=command,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_description=model_description,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- labels=labels,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- staging_bucket=staging_bucket,
- )
-
- model = self._run_job(
- job=self._job,
- dataset=dataset,
- annotation_schema_uri=annotation_schema_uri,
- model_display_name=model_display_name,
- model_labels=model_labels,
- base_output_dir=base_output_dir,
- service_account=service_account,
- network=network,
- bigquery_destination=bigquery_destination,
- args=args,
- environment_variables=environment_variables,
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_type=accelerator_type,
- accelerator_count=accelerator_count,
- boot_disk_type=boot_disk_type,
- boot_disk_size_gb=boot_disk_size_gb,
- training_fraction_split=training_fraction_split,
- validation_fraction_split=validation_fraction_split,
- test_fraction_split=test_fraction_split,
- training_filter_split=training_filter_split,
- validation_filter_split=validation_filter_split,
- test_filter_split=test_filter_split,
- predefined_split_column_name=predefined_split_column_name,
- timestamp_split_column_name=timestamp_split_column_name,
- tensorboard=tensorboard,
- sync=sync,
- )
-
- return model
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_custom_python_package_training_job(
- self,
- project_id: str,
- region: str,
- display_name: str,
- python_package_gcs_uri: str,
- python_module_name: str,
- container_uri: str,
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- # RUN
- dataset: Optional[
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ] = None,
- annotation_schema_uri: Optional[str] = None,
- model_display_name: Optional[str] = None,
- model_labels: Optional[Dict[str, str]] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 1,
- machine_type: str = "n1-standard-4",
- accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
- accelerator_count: int = 0,
- boot_disk_type: str = "pd-ssd",
- boot_disk_size_gb: int = 100,
- training_fraction_split: Optional[float] = None,
- validation_fraction_split: Optional[float] = None,
- test_fraction_split: Optional[float] = None,
- training_filter_split: Optional[str] = None,
- validation_filter_split: Optional[str] = None,
- test_filter_split: Optional[str] = None,
- predefined_split_column_name: Optional[str] = None,
- timestamp_split_column_name: Optional[str] = None,
- tensorboard: Optional[str] = None,
- sync=True,
- ) -> Model:
- """
- Create Custom Python Package Training Job
-
- :param display_name: Required. The user-defined name of this TrainingPipeline.
- :param python_package_gcs_uri: Required: GCS location of the training python package.
- :param python_module_name: Required: The module name of the training python package.
- :param container_uri: Required: Uri of the training container image in the GCR.
- :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
- of the Model serving container suitable for serving the model produced by the
- training script.
- :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
- HTTP path to send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
- HTTP path to send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI Platform.
- :param model_serving_container_command: The command with which the container is run. Not executed
- within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
- this is not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- :param model_serving_container_environment_variables: The environment variables that are to be
- present in the container. Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
- field is primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- :param model_description: The description of the Model.
- :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param project_id: Project to run training in.
- :param region: Location to run training in.
- :param labels: Optional. The labels with user-defined metadata to
- organize TrainingPipelines.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
- :param staging_bucket: Bucket used to stage source and training artifacts.
- :param dataset: Vertex AI to fit this training against.
- :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object]
- (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
-
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
-
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
- :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
- :param model_labels: Optional. The labels with user-defined metadata to
- organize your Models.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
- staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
- i.e. <base_output_dir>/model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
- i.e. <base_output_dir>/checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
- logs, i.e. <base_output_dir>/logs/
- :param service_account: Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- :param network: The full name of the Compute Engine network to which the job
- should be peered.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- :param args: Command line arguments to be passed to the Python script.
- :param environment_variables: Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
- :param replica_count: The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- :param machine_type: The type of machine to use for training.
- :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- :param accelerator_count: The number of accelerators to attach to a worker replica.
- :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
- Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
- `pd-standard` (Persistent Disk Hard Disk Drive).
- :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
- boot disk size must be within the range of [100, 64000].
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
- the Model. This is ignored if Dataset is not provided.
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
- validate the Model. This is ignored if Dataset is not provided.
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
- the Model. This is ignored if Dataset is not provided.
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to train the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to validate the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to test the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key values of the key (the values in
- the column) must be in RFC 3339 `date-time` format, where
- `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
- piece of data the key is not present or has an invalid value,
- that piece is ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
- logs. Format:
- ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
- For more information on configuring your service account please visit:
- https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
- will be executed in concurrent Future and any downstream object will
- be immediately returned and synced when the Future has completed.
- """
- self._job = self.get_custom_python_package_training_job(
- project=project_id,
- location=region,
- display_name=display_name,
- python_package_gcs_uri=python_package_gcs_uri,
- python_module_name=python_module_name,
- container_uri=container_uri,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_description=model_description,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- labels=labels,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- staging_bucket=staging_bucket,
- )
-
- model = self._run_job(
- job=self._job,
- dataset=dataset,
- annotation_schema_uri=annotation_schema_uri,
- model_display_name=model_display_name,
- model_labels=model_labels,
- base_output_dir=base_output_dir,
- service_account=service_account,
- network=network,
- bigquery_destination=bigquery_destination,
- args=args,
- environment_variables=environment_variables,
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_type=accelerator_type,
- accelerator_count=accelerator_count,
- boot_disk_type=boot_disk_type,
- boot_disk_size_gb=boot_disk_size_gb,
- training_fraction_split=training_fraction_split,
- validation_fraction_split=validation_fraction_split,
- test_fraction_split=test_fraction_split,
- training_filter_split=training_filter_split,
- validation_filter_split=validation_filter_split,
- test_filter_split=test_filter_split,
- predefined_split_column_name=predefined_split_column_name,
- timestamp_split_column_name=timestamp_split_column_name,
- tensorboard=tensorboard,
- sync=sync,
- )
-
- return model
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_custom_training_job(
- self,
- project_id: str,
- region: str,
- display_name: str,
- script_path: str,
- container_uri: str,
- requirements: Optional[Sequence[str]] = None,
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- # RUN
- dataset: Optional[
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ] = None,
- annotation_schema_uri: Optional[str] = None,
- model_display_name: Optional[str] = None,
- model_labels: Optional[Dict[str, str]] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 1,
- machine_type: str = "n1-standard-4",
- accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
- accelerator_count: int = 0,
- boot_disk_type: str = "pd-ssd",
- boot_disk_size_gb: int = 100,
- training_fraction_split: Optional[float] = None,
- validation_fraction_split: Optional[float] = None,
- test_fraction_split: Optional[float] = None,
- training_filter_split: Optional[str] = None,
- validation_filter_split: Optional[str] = None,
- test_filter_split: Optional[str] = None,
- predefined_split_column_name: Optional[str] = None,
- timestamp_split_column_name: Optional[str] = None,
- tensorboard: Optional[str] = None,
- sync=True,
- ) -> Model:
- """
- Create Custom Training Job
-
- :param display_name: Required. The user-defined name of this TrainingPipeline.
- :param script_path: Required. Local path to training script.
- :param container_uri: Required: Uri of the training container image in the GCR.
- :param requirements: List of python packages dependencies of script.
- :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
- of the Model serving container suitable for serving the model produced by the
- training script.
- :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
- HTTP path to send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
- HTTP path to send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI Platform.
- :param model_serving_container_command: The command with which the container is run. Not executed
- within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
- this is not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- :param model_serving_container_environment_variables: The environment variables that are to be
- present in the container. Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
- field is primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- :param model_description: The description of the Model.
- :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param project_id: Project to run training in.
- :param region: Location to run training in.
- :param labels: Optional. The labels with user-defined metadata to
- organize TrainingPipelines.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
- :param staging_bucket: Bucket used to stage source and training artifacts.
- :param dataset: Vertex AI to fit this training against.
- :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object]
- (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
-
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
-
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
- :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
- :param model_labels: Optional. The labels with user-defined metadata to
- organize your Models.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
- staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
- i.e. <base_output_dir>/model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
- i.e. <base_output_dir>/checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
- logs, i.e. <base_output_dir>/logs/
- :param service_account: Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- :param network: The full name of the Compute Engine network to which the job
- should be peered.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- :param args: Command line arguments to be passed to the Python script.
- :param environment_variables: Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
- :param replica_count: The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- :param machine_type: The type of machine to use for training.
- :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- :param accelerator_count: The number of accelerators to attach to a worker replica.
- :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
- Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
- `pd-standard` (Persistent Disk Hard Disk Drive).
- :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
- boot disk size must be within the range of [100, 64000].
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
- the Model. This is ignored if Dataset is not provided.
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
- validate the Model. This is ignored if Dataset is not provided.
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
- the Model. This is ignored if Dataset is not provided.
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to train the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to validate the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to test the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key values of the key (the values in
- the column) must be in RFC 3339 `date-time` format, where
- `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
- piece of data the key is not present or has an invalid value,
- that piece is ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
- logs. Format:
- ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
- For more information on configuring your service account please visit:
- https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
- will be executed in concurrent Future and any downstream object will
- be immediately returned and synced when the Future has completed.
- """
- self._job = self.get_custom_training_job(
- project=project_id,
- location=region,
- display_name=display_name,
- script_path=script_path,
- container_uri=container_uri,
- requirements=requirements,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_description=model_description,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- labels=labels,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- staging_bucket=staging_bucket,
- )
-
- model = self._run_job(
- job=self._job,
- dataset=dataset,
- annotation_schema_uri=annotation_schema_uri,
- model_display_name=model_display_name,
- model_labels=model_labels,
- base_output_dir=base_output_dir,
- service_account=service_account,
- network=network,
- bigquery_destination=bigquery_destination,
- args=args,
- environment_variables=environment_variables,
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_type=accelerator_type,
- accelerator_count=accelerator_count,
- boot_disk_type=boot_disk_type,
- boot_disk_size_gb=boot_disk_size_gb,
- training_fraction_split=training_fraction_split,
- validation_fraction_split=validation_fraction_split,
- test_fraction_split=test_fraction_split,
- training_filter_split=training_filter_split,
- validation_filter_split=validation_filter_split,
- test_filter_split=test_filter_split,
- predefined_split_column_name=predefined_split_column_name,
- timestamp_split_column_name=timestamp_split_column_name,
- tensorboard=tensorboard,
- sync=sync,
- )
-
- return model
-
- @GoogleBaseHook.fallback_to_default_project_id
- def delete_pipeline_job(
- self,
- project_id: str,
- region: str,
- pipeline_job: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Deletes a PipelineJob.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- name = client.pipeline_job_path(project_id, region, pipeline_job)
-
- result = client.delete_pipeline_job(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def delete_training_pipeline(
- self,
- project_id: str,
- region: str,
- training_pipeline: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Deletes a TrainingPipeline.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- name = client.training_pipeline_path(project_id, region, training_pipeline)
-
- result = client.delete_training_pipeline(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def delete_custom_job(
- self,
- project_id: str,
- region: str,
- custom_job: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Deletes a CustomJob.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param custom_job: Required. The name of the CustomJob to delete.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_job_service_client(region)
- name = client.custom_job_path(project_id, region, custom_job)
-
- result = client.delete_custom_job(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_pipeline_job(
- self,
- project_id: str,
- region: str,
- pipeline_job: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> PipelineJob:
- """
- Gets a PipelineJob.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param pipeline_job: Required. The name of the PipelineJob resource.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- name = client.pipeline_job_path(project_id, region, pipeline_job)
-
- result = client.get_pipeline_job(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_training_pipeline(
- self,
- project_id: str,
- region: str,
- training_pipeline: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> TrainingPipeline:
- """
- Gets a TrainingPipeline.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param training_pipeline: Required. The name of the TrainingPipeline resource.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- name = client.training_pipeline_path(project_id, region, training_pipeline)
-
- result = client.get_training_pipeline(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_custom_job(
- self,
- project_id: str,
- region: str,
- custom_job: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> CustomJob:
- """
- Gets a CustomJob.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param custom_job: Required. The name of the CustomJob to get.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_job_service_client(region)
- name = JobServiceClient.custom_job_path(project_id, region, custom_job)
-
- result = client.get_custom_job(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_pipeline_jobs(
- self,
- project_id: str,
- region: str,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- filter: Optional[str] = None,
- order_by: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> ListPipelineJobsPager:
- """
- Lists PipelineJobs in a Location.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param filter: Optional. Lists the PipelineJobs that match the filter expression. The
- following fields are supported:
-
- - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons.
- - ``display_name``: Supports ``=``, ``!=`` comparisons, and
- ``:`` wildcard.
- - ``pipeline_job_user_id``: Supports ``=``, ``!=``
- comparisons, and ``:`` wildcard. for example, can check
- if pipeline's display_name contains *step* by doing
- display_name:"*step*"
- - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``,
- ``<=``, and ``>=`` comparisons. Values must be in RFC
- 3339 format.
- - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``,
- ``<=``, and ``>=`` comparisons. Values must be in RFC
- 3339 format.
- - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``,
- ``<=``, and ``>=`` comparisons. Values must be in RFC
- 3339 format.
- - ``labels``: Supports key-value equality and key presence.
-
- Filter expressions can be combined together using logical
- operators (``AND`` & ``OR``). For example:
- ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``.
-
- The syntax to define filter expression is based on
- https://google.aip.dev/160.
- :param page_size: Optional. The standard list page size.
- :param page_token: Optional. The standard list page token. Typically obtained via
- [ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token]
- of the previous
- [PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs]
- call.
- :param order_by: Optional. A comma-separated list of fields to order by. The default
- sort order is in ascending order. Use "desc" after a field
- name for descending. You can have multiple order_by fields
- provided e.g. "create_time desc, end_time", "end_time,
- start_time, update_time" For example, using "create_time
- desc, end_time" will order results by create time in
- descending order, and if there are multiple jobs having the
- same create time, order them by the end time in ascending
- order. if order_by is not specified, it will order by
- default order is create time in descending order. Supported
- fields:
-
- - ``create_time``
- - ``update_time``
- - ``end_time``
- - ``start_time``
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- parent = client.common_location_path(project_id, region)
-
- result = client.list_pipeline_jobs(
- request={
- 'parent': parent,
- 'page_size': page_size,
- 'page_token': page_token,
- 'filter': filter,
- 'order_by': order_by,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_training_pipelines(
- self,
- project_id: str,
- region: str,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- filter: Optional[str] = None,
- read_mask: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> ListTrainingPipelinesPager:
- """
- Lists TrainingPipelines in a Location.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param filter: Optional. The standard list filter. Supported fields:
-
- - ``display_name`` supports = and !=.
-
- - ``state`` supports = and !=.
-
- Some examples of using the filter are:
-
- - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"``
-
- - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"``
-
- - ``NOT display_name="my_pipeline"``
-
- - ``state="PIPELINE_STATE_FAILED"``
- :param page_size: Optional. The standard list page size.
- :param page_token: Optional. The standard list page token. Typically obtained via
- [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token]
- of the previous
- [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines]
- call.
- :param read_mask: Optional. Mask specifying which fields to read.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_pipeline_service_client(region)
- parent = client.common_location_path(project_id, region)
-
- result = client.list_training_pipelines(
- request={
- 'parent': parent,
- 'page_size': page_size,
- 'page_token': page_token,
- 'filter': filter,
- 'read_mask': read_mask,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_custom_jobs(
- self,
- project_id: str,
- region: str,
- page_size: Optional[int],
- page_token: Optional[str],
- filter: Optional[str],
- read_mask: Optional[str],
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> ListCustomJobsPager:
- """
- Lists CustomJobs in a Location.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param filter: Optional. The standard list filter. Supported fields:
-
- - ``display_name`` supports = and !=.
-
- - ``state`` supports = and !=.
-
- Some examples of using the filter are:
-
- - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"``
-
- - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"``
-
- - ``NOT display_name="my_pipeline"``
-
- - ``state="PIPELINE_STATE_FAILED"``
- :param page_size: Optional. The standard list page size.
- :param page_token: Optional. The standard list page token. Typically obtained via
- [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token]
- of the previous
- [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines]
- call.
- :param read_mask: Optional. Mask specifying which fields to read.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_job_service_client(region)
- parent = JobServiceClient.common_location_path(project_id, region)
-
- result = client.list_custom_jobs(
- request={
- 'parent': parent,
- 'page_size': page_size,
- 'page_token': page_token,
- 'filter': filter,
- 'read_mask': read_mask,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py b/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py
deleted file mode 100644
index 4a68c34..0000000
--- a/airflow/providers/google/cloud/hooks/vertex_ai/dataset.py
+++ /dev/null
@@ -1,460 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-"""This module contains a Google Cloud Vertex AI hook."""
-
-from typing import Dict, Optional, Sequence, Tuple
-
-from google.api_core.operation import Operation
-from google.api_core.retry import Retry
-from google.cloud.aiplatform_v1 import DatasetServiceClient
-from google.cloud.aiplatform_v1.services.dataset_service.pagers import (
- ListAnnotationsPager,
- ListDataItemsPager,
- ListDatasetsPager,
-)
-from google.cloud.aiplatform_v1.types import AnnotationSpec, Dataset, ExportDataConfig, ImportDataConfig
-from google.protobuf.field_mask_pb2 import FieldMask
-
-from airflow import AirflowException
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
-
-
-class DatasetHook(GoogleBaseHook):
- """Hook for Google Cloud Vertex AI Dataset APIs."""
-
- def get_dataset_service_client(self, region: Optional[str] = None) -> DatasetServiceClient:
- """Returns DatasetServiceClient."""
- client_options = None
- if region and region != 'global':
- client_options = {'api_endpoint': f'{region}-aiplatform.googleapis.com:443'}
-
- return DatasetServiceClient(
- credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
- )
-
- def wait_for_operation(self, timeout: float, operation: Operation):
- """Waits for long-lasting operation to complete."""
- try:
- return operation.result(timeout=timeout)
- except Exception:
- error = operation.exception(timeout=timeout)
- raise AirflowException(error)
-
- @staticmethod
- def extract_dataset_id(obj: Dict) -> str:
- """Returns unique id of the dataset."""
- return obj["name"].rpartition("/")[-1]
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_dataset(
- self,
- project_id: str,
- region: str,
- dataset: Dataset,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Creates a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The Dataset to create.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- parent = client.common_location_path(project_id, region)
-
- result = client.create_dataset(
- request={
- 'parent': parent,
- 'dataset': dataset,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def delete_dataset(
- self,
- project_id: str,
- region: str,
- dataset: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Deletes a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset to delete.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- name = client.dataset_path(project_id, region, dataset)
-
- result = client.delete_dataset(
- request={
- 'name': name,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def export_data(
- self,
- project_id: str,
- region: str,
- dataset: str,
- export_config: ExportDataConfig,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Exports data from a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset to export.
- :param export_config: Required. The desired output location.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- name = client.dataset_path(project_id, region, dataset)
-
- result = client.export_data(
- request={
- 'name': name,
- 'export_config': export_config,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_annotation_spec(
- self,
- project_id: str,
- region: str,
- dataset: str,
- annotation_spec: str,
- read_mask: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> AnnotationSpec:
- """
- Gets an AnnotationSpec.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset.
- :param annotation_spec: The ID of the AnnotationSpec resource.
- :param read_mask: Optional. Mask specifying which fields to read.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- name = client.annotation_spec_path(project_id, region, dataset, annotation_spec)
-
- result = client.get_annotation_spec(
- request={
- 'name': name,
- 'read_mask': read_mask,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_dataset(
- self,
- project_id: str,
- region: str,
- dataset: str,
- read_mask: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Dataset:
- """
- Gets a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset to export.
- :param read_mask: Optional. Mask specifying which fields to read.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- name = client.dataset_path(project_id, region, dataset)
-
- result = client.get_dataset(
- request={
- 'name': name,
- 'read_mask': read_mask,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def import_data(
- self,
- project_id: str,
- region: str,
- dataset: str,
- import_configs: Sequence[ImportDataConfig],
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Operation:
- """
- Imports data into a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset to import.
- :param import_configs: Required. The desired input locations. The contents of all input locations
- will be imported in one batch.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- name = client.dataset_path(project_id, region, dataset)
-
- result = client.import_data(
- request={
- 'name': name,
- 'import_configs': import_configs,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_annotations(
- self,
- project_id: str,
- region: str,
- dataset: str,
- data_item: str,
- filter: Optional[str] = None,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- read_mask: Optional[str] = None,
- order_by: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> ListAnnotationsPager:
- """
- Lists Annotations belongs to a data item
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset.
- :param data_item: Required. The ID of the DataItem to list Annotations from.
- :param filter: The standard list filter.
- :param page_size: The standard list page size.
- :param page_token: The standard list page token.
- :param read_mask: Mask specifying which fields to read.
- :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc"
- after a field name for descending.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- parent = client.data_item_path(project_id, region, dataset, data_item)
-
- result = client.list_annotations(
- request={
- 'parent': parent,
- 'filter': filter,
- 'page_size': page_size,
- 'page_token': page_token,
- 'read_mask': read_mask,
- 'order_by': order_by,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_data_items(
- self,
- project_id: str,
- region: str,
- dataset: str,
- filter: Optional[str] = None,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- read_mask: Optional[str] = None,
- order_by: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> ListDataItemsPager:
- """
- Lists DataItems in a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset: Required. The ID of the Dataset.
- :param filter: The standard list filter.
- :param page_size: The standard list page size.
- :param page_token: The standard list page token.
- :param read_mask: Mask specifying which fields to read.
- :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc"
- after a field name for descending.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- parent = client.dataset_path(project_id, region, dataset)
-
- result = client.list_data_items(
- request={
- 'parent': parent,
- 'filter': filter,
- 'page_size': page_size,
- 'page_token': page_token,
- 'read_mask': read_mask,
- 'order_by': order_by,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_datasets(
- self,
- project_id: str,
- region: str,
- filter: Optional[str] = None,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- read_mask: Optional[str] = None,
- order_by: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> ListDatasetsPager:
- """
- Lists Datasets in a Location.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param filter: The standard list filter.
- :param page_size: The standard list page size.
- :param page_token: The standard list page token.
- :param read_mask: Mask specifying which fields to read.
- :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc"
- after a field name for descending.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- parent = client.common_location_path(project_id, region)
-
- result = client.list_datasets(
- request={
- 'parent': parent,
- 'filter': filter,
- 'page_size': page_size,
- 'page_token': page_token,
- 'read_mask': read_mask,
- 'order_by': order_by,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- def update_dataset(
- self,
- project_id: str,
- region: str,
- dataset_id: str,
- dataset: Dataset,
- update_mask: FieldMask,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = None,
- ) -> Dataset:
- """
- Updates a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset_id: Required. The ID of the Dataset.
- :param dataset: Required. The Dataset which replaces the resource on the server.
- :param update_mask: Required. The update mask applies to the resource.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- """
- client = self.get_dataset_service_client(region)
- dataset["name"] = client.dataset_path(project_id, region, dataset_id)
-
- result = client.update_dataset(
- request={
- 'dataset': dataset,
- 'update_mask': update_mask,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/__init__.py b/airflow/providers/google/cloud/operators/vertex_ai/__init__.py
deleted file mode 100644
index 13a8339..0000000
--- a/airflow/providers/google/cloud/operators/vertex_ai/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
deleted file mode 100644
index 875186b..0000000
--- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++ /dev/null
@@ -1,1427 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-"""This module contains Google Vertex AI operators."""
-
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
-
-from google.api_core.exceptions import NotFound
-from google.api_core.retry import Retry
-from google.cloud.aiplatform.models import Model
-from google.cloud.aiplatform_v1.types.dataset import Dataset
-from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline
-
-from airflow.models import BaseOperator, BaseOperatorLink
-from airflow.models.taskinstance import TaskInstance
-from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
-
-if TYPE_CHECKING:
- from airflow.utils.context import Context
-
-VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai"
-VERTEX_AI_MODEL_LINK = (
- VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}"
-)
-VERTEX_AI_TRAINING_PIPELINES_LINK = VERTEX_AI_BASE_LINK + "/training/training-pipelines?project={project_id}"
-
-
-class VertexAIModelLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Model link"""
-
- name = "Vertex AI Model"
-
- def get_link(self, operator, dttm):
- ti = TaskInstance(task=operator, execution_date=dttm)
- model_conf = ti.xcom_pull(task_ids=operator.task_id, key="model_conf")
- return (
- VERTEX_AI_MODEL_LINK.format(
- region=model_conf["region"],
- model_id=model_conf["model_id"],
- project_id=model_conf["project_id"],
- )
- if model_conf
- else ""
- )
-
-
-class VertexAITrainingPipelinesLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Training Pipelines link"""
-
- name = "Vertex AI Training Pipelines"
-
- def get_link(self, operator, dttm):
- ti = TaskInstance(task=operator, execution_date=dttm)
- project_id = ti.xcom_pull(task_ids=operator.task_id, key="project_id")
- return (
- VERTEX_AI_TRAINING_PIPELINES_LINK.format(
- project_id=project_id,
- )
- if project_id
- else ""
- )
-
-
-class CustomTrainingJobBaseOperator(BaseOperator):
- """The base class for operators that launch Custom jobs on VertexAI."""
-
- def __init__(
- self,
- *,
- project_id: str,
- region: str,
- display_name: str,
- container_uri: str,
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
- labels: Optional[Dict[str, str]] = None,
- training_encryption_spec_key_name: Optional[str] = None,
- model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
- # RUN
- dataset_id: Optional[str] = None,
- annotation_schema_uri: Optional[str] = None,
- model_display_name: Optional[str] = None,
- model_labels: Optional[Dict[str, str]] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 1,
- machine_type: str = "n1-standard-4",
- accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
- accelerator_count: int = 0,
- boot_disk_type: str = "pd-ssd",
- boot_disk_size_gb: int = 100,
- training_fraction_split: Optional[float] = None,
- validation_fraction_split: Optional[float] = None,
- test_fraction_split: Optional[float] = None,
- training_filter_split: Optional[str] = None,
- validation_filter_split: Optional[str] = None,
- test_filter_split: Optional[str] = None,
- predefined_split_column_name: Optional[str] = None,
- timestamp_split_column_name: Optional[str] = None,
- tensorboard: Optional[str] = None,
- sync=True,
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.project_id = project_id
- self.region = region
- self.display_name = display_name
- # START Custom
- self.container_uri = container_uri
- self.model_serving_container_image_uri = model_serving_container_image_uri
- self.model_serving_container_predict_route = model_serving_container_predict_route
- self.model_serving_container_health_route = model_serving_container_health_route
- self.model_serving_container_command = model_serving_container_command
- self.model_serving_container_args = model_serving_container_args
- self.model_serving_container_environment_variables = model_serving_container_environment_variables
- self.model_serving_container_ports = model_serving_container_ports
- self.model_description = model_description
- self.model_instance_schema_uri = model_instance_schema_uri
- self.model_parameters_schema_uri = model_parameters_schema_uri
- self.model_prediction_schema_uri = model_prediction_schema_uri
- self.labels = labels
- self.training_encryption_spec_key_name = training_encryption_spec_key_name
- self.model_encryption_spec_key_name = model_encryption_spec_key_name
- self.staging_bucket = staging_bucket
- # END Custom
- # START Run param
- self.dataset = Dataset(name=dataset_id) if dataset_id else None
- self.annotation_schema_uri = annotation_schema_uri
- self.model_display_name = model_display_name
- self.model_labels = model_labels
- self.base_output_dir = base_output_dir
- self.service_account = service_account
- self.network = network
- self.bigquery_destination = bigquery_destination
- self.args = args
- self.environment_variables = environment_variables
- self.replica_count = replica_count
- self.machine_type = machine_type
- self.accelerator_type = accelerator_type
- self.accelerator_count = accelerator_count
- self.boot_disk_type = boot_disk_type
- self.boot_disk_size_gb = boot_disk_size_gb
- self.training_fraction_split = training_fraction_split
- self.validation_fraction_split = validation_fraction_split
- self.test_fraction_split = test_fraction_split
- self.training_filter_split = training_filter_split
- self.validation_filter_split = validation_filter_split
- self.test_filter_split = test_filter_split
- self.predefined_split_column_name = predefined_split_column_name
- self.timestamp_split_column_name = timestamp_split_column_name
- self.tensorboard = tensorboard
- self.sync = sync
- # END Run param
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
- self.hook: Optional[CustomJobHook] = None
-
- def execute(self, context: 'Context'):
- self.hook = CustomJobHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
-
- def on_kill(self) -> None:
- """
- Callback called when the operator is killed.
- Cancel any running job.
- """
- if self.hook:
- self.hook.cancel_job()
-
-
-class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
- """Create Custom Container Training job
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param display_name: Required. The user-defined name of this TrainingPipeline.
- :param command: The command to be invoked when the container is started.
- It overrides the entrypoint instruction in Dockerfile when provided
- :param container_uri: Required: Uri of the training container image in the GCR.
- :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
- of the Model serving container suitable for serving the model produced by the
- training script.
- :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
- HTTP path to send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
- HTTP path to send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI Platform.
- :param model_serving_container_command: The command with which the container is run. Not executed
- within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
- this is not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- :param model_serving_container_environment_variables: The environment variables that are to be
- present in the container. Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
- field is primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- :param model_description: The description of the Model.
- :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param project_id: Project to run training in.
- :param region: Location to run training in.
- :param labels: Optional. The labels with user-defined metadata to
- organize TrainingPipelines.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
- :param staging_bucket: Bucket used to stage source and training artifacts.
- :param dataset: Vertex AI to fit this training against.
- :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object]
- (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
-
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
-
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
- :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
- :param model_labels: Optional. The labels with user-defined metadata to
- organize your Models.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
- staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
- i.e. <base_output_dir>/model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
- i.e. <base_output_dir>/checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
- logs, i.e. <base_output_dir>/logs/
- :param service_account: Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- :param network: The full name of the Compute Engine network to which the job
- should be peered.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- :param args: Command line arguments to be passed to the Python script.
- :param environment_variables: Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
- :param replica_count: The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- :param machine_type: The type of machine to use for training.
- :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- :param accelerator_count: The number of accelerators to attach to a worker replica.
- :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
- Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
- `pd-standard` (Persistent Disk Hard Disk Drive).
- :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
- boot disk size must be within the range of [100, 64000].
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
- the Model. This is ignored if Dataset is not provided.
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
- validate the Model. This is ignored if Dataset is not provided.
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
- the Model. This is ignored if Dataset is not provided.
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to train the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to validate the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to test the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key values of the key (the values in
- the column) must be in RFC 3339 `date-time` format, where
- `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
- piece of data the key is not present or has an invalid value,
- that piece is ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
- logs. Format:
- ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
- For more information on configuring your service account please visit:
- https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
- will be executed in concurrent Future and any downstream object will
- be immediately returned and synced when the Future has completed.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = [
- 'region',
- 'command',
- 'impersonation_chain',
- ]
- operator_extra_links = (VertexAIModelLink(),)
-
- def __init__(
- self,
- *,
- command: Sequence[str] = [],
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.command = command
-
- def execute(self, context: 'Context'):
- super().execute(context)
- model = self.hook.create_custom_container_training_job(
- project_id=self.project_id,
- region=self.region,
- display_name=self.display_name,
- container_uri=self.container_uri,
- command=self.command,
- model_serving_container_image_uri=self.model_serving_container_image_uri,
- model_serving_container_predict_route=self.model_serving_container_predict_route,
- model_serving_container_health_route=self.model_serving_container_health_route,
- model_serving_container_command=self.model_serving_container_command,
- model_serving_container_args=self.model_serving_container_args,
- model_serving_container_environment_variables=self.model_serving_container_environment_variables,
- model_serving_container_ports=self.model_serving_container_ports,
- model_description=self.model_description,
- model_instance_schema_uri=self.model_instance_schema_uri,
- model_parameters_schema_uri=self.model_parameters_schema_uri,
- model_prediction_schema_uri=self.model_prediction_schema_uri,
- labels=self.labels,
- training_encryption_spec_key_name=self.training_encryption_spec_key_name,
- model_encryption_spec_key_name=self.model_encryption_spec_key_name,
- staging_bucket=self.staging_bucket,
- # RUN
- dataset=self.dataset,
- annotation_schema_uri=self.annotation_schema_uri,
- model_display_name=self.model_display_name,
- model_labels=self.model_labels,
- base_output_dir=self.base_output_dir,
- service_account=self.service_account,
- network=self.network,
- bigquery_destination=self.bigquery_destination,
- args=self.args,
- environment_variables=self.environment_variables,
- replica_count=self.replica_count,
- machine_type=self.machine_type,
- accelerator_type=self.accelerator_type,
- accelerator_count=self.accelerator_count,
- boot_disk_type=self.boot_disk_type,
- boot_disk_size_gb=self.boot_disk_size_gb,
- training_fraction_split=self.training_fraction_split,
- validation_fraction_split=self.validation_fraction_split,
- test_fraction_split=self.test_fraction_split,
- training_filter_split=self.training_filter_split,
- validation_filter_split=self.validation_filter_split,
- test_filter_split=self.test_filter_split,
- predefined_split_column_name=self.predefined_split_column_name,
- timestamp_split_column_name=self.timestamp_split_column_name,
- tensorboard=self.tensorboard,
- sync=True,
- )
-
- result = Model.to_dict(model)
- model_id = self.hook.extract_model_id(result)
- self.xcom_push(
- context,
- key="model_conf",
- value={
- "model_id": model_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
- return result
-
-
-class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator):
- """Create Custom Python Package Training job
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param display_name: Required. The user-defined name of this TrainingPipeline.
- :param python_package_gcs_uri: Required: GCS location of the training python package.
- :param python_module_name: Required: The module name of the training python package.
- :param container_uri: Required: Uri of the training container image in the GCR.
- :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
- of the Model serving container suitable for serving the model produced by the
- training script.
- :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
- HTTP path to send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
- HTTP path to send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI Platform.
- :param model_serving_container_command: The command with which the container is run. Not executed
- within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
- this is not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- :param model_serving_container_environment_variables: The environment variables that are to be
- present in the container. Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
- field is primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- :param model_description: The description of the Model.
- :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param project_id: Project to run training in.
- :param region: Location to run training in.
- :param labels: Optional. The labels with user-defined metadata to
- organize TrainingPipelines.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
- :param staging_bucket: Bucket used to stage source and training artifacts.
- :param dataset: Vertex AI to fit this training against.
- :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object]
- (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
-
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
-
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
- :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
- :param model_labels: Optional. The labels with user-defined metadata to
- organize your Models.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
- staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
- i.e. <base_output_dir>/model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
- i.e. <base_output_dir>/checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
- logs, i.e. <base_output_dir>/logs/
- :param service_account: Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- :param network: The full name of the Compute Engine network to which the job
- should be peered.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- :param args: Command line arguments to be passed to the Python script.
- :param environment_variables: Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
- :param replica_count: The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- :param machine_type: The type of machine to use for training.
- :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- :param accelerator_count: The number of accelerators to attach to a worker replica.
- :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
- Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
- `pd-standard` (Persistent Disk Hard Disk Drive).
- :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
- boot disk size must be within the range of [100, 64000].
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
- the Model. This is ignored if Dataset is not provided.
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
- validate the Model. This is ignored if Dataset is not provided.
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
- the Model. This is ignored if Dataset is not provided.
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to train the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to validate the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to test the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key values of the key (the values in
- the column) must be in RFC 3339 `date-time` format, where
- `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
- piece of data the key is not present or has an invalid value,
- that piece is ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
- logs. Format:
- ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
- For more information on configuring your service account please visit:
- https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
- will be executed in concurrent Future and any downstream object will
- be immediately returned and synced when the Future has completed.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = [
- 'region',
- 'impersonation_chain',
- ]
- operator_extra_links = (VertexAIModelLink(),)
-
- def __init__(
- self,
- *,
- python_package_gcs_uri: str,
- python_module_name: str,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.python_package_gcs_uri = python_package_gcs_uri
- self.python_module_name = python_module_name
-
- def execute(self, context: 'Context'):
- super().execute(context)
- model = self.hook.create_custom_python_package_training_job(
- project_id=self.project_id,
- region=self.region,
- display_name=self.display_name,
- python_package_gcs_uri=self.python_package_gcs_uri,
- python_module_name=self.python_module_name,
- container_uri=self.container_uri,
- model_serving_container_image_uri=self.model_serving_container_image_uri,
- model_serving_container_predict_route=self.model_serving_container_predict_route,
- model_serving_container_health_route=self.model_serving_container_health_route,
- model_serving_container_command=self.model_serving_container_command,
- model_serving_container_args=self.model_serving_container_args,
- model_serving_container_environment_variables=self.model_serving_container_environment_variables,
- model_serving_container_ports=self.model_serving_container_ports,
- model_description=self.model_description,
- model_instance_schema_uri=self.model_instance_schema_uri,
- model_parameters_schema_uri=self.model_parameters_schema_uri,
- model_prediction_schema_uri=self.model_prediction_schema_uri,
- labels=self.labels,
- training_encryption_spec_key_name=self.training_encryption_spec_key_name,
- model_encryption_spec_key_name=self.model_encryption_spec_key_name,
- staging_bucket=self.staging_bucket,
- # RUN
- dataset=self.dataset,
- annotation_schema_uri=self.annotation_schema_uri,
- model_display_name=self.model_display_name,
- model_labels=self.model_labels,
- base_output_dir=self.base_output_dir,
- service_account=self.service_account,
- network=self.network,
- bigquery_destination=self.bigquery_destination,
- args=self.args,
- environment_variables=self.environment_variables,
- replica_count=self.replica_count,
- machine_type=self.machine_type,
- accelerator_type=self.accelerator_type,
- accelerator_count=self.accelerator_count,
- boot_disk_type=self.boot_disk_type,
- boot_disk_size_gb=self.boot_disk_size_gb,
- training_fraction_split=self.training_fraction_split,
- validation_fraction_split=self.validation_fraction_split,
- test_fraction_split=self.test_fraction_split,
- training_filter_split=self.training_filter_split,
- validation_filter_split=self.validation_filter_split,
- test_filter_split=self.test_filter_split,
- predefined_split_column_name=self.predefined_split_column_name,
- timestamp_split_column_name=self.timestamp_split_column_name,
- tensorboard=self.tensorboard,
- sync=True,
- )
-
- result = Model.to_dict(model)
- model_id = self.hook.extract_model_id(result)
- self.xcom_push(
- context,
- key="model_conf",
- value={
- "model_id": model_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
- return result
-
-
-class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
- """Create Custom Training job
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param display_name: Required. The user-defined name of this TrainingPipeline.
- :param script_path: Required. Local path to training script.
- :param container_uri: Required: Uri of the training container image in the GCR.
- :param requirements: List of python packages dependencies of script.
- :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI
- of the Model serving container suitable for serving the model produced by the
- training script.
- :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An
- HTTP path to send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an
- HTTP path to send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI Platform.
- :param model_serving_container_command: The command with which the container is run. Not executed
- within a shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if
- this is not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- :param model_serving_container_environment_variables: The environment variables that are to be
- present in the container. Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- :param model_serving_container_ports: Declaration of ports that are exposed by the container. This
- field is primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- :param model_description: The description of the Model.
- :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object <https://tinyurl.com/y538mdwt#schema-object>`__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- :param project_id: Project to run training in.
- :param region: Location to run training in.
- :param labels: Optional. The labels with user-defined metadata to
- organize TrainingPipelines.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
- :param staging_bucket: Bucket used to stage source and training artifacts.
- :param dataset: Vertex AI to fit this training against.
- :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object]
- (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object)
-
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
-
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
- :param model_display_name: If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
- :param model_labels: Optional. The labels with user-defined metadata to
- organize your Models.
- Label keys and values can be no longer than 64
- characters, can only
- contain lowercase letters, numeric characters,
- underscores and dashes. International characters
- are allowed.
- See https://goo.gl/xmQnxf for more information
- and examples of labels.
- :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the
- staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts,
- i.e. <base_output_dir>/model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints,
- i.e. <base_output_dir>/checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard
- logs, i.e. <base_output_dir>/logs/
- :param service_account: Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- :param network: The full name of the Compute Engine network to which the job
- should be peered.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- :param args: Command line arguments to be passed to the Python script.
- :param environment_variables: Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
- :param replica_count: The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- :param machine_type: The type of machine to use for training.
- :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- :param accelerator_count: The number of accelerators to attach to a worker replica.
- :param boot_disk_type: Type of the boot disk, default is `pd-ssd`.
- Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
- `pd-standard` (Persistent Disk Hard Disk Drive).
- :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB.
- boot disk size must be within the range of [100, 64000].
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
- the Model. This is ignored if Dataset is not provided.
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
- validate the Model. This is ignored if Dataset is not provided.
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
- the Model. This is ignored if Dataset is not provided.
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to train the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to validate the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
- this filter are used to test the Model. A filter with same syntax
- as the one used in DatasetService.ListDataItems may be used. If a
- single DataItem is matched by more than one of the FilterSplit filters,
- then it is assigned to the first set that applies to it in the training,
- validation, test order. This is ignored if Dataset is not provided.
- :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data
- columns. The value of the key values of the key (the values in
- the column) must be in RFC 3339 `date-time` format, where
- `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
- piece of data the key is not present or has an invalid value,
- that piece is ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
- :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload
- logs. Format:
- ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
- For more information on configuring your service account please visit:
- https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
- :param sync: Whether to execute the AI Platform job synchronously. If False, this method
- will be executed in concurrent Future and any downstream object will
- be immediately returned and synced when the Future has completed.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = [
- 'region',
- 'script_path',
- 'requirements',
- 'impersonation_chain',
- ]
- operator_extra_links = (VertexAIModelLink(),)
-
- def __init__(
- self,
- *,
- script_path: str,
- requirements: Optional[Sequence[str]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.requirements = requirements
- self.script_path = script_path
-
- def execute(self, context: 'Context'):
- super().execute(context)
- model = self.hook.create_custom_training_job(
- project_id=self.project_id,
- region=self.region,
- display_name=self.display_name,
- script_path=self.script_path,
- container_uri=self.container_uri,
- requirements=self.requirements,
- model_serving_container_image_uri=self.model_serving_container_image_uri,
- model_serving_container_predict_route=self.model_serving_container_predict_route,
- model_serving_container_health_route=self.model_serving_container_health_route,
- model_serving_container_command=self.model_serving_container_command,
- model_serving_container_args=self.model_serving_container_args,
- model_serving_container_environment_variables=self.model_serving_container_environment_variables,
- model_serving_container_ports=self.model_serving_container_ports,
- model_description=self.model_description,
- model_instance_schema_uri=self.model_instance_schema_uri,
- model_parameters_schema_uri=self.model_parameters_schema_uri,
- model_prediction_schema_uri=self.model_prediction_schema_uri,
- labels=self.labels,
- training_encryption_spec_key_name=self.training_encryption_spec_key_name,
- model_encryption_spec_key_name=self.model_encryption_spec_key_name,
- staging_bucket=self.staging_bucket,
- # RUN
- dataset=self.dataset,
- annotation_schema_uri=self.annotation_schema_uri,
- model_display_name=self.model_display_name,
- model_labels=self.model_labels,
- base_output_dir=self.base_output_dir,
- service_account=self.service_account,
- network=self.network,
- bigquery_destination=self.bigquery_destination,
- args=self.args,
- environment_variables=self.environment_variables,
- replica_count=self.replica_count,
- machine_type=self.machine_type,
- accelerator_type=self.accelerator_type,
- accelerator_count=self.accelerator_count,
- boot_disk_type=self.boot_disk_type,
- boot_disk_size_gb=self.boot_disk_size_gb,
- training_fraction_split=self.training_fraction_split,
- validation_fraction_split=self.validation_fraction_split,
- test_fraction_split=self.test_fraction_split,
- training_filter_split=self.training_filter_split,
- validation_filter_split=self.validation_filter_split,
- test_filter_split=self.test_filter_split,
- predefined_split_column_name=self.predefined_split_column_name,
- timestamp_split_column_name=self.timestamp_split_column_name,
- tensorboard=self.tensorboard,
- sync=True,
- )
-
- result = Model.to_dict(model)
- model_id = self.hook.extract_model_id(result)
- self.xcom_push(
- context,
- key="model_conf",
- value={
- "model_id": model_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
- return result
-
-
-class DeleteCustomTrainingJobOperator(BaseOperator):
- """Deletes a CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob.
-
- :param training_pipeline_id: Required. The name of the TrainingPipeline resource to be deleted.
- :param custom_job_id: Required. The name of the CustomJob to delete.
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "project_id", "impersonation_chain")
-
- def __init__(
- self,
- *,
- training_pipeline_id: str,
- custom_job_id: str,
- region: str,
- project_id: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.training_pipeline = training_pipeline_id
- self.custom_job = custom_job_id
- self.region = region
- self.project_id = project_id
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = CustomJobHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
- try:
- self.log.info("Deleting custom training pipeline: %s", self.training_pipeline)
- training_pipeline_operation = hook.delete_training_pipeline(
- training_pipeline=self.training_pipeline,
- region=self.region,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=training_pipeline_operation)
- self.log.info("Training pipeline was deleted.")
- except NotFound:
- self.log.info("The Training Pipeline ID %s does not exist.", self.training_pipeline)
- try:
- self.log.info("Deleting custom job: %s", self.custom_job)
- custom_job_operation = hook.delete_custom_job(
- custom_job=self.custom_job,
- region=self.region,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=custom_job_operation)
- self.log.info("Custom job was deleted.")
- except NotFound:
- self.log.info("The Custom Job ID %s does not exist.", self.custom_job)
-
-
-class ListCustomTrainingJobOperator(BaseOperator):
- """Lists CustomTrainingJob, CustomPythonTrainingJob, or CustomContainerTrainingJob in a Location.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param filter: Optional. The standard list filter. Supported fields:
-
- - ``display_name`` supports = and !=.
-
- - ``state`` supports = and !=.
-
- Some examples of using the filter are:
-
- - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"``
-
- - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"``
-
- - ``NOT display_name="my_pipeline"``
-
- - ``state="PIPELINE_STATE_FAILED"``
- :param page_size: Optional. The standard list page size.
- :param page_token: Optional. The standard list page token. Typically obtained via
- [ListTrainingPipelinesResponse.next_page_token][google.cloud.aiplatform.v1.ListTrainingPipelinesResponse.next_page_token]
- of the previous
- [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1.PipelineService.ListTrainingPipelines]
- call.
- :param read_mask: Optional. Mask specifying which fields to read.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = [
- "region",
- "project_id",
- "impersonation_chain",
- ]
- operator_extra_links = [
- VertexAITrainingPipelinesLink(),
- ]
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- filter: Optional[str] = None,
- read_mask: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.page_size = page_size
- self.page_token = page_token
- self.filter = filter
- self.read_mask = read_mask
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = CustomJobHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
- results = hook.list_training_pipelines(
- region=self.region,
- project_id=self.project_id,
- page_size=self.page_size,
- page_token=self.page_token,
- filter=self.filter,
- read_mask=self.read_mask,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.xcom_push(context, key="project_id", value=self.project_id)
- return [TrainingPipeline.to_dict(result) for result in results]
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
deleted file mode 100644
index 1def925..0000000
--- a/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
+++ /dev/null
@@ -1,644 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-"""This module contains Google Vertex AI operators."""
-
-from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
-
-from google.api_core.exceptions import NotFound
-from google.api_core.retry import Retry
-from google.cloud.aiplatform_v1.types import Dataset, ExportDataConfig, ImportDataConfig
-from google.protobuf.field_mask_pb2 import FieldMask
-
-from airflow.models import BaseOperator, BaseOperatorLink
-from airflow.models.taskinstance import TaskInstance
-from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
-
-if TYPE_CHECKING:
- from airflow.utils.context import Context
-
-VERTEX_AI_BASE_LINK = "https://console.cloud.google.com/vertex-ai"
-VERTEX_AI_DATASET_LINK = (
- VERTEX_AI_BASE_LINK + "/locations/{region}/datasets/{dataset_id}/analyze?project={project_id}"
-)
-VERTEX_AI_DATASET_LIST_LINK = VERTEX_AI_BASE_LINK + "/datasets?project={project_id}"
-
-
-class VertexAIDatasetLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Dataset link"""
-
- name = "Dataset"
-
- def get_link(self, operator, dttm):
- ti = TaskInstance(task=operator, execution_date=dttm)
- dataset_conf = ti.xcom_pull(task_ids=operator.task_id, key="dataset_conf")
- return (
- VERTEX_AI_DATASET_LINK.format(
- region=dataset_conf["region"],
- dataset_id=dataset_conf["dataset_id"],
- project_id=dataset_conf["project_id"],
- )
- if dataset_conf
- else ""
- )
-
-
-class VertexAIDatasetListLink(BaseOperatorLink):
- """Helper class for constructing Vertex AI Datasets Link"""
-
- name = "Dataset List"
-
- def get_link(self, operator, dttm):
- ti = TaskInstance(task=operator, execution_date=dttm)
- project_id = ti.xcom_pull(task_ids=operator.task_id, key="project_id")
- return (
- VERTEX_AI_DATASET_LIST_LINK.format(
- project_id=project_id,
- )
- if project_id
- else ""
- )
-
-
-class CreateDatasetOperator(BaseOperator):
- """
- Creates a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
- :param region: Required. The Cloud Dataproc region in which to handle the request.
- :param dataset: Required. The Dataset to create. This corresponds to the ``dataset`` field on the
- ``request`` instance; if ``request`` is provided, this should not be set.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "project_id", "impersonation_chain")
- operator_extra_links = (VertexAIDatasetLink(),)
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- dataset: Dataset,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.dataset = dataset
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
-
- self.log.info("Creating dataset")
- operation = hook.create_dataset(
- project_id=self.project_id,
- region=self.region,
- dataset=self.dataset,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- result = hook.wait_for_operation(self.timeout, operation)
-
- dataset = Dataset.to_dict(result)
- dataset_id = hook.extract_dataset_id(dataset)
- self.log.info("Dataset was created. Dataset id: %s", dataset_id)
-
- self.xcom_push(context, key="dataset_id", value=dataset_id)
- self.xcom_push(
- context,
- key="dataset_conf",
- value={
- "dataset_id": dataset_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
- return dataset
-
-
-class GetDatasetOperator(BaseOperator):
- """
- Get a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
- :param region: Required. The Cloud Dataproc region in which to handle the request.
- :param dataset_id: Required. The ID of the Dataset to get.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
- operator_extra_links = (VertexAIDatasetLink(),)
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- dataset_id: str,
- read_mask: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> Dataset:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.dataset_id = dataset_id
- self.read_mask = read_mask
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
-
- try:
- self.log.info("Get dataset: %s", self.dataset_id)
- dataset_obj = hook.get_dataset(
- project_id=self.project_id,
- region=self.region,
- dataset=self.dataset_id,
- read_mask=self.read_mask,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.xcom_push(
- context,
- key="dataset_conf",
- value={
- "dataset_id": self.dataset_id,
- "project_id": self.project_id,
- "region": self.region,
- },
- )
- self.log.info("Dataset was gotten.")
- return Dataset.to_dict(dataset_obj)
- except NotFound:
- self.log.info("The Dataset ID %s does not exist.", self.dataset_id)
-
-
-class DeleteDatasetOperator(BaseOperator):
- """
- Deletes a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
- :param region: Required. The Cloud Dataproc region in which to handle the request.
- :param dataset_id: Required. The ID of the Dataset to delete.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- dataset_id: str,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.dataset_id = dataset_id
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
-
- try:
- self.log.info("Deleting dataset: %s", self.dataset_id)
- operation = hook.delete_dataset(
- project_id=self.project_id,
- region=self.region,
- dataset=self.dataset_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=operation)
- self.log.info("Dataset was deleted.")
- except NotFound:
- self.log.info("The Dataset ID %s does not exist.", self.dataset_id)
-
-
-class ExportDataOperator(BaseOperator):
- """
- Exports data from a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
- :param region: Required. The Cloud Dataproc region in which to handle the request.
- :param dataset_id: Required. The ID of the Dataset to delete.
- :param export_config: Required. The desired output location.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- dataset_id: str,
- export_config: ExportDataConfig,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.dataset_id = dataset_id
- self.export_config = export_config
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
-
- self.log.info("Exporting data: %s", self.dataset_id)
- operation = hook.export_data(
- project_id=self.project_id,
- region=self.region,
- dataset=self.dataset_id,
- export_config=self.export_config,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=operation)
- self.log.info("Export was done successfully")
-
-
-class ImportDataOperator(BaseOperator):
- """
- Imports data into a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
- :param region: Required. The Cloud Dataproc region in which to handle the request.
- :param dataset_id: Required. The ID of the Dataset to delete.
- :param import_configs: Required. The desired input locations. The contents of all input locations will be
- imported in one batch.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- dataset_id: str,
- import_configs: Sequence[ImportDataConfig],
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.dataset_id = dataset_id
- self.import_configs = import_configs
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
-
- self.log.info("Importing data: %s", self.dataset_id)
- operation = hook.import_data(
- project_id=self.project_id,
- region=self.region,
- dataset=self.dataset_id,
- import_configs=self.import_configs,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=operation)
- self.log.info("Import was done successfully")
-
-
-class ListDatasetsOperator(BaseOperator):
- """
- Lists Datasets in a Location.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param filter: The standard list filter.
- :param page_size: The standard list page size.
- :param page_token: The standard list page token.
- :param read_mask: Mask specifying which fields to read.
- :param order_by: A comma-separated list of fields to order by, sorted in ascending order. Use "desc"
- after a field name for descending.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "project_id", "impersonation_chain")
- operator_extra_links = (VertexAIDatasetListLink(),)
-
- def __init__(
- self,
- *,
- region: str,
- project_id: str,
- filter: Optional[str] = None,
- page_size: Optional[int] = None,
- page_token: Optional[str] = None,
- read_mask: Optional[str] = None,
- order_by: Optional[str] = None,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.region = region
- self.project_id = project_id
- self.filter = filter
- self.page_size = page_size
- self.page_token = page_token
- self.read_mask = read_mask
- self.order_by = order_by
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
- results = hook.list_datasets(
- project_id=self.project_id,
- region=self.region,
- filter=self.filter,
- page_size=self.page_size,
- page_token=self.page_token,
- read_mask=self.read_mask,
- order_by=self.order_by,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.xcom_push(
- context,
- key="project_id",
- value=self.project_id,
- )
- return [Dataset.to_dict(result) for result in results]
-
-
-class UpdateDatasetOperator(BaseOperator):
- """
- Updates a Dataset.
-
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
- :param region: Required. The ID of the Google Cloud region that the service belongs to.
- :param dataset_id: Required. The ID of the Dataset to update.
- :param dataset: Required. The Dataset which replaces the resource on the server.
- :param update_mask: Required. The update mask applies to the resource.
- :param retry: Designation of what errors, if any, should be retried.
- :param timeout: The timeout for this request.
- :param metadata: Strings which should be sent along with the request as metadata.
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
- if any. For this to work, the service account making the request must have
- domain-wide delegation enabled.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields = ("region", "dataset_id", "project_id", "impersonation_chain")
-
- def __init__(
- self,
- *,
- project_id: str,
- region: str,
- dataset_id: str,
- dataset: Dataset,
- update_mask: FieldMask,
- retry: Optional[Retry] = None,
- timeout: Optional[float] = None,
- metadata: Optional[Sequence[Tuple[str, str]]] = "",
- gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.project_id = project_id
- self.region = region
- self.dataset_id = dataset_id
- self.dataset = dataset
- self.update_mask = update_mask
- self.retry = retry
- self.timeout = timeout
- self.metadata = metadata
- self.gcp_conn_id = gcp_conn_id
- self.delegate_to = delegate_to
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: 'Context'):
- hook = DatasetHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Updating dataset: %s", self.dataset_id)
- result = hook.update_dataset(
- project_id=self.project_id,
- region=self.region,
- dataset_id=self.dataset_id,
- dataset=self.dataset,
- update_mask=self.update_mask,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.log.info("Dataset was updated")
- return Dataset.to_dict(result)
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 8dd83a2..5e9b00d 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -337,11 +337,6 @@ integrations:
how-to-guide:
- /docs/apache-airflow-providers-google/operators/leveldb/leveldb.rst
tags: [google]
- - integration-name: Google Vertex AI
- external-doc-url: https://cloud.google.com/vertex-ai
- how-to-guide:
- - /docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
- tags: [gcp]
operators:
- integration-name: Google Ads
@@ -469,10 +464,6 @@ operators:
- integration-name: Google LevelDB
python-modules:
- airflow.providers.google.leveldb.operators.leveldb
- - integration-name: Google Vertex AI
- python-modules:
- - airflow.providers.google.cloud.operators.vertex_ai.dataset
- - airflow.providers.google.cloud.operators.vertex_ai.custom_job
sensors:
- integration-name: Google BigQuery
@@ -667,10 +658,6 @@ hooks:
- integration-name: Google LevelDB
python-modules:
- airflow.providers.google.leveldb.hooks.leveldb
- - integration-name: Google Vertex AI
- python-modules:
- - airflow.providers.google.cloud.hooks.vertex_ai.dataset
- - airflow.providers.google.cloud.hooks.vertex_ai.custom_job
transfers:
- source-integration-name: Presto
@@ -818,10 +805,6 @@ extra-links:
- airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink
- airflow.providers.google.cloud.operators.dataproc.DataprocJobLink
- airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink
- - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAIModelLink
- - airflow.providers.google.cloud.operators.vertex_ai.custom_job.VertexAITrainingPipelinesLink
- - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetLink
- - airflow.providers.google.cloud.operators.vertex_ai.dataset.VertexAIDatasetListLink
additional-extras:
apache.beam: apache-beam[gcp]
diff --git a/docs/apache-airflow-providers-google/index.rst b/docs/apache-airflow-providers-google/index.rst
index 4c6dc4e..d5a2f61 100644
--- a/docs/apache-airflow-providers-google/index.rst
+++ b/docs/apache-airflow-providers-google/index.rst
@@ -96,7 +96,6 @@ PIP package Version required
``google-api-python-client`` ``>=1.6.0,<2.0.0``
``google-auth-httplib2`` ``>=0.0.1``
``google-auth`` ``>=1.0.0,<3.0.0``
-``google-cloud-aiplatform`` ``>=1.7.1,<2.0.0``
``google-cloud-automl`` ``>=2.1.0,<3.0.0``
``google-cloud-bigquery-datatransfer`` ``>=3.0.0,<4.0.0``
``google-cloud-bigtable`` ``>=1.0.0,<2.0.0``
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
deleted file mode 100644
index 92c22af..0000000
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ /dev/null
@@ -1,173 +0,0 @@
- .. Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- .. http://www.apache.org/licenses/LICENSE-2.0
-
- .. Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
-
-Google Cloud VertexAI Operators
-=======================================
-
-The `Google Cloud VertexAI <https://cloud.google.com/vertex-ai/docs>`__
-brings AutoML and AI Platform together into a unified API, client library, and user
-interface. AutoML lets you train models on image, tabular, text, and video datasets
-without writing code, while training in AI Platform lets you run custom training code.
-With Vertex AI, both AutoML training and custom training are available options.
-Whichever option you choose for training, you can save models, deploy models, and
-request predictions with Vertex AI.
-
-Creating Datasets
-^^^^^^^^^^^^^^^^^
-
-To create a Google VertexAI dataset you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`.
-The operator returns dataset id in :ref:`XCom <concepts:xcom>` under ``dataset_id`` key.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_create_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_create_dataset_operator]
-
-After creating a dataset you can use it to import some data using
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_import_data_operator]
- :end-before: [END how_to_cloud_vertex_ai_import_data_operator]
-
-To export dataset you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ExportDataOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_export_data_operator]
- :end-before: [END how_to_cloud_vertex_ai_export_data_operator]
-
-To delete dataset you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_delete_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_delete_dataset_operator]
-
-To get dataset you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.GetDatasetOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_get_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_get_dataset_operator]
-
-To get a dataset list you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_list_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_list_dataset_operator]
-
-To update dataset you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_update_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_update_dataset_operator]
-
-Creating a Training Jobs
-^^^^^^^^^^^^^^^^^^^^^^^^
-
-To create a Google Vertex AI training jobs you have three operators
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator`,
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`,
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`.
-Each of them will wait for the operation to complete. The results of each operator will be a model
-which was trained by user using these operators.
-
-Preparation step
-
-For each operator you must prepare and create dataset. Then put dataset id to ``dataset_id`` parameter in operator.
-
-How to run Container Training Job
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator`
-
-Before start running this Job you should create a docker image with training script inside. Documentation how to
-create image you can find by this link: https://cloud.google.com/vertex-ai/docs/training/create-custom-container
-After that you should put link to the image in ``container_uri`` parameter. Also you can type executing command
-for container which will be created from this image in ``command`` parameter.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator]
-
-How to run Python Package Training Job
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`
-
-Before start running this Job you should create a python package with training script inside. Documentation how to
-create you can find by this link: https://cloud.google.com/vertex-ai/docs/training/create-python-pre-built-container
-Next you should put link to the package in ``python_package_gcs_uri`` parameter, also ``python_module_name``
-parameter should has the name of script which will run your training task.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator]
-
-How to run Training Job
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`.
-
-For this Job you should put path to your local training script inside ``script_path`` parameter.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_operator]
-
-You can get a list of Training Jobs using
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.ListCustomTrainingJobOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_list_custom_training_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_list_custom_training_job_operator]
-
-If you wish to delete a Custom Training Job you can use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.DeleteCustomTrainingJobOperator`.
-
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_vertex_ai.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_delete_custom_training_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_delete_custom_training_job_operator]
-
-Reference
-^^^^^^^^^
-
-For further information, look at:
-
-* `Client Library Documentation <https://googleapis.dev/python/aiplatform/latest/index.html>`__
-* `Product Documentation <https://cloud.google.com/ai-platform/docs>`__
diff --git a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
index 11d9eaf..ec614a5 100755
--- a/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
+++ b/scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
@@ -151,10 +151,7 @@ def parse_module_data(provider_data, resource_type, yaml_file_path):
package_dir = ROOT_DIR + "/" + os.path.dirname(yaml_file_path)
provider_package = os.path.dirname(yaml_file_path).replace(os.sep, ".")
py_files = chain(
- glob(f"{package_dir}/**/{resource_type}/*.py"),
- glob(f"{package_dir}/{resource_type}/*.py"),
- glob(f"{package_dir}/**/{resource_type}/**/*.py"),
- glob(f"{package_dir}/{resource_type}/**/*.py"),
+ glob(f"{package_dir}/**/{resource_type}/*.py"), glob(f"{package_dir}/{resource_type}/*.py")
)
expected_modules = {_filepath_to_module(f) for f in py_files if not f.endswith("/__init__.py")}
resource_data = provider_data.get(resource_type, [])
diff --git a/setup.py b/setup.py
index a884b33..ae8b7e7 100644
--- a/setup.py
+++ b/setup.py
@@ -307,7 +307,6 @@ google = [
# https://github.com/googleapis/google-cloud-python/issues/10566
'google-auth>=1.0.0,<3.0.0',
'google-auth-httplib2>=0.0.1',
- 'google-cloud-aiplatform>=1.7.1,<2.0.0',
'google-cloud-automl>=2.1.0,<3.0.0',
'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
'google-cloud-bigtable>=1.0.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/vertex_ai/__init__.py b/tests/providers/google/cloud/hooks/vertex_ai/__init__.py
deleted file mode 100644
index 13a8339..0000000
--- a/tests/providers/google/cloud/hooks/vertex_ai/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py
deleted file mode 100644
index 2b12fbf..0000000
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py
+++ /dev/null
@@ -1,457 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-from unittest import TestCase, mock
-
-from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
-from tests.providers.google.cloud.utils.base_gcp_mock import (
- mock_base_gcp_hook_default_project_id,
- mock_base_gcp_hook_no_default_project_id,
-)
-
-TEST_GCP_CONN_ID: str = "test-gcp-conn-id"
-TEST_REGION: str = "test-region"
-TEST_PROJECT_ID: str = "test-project-id"
-TEST_PIPELINE_JOB: dict = {}
-TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id"
-TEST_TRAINING_PIPELINE: dict = {}
-TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline"
-
-BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
-CUSTOM_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.custom_job.{}"
-
-
-class TestCustomJobWithDefaultProjectIdHook(TestCase):
- def setUp(self):
- with mock.patch(
- BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id
- ):
- self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_cancel_pipeline_job(self, mock_client) -> None:
- self.hook.cancel_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.cancel_pipeline_job.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.pipeline_job_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.pipeline_job_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_cancel_training_pipeline(self, mock_client) -> None:
- self.hook.cancel_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE_NAME,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.cancel_training_pipeline.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.training_pipeline_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.training_pipeline_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_create_pipeline_job(self, mock_client) -> None:
- self.hook.create_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB,
- pipeline_job_id=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.create_pipeline_job.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- pipeline_job=TEST_PIPELINE_JOB,
- pipeline_job_id=TEST_PIPELINE_JOB_ID,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_create_training_pipeline(self, mock_client) -> None:
- self.hook.create_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.create_training_pipeline.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- training_pipeline=TEST_TRAINING_PIPELINE,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_delete_pipeline_job(self, mock_client) -> None:
- self.hook.delete_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.delete_pipeline_job.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.pipeline_job_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.pipeline_job_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_delete_training_pipeline(self, mock_client) -> None:
- self.hook.delete_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE_NAME,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.delete_training_pipeline.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.training_pipeline_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.training_pipeline_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_get_pipeline_job(self, mock_client) -> None:
- self.hook.get_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_pipeline_job.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.pipeline_job_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.pipeline_job_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_get_training_pipeline(self, mock_client) -> None:
- self.hook.get_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE_NAME,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_training_pipeline.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.training_pipeline_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.training_pipeline_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_list_pipeline_jobs(self, mock_client) -> None:
- self.hook.list_pipeline_jobs(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_pipeline_jobs.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- page_size=None,
- page_token=None,
- filter=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_list_training_pipelines(self, mock_client) -> None:
- self.hook.list_training_pipelines(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_training_pipelines.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- page_size=None,
- page_token=None,
- filter=None,
- read_mask=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
-
-class TestCustomJobWithoutDefaultProjectIdHook(TestCase):
- def setUp(self):
- with mock.patch(
- BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id
- ):
- self.hook = CustomJobHook(gcp_conn_id=TEST_GCP_CONN_ID)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_cancel_pipeline_job(self, mock_client) -> None:
- self.hook.cancel_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.cancel_pipeline_job.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.pipeline_job_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.pipeline_job_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_cancel_training_pipeline(self, mock_client) -> None:
- self.hook.cancel_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE_NAME,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.cancel_training_pipeline.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.training_pipeline_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.training_pipeline_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_create_pipeline_job(self, mock_client) -> None:
- self.hook.create_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB,
- pipeline_job_id=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.create_pipeline_job.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- pipeline_job=TEST_PIPELINE_JOB,
- pipeline_job_id=TEST_PIPELINE_JOB_ID,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_create_training_pipeline(self, mock_client) -> None:
- self.hook.create_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.create_training_pipeline.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- training_pipeline=TEST_TRAINING_PIPELINE,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_delete_pipeline_job(self, mock_client) -> None:
- self.hook.delete_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.delete_pipeline_job.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.pipeline_job_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.pipeline_job_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_delete_training_pipeline(self, mock_client) -> None:
- self.hook.delete_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE_NAME,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.delete_training_pipeline.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.training_pipeline_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.training_pipeline_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_get_pipeline_job(self, mock_client) -> None:
- self.hook.get_pipeline_job(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- pipeline_job=TEST_PIPELINE_JOB_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_pipeline_job.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.pipeline_job_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.pipeline_job_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_PIPELINE_JOB_ID
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_get_training_pipeline(self, mock_client) -> None:
- self.hook.get_training_pipeline(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- training_pipeline=TEST_TRAINING_PIPELINE_NAME,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_training_pipeline.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.training_pipeline_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.training_pipeline_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_TRAINING_PIPELINE_NAME
- )
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_list_pipeline_jobs(self, mock_client) -> None:
- self.hook.list_pipeline_jobs(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_pipeline_jobs.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- page_size=None,
- page_token=None,
- filter=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(CUSTOM_JOB_STRING.format("CustomJobHook.get_pipeline_service_client"))
- def test_list_training_pipelines(self, mock_client) -> None:
- self.hook.list_training_pipelines(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_training_pipelines.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- page_size=None,
- page_token=None,
- filter=None,
- read_mask=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py b/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py
deleted file mode 100644
index 5a29087..0000000
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_dataset.py
+++ /dev/null
@@ -1,504 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-from unittest import TestCase, mock
-
-from airflow.providers.google.cloud.hooks.vertex_ai.dataset import DatasetHook
-from tests.providers.google.cloud.utils.base_gcp_mock import (
- mock_base_gcp_hook_default_project_id,
- mock_base_gcp_hook_no_default_project_id,
-)
-
-TEST_GCP_CONN_ID: str = "test-gcp-conn-id"
-TEST_REGION: str = "test-region"
-TEST_PROJECT_ID: str = "test-project-id"
-TEST_PIPELINE_JOB: dict = {}
-TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id"
-TEST_TRAINING_PIPELINE: dict = {}
-TEST_TRAINING_PIPELINE_NAME: str = "test-training-pipeline"
-TEST_DATASET: dict = {}
-TEST_DATASET_ID: str = "test-dataset-id"
-TEST_EXPORT_CONFIG: dict = {}
-TEST_ANNOTATION_SPEC: str = "test-annotation-spec"
-TEST_IMPORT_CONFIGS: dict = {}
-TEST_DATA_ITEM: str = "test-data-item"
-TEST_UPDATE_MASK: dict = {}
-
-BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
-DATASET_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.dataset.{}"
-
-
-class TestVertexAIWithDefaultProjectIdHook(TestCase):
- def setUp(self):
- with mock.patch(
- BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id
- ):
- self.hook = DatasetHook(gcp_conn_id=TEST_GCP_CONN_ID)
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_create_dataset(self, mock_client) -> None:
- self.hook.create_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.create_dataset.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- dataset=TEST_DATASET,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_delete_dataset(self, mock_client) -> None:
- self.hook.delete_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.delete_dataset.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_export_data(self, mock_client) -> None:
- self.hook.export_data(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- export_config=TEST_EXPORT_CONFIG,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.export_data.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- export_config=TEST_EXPORT_CONFIG,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_get_annotation_spec(self, mock_client) -> None:
- self.hook.get_annotation_spec(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- annotation_spec=TEST_ANNOTATION_SPEC,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_annotation_spec.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.annotation_spec_path.return_value,
- read_mask=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.annotation_spec_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_ANNOTATION_SPEC
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_get_dataset(self, mock_client) -> None:
- self.hook.get_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_dataset.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- read_mask=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_import_data(self, mock_client) -> None:
- self.hook.import_data(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- import_configs=TEST_IMPORT_CONFIGS,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.import_data.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- import_configs=TEST_IMPORT_CONFIGS,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_list_annotations(self, mock_client) -> None:
- self.hook.list_annotations(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- data_item=TEST_DATA_ITEM,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_annotations.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.data_item_path.return_value,
- filter=None,
- page_size=None,
- page_token=None,
- read_mask=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.data_item_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_DATA_ITEM
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_list_data_items(self, mock_client) -> None:
- self.hook.list_data_items(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_data_items.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.dataset_path.return_value,
- filter=None,
- page_size=None,
- page_token=None,
- read_mask=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_list_datasets(self, mock_client) -> None:
- self.hook.list_datasets(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_datasets.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- filter=None,
- page_size=None,
- page_token=None,
- read_mask=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_update_dataset(self, mock_client) -> None:
- self.hook.update_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset_id=TEST_DATASET_ID,
- dataset=TEST_DATASET,
- update_mask=TEST_UPDATE_MASK,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.update_dataset.assert_called_once_with(
- request=dict(
- dataset=TEST_DATASET,
- update_mask=TEST_UPDATE_MASK,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
-
-class TestVertexAIWithoutDefaultProjectIdHook(TestCase):
- def setUp(self):
- with mock.patch(
- BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_no_default_project_id
- ):
- self.hook = DatasetHook(gcp_conn_id=TEST_GCP_CONN_ID)
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_create_dataset(self, mock_client) -> None:
- self.hook.create_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.create_dataset.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- dataset=TEST_DATASET,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_delete_dataset(self, mock_client) -> None:
- self.hook.delete_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.delete_dataset.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_export_data(self, mock_client) -> None:
- self.hook.export_data(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- export_config=TEST_EXPORT_CONFIG,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.export_data.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- export_config=TEST_EXPORT_CONFIG,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_get_annotation_spec(self, mock_client) -> None:
- self.hook.get_annotation_spec(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- annotation_spec=TEST_ANNOTATION_SPEC,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_annotation_spec.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.annotation_spec_path.return_value,
- read_mask=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.annotation_spec_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_ANNOTATION_SPEC
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_get_dataset(self, mock_client) -> None:
- self.hook.get_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.get_dataset.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- read_mask=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_import_data(self, mock_client) -> None:
- self.hook.import_data(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- import_configs=TEST_IMPORT_CONFIGS,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.import_data.assert_called_once_with(
- request=dict(
- name=mock_client.return_value.dataset_path.return_value,
- import_configs=TEST_IMPORT_CONFIGS,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_list_annotations(self, mock_client) -> None:
- self.hook.list_annotations(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- data_item=TEST_DATA_ITEM,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_annotations.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.data_item_path.return_value,
- filter=None,
- page_size=None,
- page_token=None,
- read_mask=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.data_item_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID, TEST_DATA_ITEM
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_list_data_items(self, mock_client) -> None:
- self.hook.list_data_items(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset=TEST_DATASET_ID,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_data_items.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.dataset_path.return_value,
- filter=None,
- page_size=None,
- page_token=None,
- read_mask=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_list_datasets(self, mock_client) -> None:
- self.hook.list_datasets(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.list_datasets.assert_called_once_with(
- request=dict(
- parent=mock_client.return_value.common_location_path.return_value,
- filter=None,
- page_size=None,
- page_token=None,
- read_mask=None,
- order_by=None,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)
-
- @mock.patch(DATASET_STRING.format("DatasetHook.get_dataset_service_client"))
- def test_update_dataset(self, mock_client) -> None:
- self.hook.update_dataset(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- dataset_id=TEST_DATASET_ID,
- dataset=TEST_DATASET,
- update_mask=TEST_UPDATE_MASK,
- )
- mock_client.assert_called_once_with(TEST_REGION)
- mock_client.return_value.update_dataset.assert_called_once_with(
- request=dict(
- dataset=TEST_DATASET,
- update_mask=TEST_UPDATE_MASK,
- ),
- metadata=None,
- retry=None,
- timeout=None,
- )
- mock_client.return_value.dataset_path.assert_called_once_with(
- TEST_PROJECT_ID, TEST_REGION, TEST_DATASET_ID
- )
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py
deleted file mode 100644
index ec5a63d..0000000
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ /dev/null
@@ -1,613 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from unittest import mock
-
-from google.api_core.retry import Retry
-
-from airflow.providers.google.cloud.operators.vertex_ai.custom_job import (
- CreateCustomContainerTrainingJobOperator,
- CreateCustomPythonPackageTrainingJobOperator,
- CreateCustomTrainingJobOperator,
- DeleteCustomTrainingJobOperator,
- ListCustomTrainingJobOperator,
-)
-from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
- CreateDatasetOperator,
- DeleteDatasetOperator,
- ExportDataOperator,
- ImportDataOperator,
- ListDatasetsOperator,
- UpdateDatasetOperator,
-)
-
-VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}"
-TIMEOUT = 120
-RETRY = mock.MagicMock(Retry)
-METADATA = [("key", "value")]
-
-TASK_ID = "test_task_id"
-GCP_PROJECT = "test-project"
-GCP_LOCATION = "test-location"
-GCP_CONN_ID = "test-conn"
-DELEGATE_TO = "test-delegate-to"
-IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
-STAGING_BUCKET = "gs://test-vertex-ai-bucket"
-DISPLAY_NAME = "display_name_1" # Create random display name
-DISPLAY_NAME_2 = "display_nmae_2"
-ARGS = ["--tfds", "tf_flowers:3.*.*"]
-CONTAINER_URI = "gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest"
-REPLICA_COUNT = 1
-MACHINE_TYPE = "n1-standard-4"
-ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
-ACCELERATOR_COUNT = 0
-TRAINING_FRACTION_SPLIT = 0.7
-TEST_FRACTION_SPLIT = 0.15
-VALIDATION_FRACTION_SPLIT = 0.15
-COMMAND_2 = ['echo', 'Hello World']
-
-TEST_API_ENDPOINT: str = "test-api-endpoint"
-TEST_PIPELINE_JOB: str = "test-pipeline-job"
-TEST_TRAINING_PIPELINE: str = "test-training-pipeline"
-TEST_PIPELINE_JOB_ID: str = "test-pipeline-job-id"
-
-PYTHON_PACKAGE = "/files/trainer-0.1.tar.gz"
-PYTHON_PACKAGE_CMDARGS = "test-python-cmd"
-PYTHON_PACKAGE_GCS_URI = "gs://test-vertex-ai-bucket/trainer-0.1.tar.gz"
-PYTHON_MODULE_NAME = "trainer.task"
-
-TRAINING_PIPELINE_ID = "test-training-pipeline-id"
-CUSTOM_JOB_ID = "test-custom-job-id"
-
-TEST_DATASET = {
- "display_name": "test-dataset-name",
- "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml",
- "metadata": "test-image-dataset",
-}
-TEST_DATASET_ID = "test-dataset-id"
-TEST_EXPORT_CONFIG = {
- "annotationsFilter": "test-filter",
- "gcs_destination": {"output_uri_prefix": "airflow-system-tests-data"},
-}
-TEST_IMPORT_CONFIG = [
- {
- "data_item_labels": {
- "test-labels-name": "test-labels-value",
- },
- "import_schema_uri": "test-shema-uri",
- "gcs_source": {"uris": ['test-string']},
- },
- {},
-]
-TEST_UPDATE_MASK = "test-update-mask"
-
-
-class TestVertexAICreateCustomContainerTrainingJobOperator:
- @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
- def test_execute(self, mock_hook):
- op = CreateCustomContainerTrainingJobOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- staging_bucket=STAGING_BUCKET,
- display_name=DISPLAY_NAME,
- args=ARGS,
- container_uri=CONTAINER_URI,
- model_serving_container_image_uri=CONTAINER_URI,
- command=COMMAND_2,
- model_display_name=DISPLAY_NAME_2,
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=TRAINING_FRACTION_SPLIT,
- validation_fraction_split=VALIDATION_FRACTION_SPLIT,
- test_fraction_split=TEST_FRACTION_SPLIT,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- )
- op.execute(context={'ti': mock.MagicMock()})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.create_custom_container_training_job.assert_called_once_with(
- staging_bucket=STAGING_BUCKET,
- display_name=DISPLAY_NAME,
- args=ARGS,
- container_uri=CONTAINER_URI,
- model_serving_container_image_uri=CONTAINER_URI,
- command=COMMAND_2,
- dataset=None,
- model_display_name=DISPLAY_NAME_2,
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=TRAINING_FRACTION_SPLIT,
- validation_fraction_split=VALIDATION_FRACTION_SPLIT,
- test_fraction_split=TEST_FRACTION_SPLIT,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- model_serving_container_predict_route=None,
- model_serving_container_health_route=None,
- model_serving_container_command=None,
- model_serving_container_args=None,
- model_serving_container_environment_variables=None,
- model_serving_container_ports=None,
- model_description=None,
- model_instance_schema_uri=None,
- model_parameters_schema_uri=None,
- model_prediction_schema_uri=None,
- labels=None,
- training_encryption_spec_key_name=None,
- model_encryption_spec_key_name=None,
- # RUN
- annotation_schema_uri=None,
- model_labels=None,
- base_output_dir=None,
- service_account=None,
- network=None,
- bigquery_destination=None,
- environment_variables=None,
- boot_disk_type='pd-ssd',
- boot_disk_size_gb=100,
- training_filter_split=None,
- validation_filter_split=None,
- test_filter_split=None,
- predefined_split_column_name=None,
- timestamp_split_column_name=None,
- tensorboard=None,
- sync=True,
- )
-
-
-class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
- @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
- def test_execute(self, mock_hook):
- op = CreateCustomPythonPackageTrainingJobOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- staging_bucket=STAGING_BUCKET,
- display_name=DISPLAY_NAME,
- python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
- python_module_name=PYTHON_MODULE_NAME,
- container_uri=CONTAINER_URI,
- args=ARGS,
- model_serving_container_image_uri=CONTAINER_URI,
- model_display_name=DISPLAY_NAME_2,
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=TRAINING_FRACTION_SPLIT,
- validation_fraction_split=VALIDATION_FRACTION_SPLIT,
- test_fraction_split=TEST_FRACTION_SPLIT,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- )
- op.execute(context={'ti': mock.MagicMock()})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with(
- staging_bucket=STAGING_BUCKET,
- display_name=DISPLAY_NAME,
- args=ARGS,
- container_uri=CONTAINER_URI,
- model_serving_container_image_uri=CONTAINER_URI,
- python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI,
- python_module_name=PYTHON_MODULE_NAME,
- dataset=None,
- model_display_name=DISPLAY_NAME_2,
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=TRAINING_FRACTION_SPLIT,
- validation_fraction_split=VALIDATION_FRACTION_SPLIT,
- test_fraction_split=TEST_FRACTION_SPLIT,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- model_serving_container_predict_route=None,
- model_serving_container_health_route=None,
- model_serving_container_command=None,
- model_serving_container_args=None,
- model_serving_container_environment_variables=None,
- model_serving_container_ports=None,
- model_description=None,
- model_instance_schema_uri=None,
- model_parameters_schema_uri=None,
- model_prediction_schema_uri=None,
- labels=None,
- training_encryption_spec_key_name=None,
- model_encryption_spec_key_name=None,
- # RUN
- annotation_schema_uri=None,
- model_labels=None,
- base_output_dir=None,
- service_account=None,
- network=None,
- bigquery_destination=None,
- environment_variables=None,
- boot_disk_type='pd-ssd',
- boot_disk_size_gb=100,
- training_filter_split=None,
- validation_filter_split=None,
- test_filter_split=None,
- predefined_split_column_name=None,
- timestamp_split_column_name=None,
- tensorboard=None,
- sync=True,
- )
-
-
-class TestVertexAICreateCustomTrainingJobOperator:
- @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
- def test_execute(self, mock_hook):
- op = CreateCustomTrainingJobOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- staging_bucket=STAGING_BUCKET,
- display_name=DISPLAY_NAME,
- script_path=PYTHON_PACKAGE,
- args=PYTHON_PACKAGE_CMDARGS,
- container_uri=CONTAINER_URI,
- model_serving_container_image_uri=CONTAINER_URI,
- requirements=[],
- replica_count=1,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- )
- op.execute(context={'ti': mock.MagicMock()})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.create_custom_training_job.assert_called_once_with(
- staging_bucket=STAGING_BUCKET,
- display_name=DISPLAY_NAME,
- args=PYTHON_PACKAGE_CMDARGS,
- container_uri=CONTAINER_URI,
- model_serving_container_image_uri=CONTAINER_URI,
- script_path=PYTHON_PACKAGE,
- requirements=[],
- dataset=None,
- model_display_name=None,
- replica_count=REPLICA_COUNT,
- machine_type=MACHINE_TYPE,
- accelerator_type=ACCELERATOR_TYPE,
- accelerator_count=ACCELERATOR_COUNT,
- training_fraction_split=None,
- validation_fraction_split=None,
- test_fraction_split=None,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- model_serving_container_predict_route=None,
- model_serving_container_health_route=None,
- model_serving_container_command=None,
- model_serving_container_args=None,
- model_serving_container_environment_variables=None,
- model_serving_container_ports=None,
- model_description=None,
- model_instance_schema_uri=None,
- model_parameters_schema_uri=None,
- model_prediction_schema_uri=None,
- labels=None,
- training_encryption_spec_key_name=None,
- model_encryption_spec_key_name=None,
- # RUN
- annotation_schema_uri=None,
- model_labels=None,
- base_output_dir=None,
- service_account=None,
- network=None,
- bigquery_destination=None,
- environment_variables=None,
- boot_disk_type='pd-ssd',
- boot_disk_size_gb=100,
- training_filter_split=None,
- validation_filter_split=None,
- test_filter_split=None,
- predefined_split_column_name=None,
- timestamp_split_column_name=None,
- tensorboard=None,
- sync=True,
- )
-
-
-class TestVertexAIDeleteCustomTrainingJobOperator:
- @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
- def test_execute(self, mock_hook):
- op = DeleteCustomTrainingJobOperator(
- task_id=TASK_ID,
- training_pipeline_id=TRAINING_PIPELINE_ID,
- custom_job_id=CUSTOM_JOB_ID,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- )
- op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.delete_training_pipeline.assert_called_once_with(
- training_pipeline=TRAINING_PIPELINE_ID,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- mock_hook.return_value.delete_custom_job.assert_called_once_with(
- custom_job=CUSTOM_JOB_ID,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAIListCustomTrainingJobOperator:
- @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
- def test_execute(self, mock_hook):
- page_token = "page_token"
- page_size = 42
- filter = "filter"
- read_mask = "read_mask"
-
- op = ListCustomTrainingJobOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- page_size=page_size,
- page_token=page_token,
- filter=filter,
- read_mask=read_mask,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={'ti': mock.MagicMock()})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.list_training_pipelines.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- page_size=page_size,
- page_token=page_token,
- filter=filter,
- read_mask=read_mask,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAICreateDatasetOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
- @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
- op = CreateDatasetOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset=TEST_DATASET,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={'ti': mock.MagicMock()})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.create_dataset.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset=TEST_DATASET,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAIDeleteDatasetOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
- @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
- op = DeleteDatasetOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset_id=TEST_DATASET_ID,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.delete_dataset.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset=TEST_DATASET_ID,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAIExportDataOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
- @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
- op = ExportDataOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset_id=TEST_DATASET_ID,
- export_config=TEST_EXPORT_CONFIG,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.export_data.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset=TEST_DATASET_ID,
- export_config=TEST_EXPORT_CONFIG,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAIImportDataOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
- @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
- op = ImportDataOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset_id=TEST_DATASET_ID,
- import_configs=TEST_IMPORT_CONFIG,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.import_data.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- dataset=TEST_DATASET_ID,
- import_configs=TEST_IMPORT_CONFIG,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAIListDatasetsOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
- @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
- page_token = "page_token"
- page_size = 42
- filter = "filter"
- read_mask = "read_mask"
- order_by = "order_by"
-
- op = ListDatasetsOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- filter=filter,
- page_size=page_size,
- page_token=page_token,
- read_mask=read_mask,
- order_by=order_by,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={'ti': mock.MagicMock()})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.list_datasets.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- filter=filter,
- page_size=page_size,
- page_token=page_token,
- read_mask=read_mask,
- order_by=order_by,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
-
-
-class TestVertexAIUpdateDatasetOperator:
- @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))
- @mock.patch(VERTEX_AI_PATH.format("dataset.DatasetHook"))
- def test_execute(self, mock_hook, to_dict_mock):
- op = UpdateDatasetOperator(
- task_id=TASK_ID,
- gcp_conn_id=GCP_CONN_ID,
- delegate_to=DELEGATE_TO,
- impersonation_chain=IMPERSONATION_CHAIN,
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- dataset_id=TEST_DATASET_ID,
- dataset=TEST_DATASET,
- update_mask=TEST_UPDATE_MASK,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
- op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, delegate_to=DELEGATE_TO, impersonation_chain=IMPERSONATION_CHAIN
- )
- mock_hook.return_value.update_dataset.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- dataset_id=TEST_DATASET_ID,
- dataset=TEST_DATASET,
- update_mask=TEST_UPDATE_MASK,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai_system.py b/tests/providers/google/cloud/operators/test_vertex_ai_system.py
deleted file mode 100644
index 84b84c3..0000000
--- a/tests/providers/google/cloud/operators/test_vertex_ai_system.py
+++ /dev/null
@@ -1,41 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-
-from tests.providers.google.cloud.utils.gcp_authenticator import GCP_VERTEX_AI_KEY
-from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
-
-
-@pytest.mark.backend("mysql", "postgres")
-@pytest.mark.credential_file(GCP_VERTEX_AI_KEY)
-class VertexAIExampleDagsTest(GoogleSystemTest):
- @provide_gcp_context(GCP_VERTEX_AI_KEY)
- def setUp(self):
- super().setUp()
-
- @provide_gcp_context(GCP_VERTEX_AI_KEY)
- def tearDown(self):
- super().tearDown()
-
- @provide_gcp_context(GCP_VERTEX_AI_KEY)
- def test_run_custom_jobs_example_dag(self):
- self.run_dag(dag_id="example_gcp_vertex_ai_custom_jobs", dag_folder=CLOUD_DAG_FOLDER)
-
- @provide_gcp_context(GCP_VERTEX_AI_KEY)
- def test_run_dataset_example_dag(self):
- self.run_dag(dag_id="example_gcp_vertex_ai_dataset", dag_folder=CLOUD_DAG_FOLDER)
diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py
index f6236b2..1a58b0d 100644
--- a/tests/providers/google/cloud/utils/gcp_authenticator.py
+++ b/tests/providers/google/cloud/utils/gcp_authenticator.py
@@ -54,7 +54,6 @@ GCP_SECRET_MANAGER_KEY = 'gcp_secret_manager.json'
GCP_SPANNER_KEY = 'gcp_spanner.json'
GCP_STACKDRIVER = 'gcp_stackdriver.json'
GCP_TASKS_KEY = 'gcp_tasks.json'
-GCP_VERTEX_AI_KEY = 'gcp_vertex_ai.json'
GCP_WORKFLOWS_KEY = "gcp_workflows.json"
GMP_KEY = 'gmp.json'
G_FIREBASE_KEY = 'g_firebase.json'