You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 05:04:54 UTC
[airflow] branch main updated: Rewrite system tests for ML Engine service (#26915)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 528ecbbc00 Rewrite system tests for ML Engine service (#26915)
528ecbbc00 is described below
commit 528ecbbc005566e13f7a6a1cafb4962733c6efb0
Author: George <pa...@gmail.com>
AuthorDate: Mon Oct 31 06:04:36 2022 +0100
Rewrite system tests for ML Engine service (#26915)
---
.../cloud/utils/mlengine_prediction_summary.py | 3 +-
.../operators/cloud/mlengine.rst | 30 ++--
.../google/cloud/operators/test_mlengine_system.py | 56 ------
.../providers/google/cloud/ml_engine/__init__.py | 16 ++
.../google/cloud/ml_engine}/example_mlengine.py | 192 ++++++++++++---------
5 files changed, 146 insertions(+), 151 deletions(-)
diff --git a/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py b/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py
index 4734a20a0a..0b3b7c5633 100644
--- a/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py
+++ b/airflow/providers/google/cloud/utils/mlengine_prediction_summary.py
@@ -116,9 +116,10 @@ import os
import apache_beam as beam
import dill
+from apache_beam.coders.coders import Coder
-class JsonCoder:
+class JsonCoder(Coder):
"""JSON encoder/decoder."""
@staticmethod
diff --git a/docs/apache-airflow-providers-google/operators/cloud/mlengine.rst b/docs/apache-airflow-providers-google/operators/cloud/mlengine.rst
index fceeeee83c..8dec546b3e 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/mlengine.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/mlengine.rst
@@ -40,7 +40,7 @@ This creates a virtual machine that can run code specified in the trainer file,
contains the main application code. A job can be initiated with the
:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_training]
@@ -55,7 +55,7 @@ A model is a container that can hold multiple model versions. A new model can be
The ``model`` field should be defined with a dictionary containing the information about the model.
``name`` is a required field in this dictionary.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_create_model]
@@ -69,7 +69,7 @@ The :class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineGetModelO
can be used to obtain a model previously created. To obtain the correct model, ``model_name``
must be defined in the operator.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_get_model]
@@ -80,7 +80,7 @@ fields to dynamically determine their values. The result are saved to :ref:`XCom
allowing them to be used by other operators. In this case, the
:class:`~airflow.operators.bash.BashOperator` is used to print the model information.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_print_model]
@@ -96,7 +96,7 @@ The model must be specified by ``model_name``, and the ``version`` parameter sho
all the information about the version. Within the ``version`` parameter's dictionary, the ``name`` field is
required.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_create_version1]
@@ -105,7 +105,7 @@ required.
The :class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineCreateVersionOperator`
can also be used to create more versions with varying parameters.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_create_version2]
@@ -120,7 +120,7 @@ By default, the model code will run using the default model version. You can set
:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineSetDefaultVersionOperator`
by specifying the ``model_name`` and ``version_name`` parameters.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_default_version]
@@ -130,7 +130,7 @@ To list the model versions available, use the
:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineListVersionsOperator`
while specifying the ``model_name`` parameter.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_list_versions]
@@ -141,7 +141,7 @@ fields to dynamically determine their values. The result are saved to :ref:`XCom
allowing them to be used by other operators. In this case, the
:class:`~airflow.operators.bash.BashOperator` is used to print the version information.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_print_versions]
@@ -156,7 +156,7 @@ A Google Cloud AI Platform prediction job can be started with the
For specifying the model origin, you need to provide either the ``model_name``, ``uri``, or ``model_name`` and
``version_name``. If you do not provide the ``version_name``, the operator will use the default model version.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_get_prediction]
@@ -171,7 +171,7 @@ A model version can be deleted with the
:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineDeleteVersionOperator` by
the ``version_name`` and ``model_name`` parameters.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_delete_version]
@@ -181,7 +181,7 @@ You can also delete a model with the
:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineDeleteModelOperator`
by providing the ``model_name`` parameter.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_delete_model]
@@ -193,7 +193,7 @@ To evaluate a prediction and model, specify a metric function to generate a summ
the evaluation of the model. This function receives a dictionary derived from a json in the batch
prediction result, then returns a tuple of metrics.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_get_metric]
@@ -203,7 +203,7 @@ To evaluate a prediction and model, it's useful to have a function to validate t
This function receives a dictionary of the averaged metrics the function above generated. It then
raises an exception if a task fails or should not proceed.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_validate_error]
@@ -214,7 +214,7 @@ Prediction results and a model summary can be generated through a function such
It makes predictions using the specified inputs and then summarizes and validates the result. The
functions created above should be passed in through the ``metric_fn_and_keys`` and ``validate_fn`` fields.
-.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_mlengine.py
+.. exampleinclude:: /../../tests/system/providers/google/cloud/ml_engine/example_mlengine.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_mlengine_evaluate]
diff --git a/tests/providers/google/cloud/operators/test_mlengine_system.py b/tests/providers/google/cloud/operators/test_mlengine_system.py
deleted file mode 100644
index cd50544cfd..0000000000
--- a/tests/providers/google/cloud/operators/test_mlengine_system.py
+++ /dev/null
@@ -1,56 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from urllib.parse import urlparse
-
-import pytest
-
-from airflow.providers.google.cloud.example_dags.example_mlengine import (
- JOB_DIR,
- PREDICTION_OUTPUT,
- SAVED_MODEL_PATH,
- SUMMARY_STAGING,
- SUMMARY_TMP,
-)
-from tests.providers.google.cloud.utils.gcp_authenticator import GCP_AI_KEY
-from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
-
-BUCKETS = {
- urlparse(bucket_url).netloc
- for bucket_url in {SAVED_MODEL_PATH, JOB_DIR, PREDICTION_OUTPUT, SUMMARY_TMP, SUMMARY_STAGING}
-}
-
-
-@pytest.mark.credential_file(GCP_AI_KEY)
-class MlEngineExampleDagTest(GoogleSystemTest):
- @provide_gcp_context(GCP_AI_KEY)
- def setUp(self):
- super().setUp()
- for bucket in BUCKETS:
- self.create_gcs_bucket(bucket)
-
- @provide_gcp_context(GCP_AI_KEY)
- def tearDown(self):
- for bucket in BUCKETS:
- self.delete_gcs_bucket(bucket)
- super().tearDown()
-
- @provide_gcp_context(GCP_AI_KEY)
- def test_run_example_dag(self):
- self.run_dag("example_gcp_mlengine", CLOUD_DAG_FOLDER)
diff --git a/tests/system/providers/google/cloud/ml_engine/__init__.py b/tests/system/providers/google/cloud/ml_engine/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/google/cloud/ml_engine/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/google/cloud/example_dags/example_mlengine.py b/tests/system/providers/google/cloud/ml_engine/example_mlengine.py
similarity index 61%
rename from airflow/providers/google/cloud/example_dags/example_mlengine.py
rename to tests/system/providers/google/cloud/ml_engine/example_mlengine.py
index de824ab9ea..999aa618d3 100644
--- a/airflow/providers/google/cloud/example_dags/example_mlengine.py
+++ b/tests/system/providers/google/cloud/ml_engine/example_mlengine.py
@@ -21,11 +21,14 @@ Example Airflow DAG for Google ML Engine service.
from __future__ import annotations
import os
+import pathlib
from datetime import datetime
-from typing import Any
+from math import ceil
from airflow import models
+from airflow.decorators import task
from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.operators.mlengine import (
MLEngineCreateModelOperator,
MLEngineCreateVersionOperator,
@@ -37,70 +40,64 @@ from airflow.providers.google.cloud.operators.mlengine import (
MLEngineStartBatchPredictionJobOperator,
MLEngineStartTrainingJobOperator,
)
+from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
from airflow.providers.google.cloud.utils import mlengine_operator_utils
+from airflow.utils.trigger_rule import TriggerRule
-PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
+PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "default")
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name")
+DAG_ID = "example_gcp_mlengine"
+PREDICT_FILE_NAME = "predict.json"
+MODEL_NAME = f"example_mlengine_model_{ENV_ID}"
+BUCKET_NAME = f"example_mlengine_bucket_{ENV_ID}"
+BUCKET_PATH = f"gs://{BUCKET_NAME}"
+JOB_DIR = f"{BUCKET_PATH}/job-dir"
+SAVED_MODEL_PATH = f"{JOB_DIR}/"
+PREDICTION_INPUT = f"{BUCKET_PATH}/{PREDICT_FILE_NAME}"
+PREDICTION_OUTPUT = f"{BUCKET_PATH}/prediction_output/"
+TRAINER_URI = "gs://system-tests-resources/example_gcp_mlengine/trainer-0.1.tar.gz"
+TRAINER_PY_MODULE = "trainer.task"
+SUMMARY_TMP = f"{BUCKET_PATH}/tmp/"
+SUMMARY_STAGING = f"{BUCKET_PATH}/staging/"
-SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://INVALID BUCKET NAME/saved-model/")
-JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://INVALID BUCKET NAME/keras-job-dir")
-PREDICTION_INPUT = os.environ.get(
- "GCP_MLENGINE_PREDICTION_INPUT", "gs://INVALID BUCKET NAME/prediction_input.json"
-)
-PREDICTION_OUTPUT = os.environ.get(
- "GCP_MLENGINE_PREDICTION_OUTPUT", "gs://INVALID BUCKET NAME/prediction_output"
-)
-TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://INVALID BUCKET NAME/trainer.tar.gz")
-TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task")
+BASE_DIR = pathlib.Path(__file__).parent.resolve()
+PATH_TO_PREDICT_FILE = BASE_DIR / PREDICT_FILE_NAME
-SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://INVALID BUCKET NAME/tmp/")
-SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/")
+
+def generate_model_predict_input_data() -> list[int]:
+ return [i for i in range(0, 201, 10)]
with models.DAG(
- "example_gcp_mlengine",
+ dag_id=DAG_ID,
+ schedule="@once",
start_date=datetime(2021, 1, 1),
catchup=False,
- tags=["example"],
+ tags=["example", "ml_engine"],
params={"model_name": MODEL_NAME},
) as dag:
- hyperparams: dict[str, Any] = {
- "goal": "MAXIMIZE",
- "hyperparameterMetricTag": "metric1",
- "maxTrials": 30,
- "maxParallelTrials": 1,
- "enableTrialEarlyStopping": True,
- "params": [],
- }
-
- hyperparams["params"].append(
- {
- "parameterName": "hidden1",
- "type": "INTEGER",
- "minValue": 40,
- "maxValue": 400,
- "scaleType": "UNIT_LINEAR_SCALE",
- }
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create-bucket",
+ bucket_name=BUCKET_NAME,
)
- hyperparams["params"].append(
- {"parameterName": "numRnnCells", "type": "DISCRETE", "discreteValues": [1, 2, 3, 4]}
- )
+ @task(task_id="write-predict-data-file")
+ def write_predict_file(path_to_file: str):
+ predict_data = generate_model_predict_input_data()
+ with open(path_to_file, "w") as file:
+ for predict_value in predict_data:
+ file.write(f'{{"input_layer": [{predict_value}]}}\n')
- hyperparams["params"].append(
- {
- "parameterName": "rnnCellType",
- "type": "CATEGORICAL",
- "categoricalValues": [
- "BasicLSTMCell",
- "BasicRNNCell",
- "GRUCell",
- "LSTMCell",
- "LayerNormBasicLSTMCell",
- ],
- }
+ write_data = write_predict_file(path_to_file=PATH_TO_PREDICT_FILE)
+
+ upload_file = LocalFilesystemToGCSOperator(
+ task_id="upload-predict-file",
+ src=[PATH_TO_PREDICT_FILE],
+ dst=PREDICT_FILE_NAME,
+ bucket=BUCKET_NAME,
)
+
# [START howto_operator_gcp_mlengine_training]
training = MLEngineStartTrainingJobOperator(
task_id="training",
@@ -114,7 +111,6 @@ with models.DAG(
training_python_module=TRAINER_PY_MODULE,
training_args=[],
labels={"job_type": "training"},
- hyperparameters=hyperparams,
)
# [END howto_operator_gcp_mlengine_training]
@@ -144,14 +140,14 @@ with models.DAG(
# [END howto_operator_gcp_mlengine_print_model]
# [START howto_operator_gcp_mlengine_create_version1]
- create_version = MLEngineCreateVersionOperator(
- task_id="create-version",
+ create_version_v1 = MLEngineCreateVersionOperator(
+ task_id="create-version-v1",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version={
"name": "v1",
"description": "First-version",
- "deployment_uri": f"{JOB_DIR}/keras_export/",
+ "deployment_uri": JOB_DIR,
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
@@ -161,14 +157,14 @@ with models.DAG(
# [END howto_operator_gcp_mlengine_create_version1]
# [START howto_operator_gcp_mlengine_create_version2]
- create_version_2 = MLEngineCreateVersionOperator(
- task_id="create-version-2",
+ create_version_v2 = MLEngineCreateVersionOperator(
+ task_id="create-version-v2",
project_id=PROJECT_ID,
model_name=MODEL_NAME,
version={
"name": "v2",
"description": "Second version",
- "deployment_uri": SAVED_MODEL_PATH,
+ "deployment_uri": JOB_DIR,
"runtime_version": "1.15",
"machineType": "mls1-c1-m2",
"framework": "TENSORFLOW",
@@ -216,28 +212,38 @@ with models.DAG(
# [END howto_operator_gcp_mlengine_get_prediction]
# [START howto_operator_gcp_mlengine_delete_version]
- delete_version = MLEngineDeleteVersionOperator(
- task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1"
+ delete_version_v1 = MLEngineDeleteVersionOperator(
+ task_id="delete-version-v1",
+ project_id=PROJECT_ID,
+ model_name=MODEL_NAME,
+ version_name="v1",
+ trigger_rule=TriggerRule.ALL_DONE,
)
# [END howto_operator_gcp_mlengine_delete_version]
+ delete_version_v2 = MLEngineDeleteVersionOperator(
+ task_id="delete-version-v2",
+ project_id=PROJECT_ID,
+ model_name=MODEL_NAME,
+ version_name="v2",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
# [START howto_operator_gcp_mlengine_delete_model]
delete_model = MLEngineDeleteModelOperator(
- task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True
+ task_id="delete-model",
+ project_id=PROJECT_ID,
+ model_name=MODEL_NAME,
+ delete_contents=True,
+ trigger_rule=TriggerRule.ALL_DONE,
)
# [END howto_operator_gcp_mlengine_delete_model]
- training >> create_version
- training >> create_version_2
- create_model >> get_model >> [get_model_result, delete_model]
- create_model >> get_model >> delete_model
- create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version
- create_version >> prediction
- create_version_2 >> prediction
- prediction >> delete_version
- list_version >> list_version_result
- list_version >> delete_version
- delete_version >> delete_model
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete-bucket",
+ bucket_name=BUCKET_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
# [START howto_operator_gcp_mlengine_get_metric]
def get_metric_fn_and_keys():
@@ -246,7 +252,7 @@ with models.DAG(
"""
def normalize_value(inst: dict):
- val = float(inst["dense_4"][0])
+ val = float(inst["output_layer"][0])
return tuple([val]) # returns a tuple.
return normalize_value, ["val"] # key order must match.
@@ -258,12 +264,13 @@ with models.DAG(
"""
Validate summary result
"""
- if summary["val"] > 1:
- raise ValueError(f"Too high val>1; summary={summary}")
- if summary["val"] < 0:
- raise ValueError(f"Too low val<0; summary={summary}")
- if summary["count"] != 20:
- raise ValueError(f"Invalid value val != 20; summary={summary}")
+ summary = summary.get("val", 0)
+ initial_values = generate_model_predict_input_data()
+ initial_summary = sum(initial_values) / len(initial_values)
+
+ multiplier = ceil(summary / initial_summary)
+ if multiplier != 2:
+ raise ValueError(f"Multiplier is not equal 2; multiplier: {multiplier}")
return summary
# [END howto_operator_gcp_mlengine_validate_error]
@@ -290,5 +297,32 @@ with models.DAG(
)
# [END howto_operator_gcp_mlengine_evaluate]
- create_model >> create_version >> evaluate_prediction
- evaluate_validation >> delete_version
+ # TEST SETUP
+ create_bucket >> write_data >> upload_file
+ upload_file >> [prediction, evaluate_prediction]
+ create_bucket >> training >> create_version_v1
+
+ # TEST BODY
+ create_model >> get_model >> [get_model_result, delete_model]
+ create_model >> create_version_v1 >> create_version_v2 >> set_defaults_version >> list_version
+
+ create_version_v1 >> prediction
+ create_version_v1 >> evaluate_prediction
+ create_version_v2 >> prediction
+
+ list_version >> [list_version_result, delete_version_v1]
+ prediction >> delete_version_v1
+
+ # TEST TEARDOWN
+ evaluate_validation >> delete_version_v1 >> delete_version_v2 >> delete_model >> delete_bucket
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)