You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/29 09:56:44 UTC
[airflow] branch main updated: YandexCloud provider: Support new Yandex SDK features for DataProc (#25158)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a61e0c1df7 YandexCloud provider: Support new Yandex SDK features for DataProc (#25158)
a61e0c1df7 is described below
commit a61e0c1df7cd8a25ac67fdfc778350e148510743
Author: Peter Reznikov <re...@gmail.com>
AuthorDate: Fri Jul 29 12:56:36 2022 +0300
YandexCloud provider: Support new Yandex SDK features for DataProc (#25158)
---
airflow/providers/yandex/hooks/yandex.py | 4 +-
.../yandex/operators/yandexcloud_dataproc.py | 190 +++++++++++---------
airflow/providers/yandex/provider.yaml | 2 +-
generated/README.md | 2 +
generated/provider_dependencies.json | 2 +-
tests/providers/yandex/hooks/test_yandex.py | 18 +-
.../yandex/operators/test_yandexcloud_dataproc.py | 5 +
.../system/providers/yandex/example_yandexcloud.py | 197 +++++++++++++++++++++
.../example_yandexcloud_dataproc_lightweight.py | 80 +++++++++
9 files changed, 412 insertions(+), 88 deletions(-)
diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py
index a337954496..deeac7b0fc 100644
--- a/airflow/providers/yandex/hooks/yandex.py
+++ b/airflow/providers/yandex/hooks/yandex.py
@@ -17,7 +17,7 @@
import json
import warnings
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional
import yandexcloud
@@ -107,7 +107,7 @@ class YandexCloudBaseHook(BaseHook):
# Connection id is deprecated. Use yandex_conn_id instead
connection_id: Optional[str] = None,
yandex_conn_id: Optional[str] = None,
- default_folder_id: Union[dict, bool, None] = None,
+ default_folder_id: Optional[str] = None,
default_public_ssh_key: Optional[str] = None,
) -> None:
super().__init__()
diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py b/airflow/providers/yandex/operators/yandexcloud_dataproc.py
index 1a9dd1acf0..ec6d8d6849 100644
--- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py
+++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import warnings
+from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Sequence, Union
from airflow.models import BaseOperator
@@ -24,6 +25,15 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
+@dataclass
+class InitializationAction:
+ """Data for initialization action to be run at start of DataProc cluster."""
+
+ uri: str # Uri of the executable file
+ args: Sequence[str] # Arguments to the initialization action
+ timeout: int # Execution timeout
+
+
class DataprocCreateClusterOperator(BaseOperator):
"""Creates Yandex.Cloud Data Proc cluster.
@@ -69,9 +79,20 @@ class DataprocCreateClusterOperator(BaseOperator):
in percents. 10-100.
By default is not set and default autoscaling strategy is used.
:param computenode_decommission_timeout: Timeout to gracefully decommission nodes during downscaling.
- In seconds.
+ In seconds
+ :param properties: Properties passed to main node software.
+ Docs: https://cloud.yandex.com/docs/data-proc/concepts/settings-list
+ :param enable_ui_proxy: Enable UI Proxy feature for forwarding Hadoop components web interfaces
+ Docs: https://cloud.yandex.com/docs/data-proc/concepts/ui-proxy
+ :param host_group_ids: Dedicated host groups to place VMs of cluster on.
+ Docs: https://cloud.yandex.com/docs/compute/concepts/dedicated-host
+ :param security_group_ids: User security groups.
+ Docs: https://cloud.yandex.com/docs/data-proc/concepts/network#security-groups
:param log_group_id: Id of log group to write logs. By default logs will be sent to default log group.
To disable cloud log sending set cluster property dataproc:disable_cloud_logging = true
+ Docs: https://cloud.yandex.com/docs/data-proc/concepts/logs
+ :param initialization_actions: Set of init-actions to run when cluster starts.
+ Docs: https://cloud.yandex.com/docs/data-proc/concepts/init-action
"""
def __init__(
@@ -106,7 +127,12 @@ class DataprocCreateClusterOperator(BaseOperator):
computenode_cpu_utilization_target: Optional[int] = None,
computenode_decommission_timeout: Optional[int] = None,
connection_id: Optional[str] = None,
+ properties: Optional[Dict[str, str]] = None,
+ enable_ui_proxy: bool = False,
+ host_group_ids: Optional[Iterable[str]] = None,
+ security_group_ids: Optional[Iterable[str]] = None,
log_group_id: Optional[str] = None,
+ initialization_actions: Optional[Iterable[InitializationAction]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -139,11 +165,16 @@ class DataprocCreateClusterOperator(BaseOperator):
self.computenode_preemptible = computenode_preemptible
self.computenode_cpu_utilization_target = computenode_cpu_utilization_target
self.computenode_decommission_timeout = computenode_decommission_timeout
+ self.properties = properties
+ self.enable_ui_proxy = enable_ui_proxy
+ self.host_group_ids = host_group_ids
+ self.security_group_ids = security_group_ids
self.log_group_id = log_group_id
+ self.initialization_actions = initialization_actions
self.hook: Optional[DataprocHook] = None
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: 'Context') -> dict:
self.hook = DataprocHook(
yandex_conn_id=self.yandex_conn_id,
)
@@ -176,14 +207,35 @@ class DataprocCreateClusterOperator(BaseOperator):
computenode_preemptible=self.computenode_preemptible,
computenode_cpu_utilization_target=self.computenode_cpu_utilization_target,
computenode_decommission_timeout=self.computenode_decommission_timeout,
+ properties=self.properties,
+ enable_ui_proxy=self.enable_ui_proxy,
+ host_group_ids=self.host_group_ids,
+ security_group_ids=self.security_group_ids,
log_group_id=self.log_group_id,
+ initialization_actions=self.initialization_actions
+ and [
+ self.hook.sdk.wrappers.InitializationAction(
+ uri=init_action.uri,
+ args=init_action.args,
+ timeout=init_action.timeout,
+ )
+ for init_action in self.initialization_actions
+ ],
)
- context['task_instance'].xcom_push(key='cluster_id', value=operation_result.response.id)
+ cluster_id = operation_result.response.id
+
+ context['task_instance'].xcom_push(key='cluster_id', value=cluster_id)
+ # Deprecated
context['task_instance'].xcom_push(key='yandexcloud_connection_id', value=self.yandex_conn_id)
+ return cluster_id
+ @property
+ def cluster_id(self):
+ return self.output
-class DataprocDeleteClusterOperator(BaseOperator):
- """Deletes Yandex.Cloud Data Proc cluster.
+
+class DataprocBaseOperator(BaseOperator):
+ """Base class for DataProc operators working with given cluster.
:param connection_id: ID of the Yandex.Cloud Airflow connection.
:param cluster_id: ID of the cluster to remove. (templated)
@@ -192,25 +244,45 @@ class DataprocDeleteClusterOperator(BaseOperator):
template_fields: Sequence[str] = ('cluster_id',)
def __init__(
- self, *, connection_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs
+ self, *, yandex_conn_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs
) -> None:
super().__init__(**kwargs)
- self.yandex_conn_id = connection_id
self.cluster_id = cluster_id
- self.hook: Optional[DataprocHook] = None
+ self.yandex_conn_id = yandex_conn_id
+
+ def _setup(self, context: 'Context') -> DataprocHook:
+ if self.cluster_id is None:
+ self.cluster_id = context['task_instance'].xcom_pull(key='cluster_id')
+ if self.yandex_conn_id is None:
+ xcom_yandex_conn_id = context['task_instance'].xcom_pull(key='yandexcloud_connection_id')
+ if xcom_yandex_conn_id:
+ warnings.warn('Implicit pass of `yandex_conn_id` is deprecated, please pass it explicitly')
+ self.yandex_conn_id = xcom_yandex_conn_id
+
+ return DataprocHook(yandex_conn_id=self.yandex_conn_id)
+
+ def execute(self, context: 'Context'):
+ raise NotImplementedError()
+
+
+class DataprocDeleteClusterOperator(DataprocBaseOperator):
+ """Deletes Yandex.Cloud Data Proc cluster.
+
+ :param connection_id: ID of the Yandex.Cloud Airflow connection.
+ :param cluster_id: ID of the cluster to remove. (templated)
+ """
+
+ def __init__(
+ self, *, connection_id: Optional[str] = None, cluster_id: Optional[str] = None, **kwargs
+ ) -> None:
+ super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs)
def execute(self, context: 'Context') -> None:
- cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id')
- yandex_conn_id = self.yandex_conn_id or context['task_instance'].xcom_pull(
- key='yandexcloud_connection_id'
- )
- self.hook = DataprocHook(
- yandex_conn_id=yandex_conn_id,
- )
- self.hook.client.delete_cluster(cluster_id)
+ hook = self._setup(context)
+ hook.client.delete_cluster(self.cluster_id)
-class DataprocCreateHiveJobOperator(BaseOperator):
+class DataprocCreateHiveJobOperator(DataprocBaseOperator):
"""Runs Hive job in Data Proc cluster.
:param query: Hive query.
@@ -224,8 +296,6 @@ class DataprocCreateHiveJobOperator(BaseOperator):
:param connection_id: ID of the Yandex.Cloud Airflow connection.
"""
- template_fields: Sequence[str] = ('cluster_id',)
-
def __init__(
self,
*,
@@ -239,37 +309,28 @@ class DataprocCreateHiveJobOperator(BaseOperator):
connection_id: Optional[str] = None,
**kwargs,
) -> None:
- super().__init__(**kwargs)
+ super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs)
self.query = query
self.query_file_uri = query_file_uri
self.script_variables = script_variables
self.continue_on_failure = continue_on_failure
self.properties = properties
self.name = name
- self.cluster_id = cluster_id
- self.connection_id = connection_id
- self.hook: Optional[DataprocHook] = None
def execute(self, context: 'Context') -> None:
- cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id')
- yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull(
- key='yandexcloud_connection_id'
- )
- self.hook = DataprocHook(
- yandex_conn_id=yandex_conn_id,
- )
- self.hook.client.create_hive_job(
+ hook = self._setup(context)
+ hook.client.create_hive_job(
query=self.query,
query_file_uri=self.query_file_uri,
script_variables=self.script_variables,
continue_on_failure=self.continue_on_failure,
properties=self.properties,
name=self.name,
- cluster_id=cluster_id,
+ cluster_id=self.cluster_id,
)
-class DataprocCreateMapReduceJobOperator(BaseOperator):
+class DataprocCreateMapReduceJobOperator(DataprocBaseOperator):
"""Runs Mapreduce job in Data Proc cluster.
:param main_jar_file_uri: URI of jar file with job.
@@ -286,8 +347,6 @@ class DataprocCreateMapReduceJobOperator(BaseOperator):
:param connection_id: ID of the Yandex.Cloud Airflow connection.
"""
- template_fields: Sequence[str] = ('cluster_id',)
-
def __init__(
self,
*,
@@ -303,7 +362,7 @@ class DataprocCreateMapReduceJobOperator(BaseOperator):
connection_id: Optional[str] = None,
**kwargs,
) -> None:
- super().__init__(**kwargs)
+ super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs)
self.main_class = main_class
self.main_jar_file_uri = main_jar_file_uri
self.jar_file_uris = jar_file_uris
@@ -312,19 +371,10 @@ class DataprocCreateMapReduceJobOperator(BaseOperator):
self.args = args
self.properties = properties
self.name = name
- self.cluster_id = cluster_id
- self.connection_id = connection_id
- self.hook: Optional[DataprocHook] = None
def execute(self, context: 'Context') -> None:
- cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id')
- yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull(
- key='yandexcloud_connection_id'
- )
- self.hook = DataprocHook(
- yandex_conn_id=yandex_conn_id,
- )
- self.hook.client.create_mapreduce_job(
+ hook = self._setup(context)
+ hook.client.create_mapreduce_job(
main_class=self.main_class,
main_jar_file_uri=self.main_jar_file_uri,
jar_file_uris=self.jar_file_uris,
@@ -333,11 +383,11 @@ class DataprocCreateMapReduceJobOperator(BaseOperator):
args=self.args,
properties=self.properties,
name=self.name,
- cluster_id=cluster_id,
+ cluster_id=self.cluster_id,
)
-class DataprocCreateSparkJobOperator(BaseOperator):
+class DataprocCreateSparkJobOperator(DataprocBaseOperator):
"""Runs Spark job in Data Proc cluster.
:param main_jar_file_uri: URI of jar file with job. Can be placed in HDFS or S3.
@@ -358,8 +408,6 @@ class DataprocCreateSparkJobOperator(BaseOperator):
provided in --packages to avoid dependency conflicts.
"""
- template_fields: Sequence[str] = ('cluster_id',)
-
def __init__(
self,
*,
@@ -378,7 +426,7 @@ class DataprocCreateSparkJobOperator(BaseOperator):
exclude_packages: Optional[Iterable[str]] = None,
**kwargs,
) -> None:
- super().__init__(**kwargs)
+ super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs)
self.main_class = main_class
self.main_jar_file_uri = main_jar_file_uri
self.jar_file_uris = jar_file_uris
@@ -387,22 +435,13 @@ class DataprocCreateSparkJobOperator(BaseOperator):
self.args = args
self.properties = properties
self.name = name
- self.cluster_id = cluster_id
- self.connection_id = connection_id
self.packages = packages
self.repositories = repositories
self.exclude_packages = exclude_packages
- self.hook: Optional[DataprocHook] = None
def execute(self, context: 'Context') -> None:
- cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id')
- yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull(
- key='yandexcloud_connection_id'
- )
- self.hook = DataprocHook(
- yandex_conn_id=yandex_conn_id,
- )
- self.hook.client.create_spark_job(
+ hook = self._setup(context)
+ hook.client.create_spark_job(
main_class=self.main_class,
main_jar_file_uri=self.main_jar_file_uri,
jar_file_uris=self.jar_file_uris,
@@ -414,11 +453,11 @@ class DataprocCreateSparkJobOperator(BaseOperator):
repositories=self.repositories,
exclude_packages=self.exclude_packages,
name=self.name,
- cluster_id=cluster_id,
+ cluster_id=self.cluster_id,
)
-class DataprocCreatePysparkJobOperator(BaseOperator):
+class DataprocCreatePysparkJobOperator(DataprocBaseOperator):
"""Runs Pyspark job in Data Proc cluster.
:param main_python_file_uri: URI of python file with job. Can be placed in HDFS or S3.
@@ -439,8 +478,6 @@ class DataprocCreatePysparkJobOperator(BaseOperator):
provided in --packages to avoid dependency conflicts.
"""
- template_fields: Sequence[str] = ('cluster_id',)
-
def __init__(
self,
*,
@@ -459,7 +496,7 @@ class DataprocCreatePysparkJobOperator(BaseOperator):
exclude_packages: Optional[Iterable[str]] = None,
**kwargs,
) -> None:
- super().__init__(**kwargs)
+ super().__init__(yandex_conn_id=connection_id, cluster_id=cluster_id, **kwargs)
self.main_python_file_uri = main_python_file_uri
self.python_file_uris = python_file_uris
self.jar_file_uris = jar_file_uris
@@ -468,22 +505,13 @@ class DataprocCreatePysparkJobOperator(BaseOperator):
self.args = args
self.properties = properties
self.name = name
- self.cluster_id = cluster_id
- self.connection_id = connection_id
self.packages = packages
self.repositories = repositories
self.exclude_packages = exclude_packages
- self.hook: Optional[DataprocHook] = None
def execute(self, context: 'Context') -> None:
- cluster_id = self.cluster_id or context['task_instance'].xcom_pull(key='cluster_id')
- yandex_conn_id = self.connection_id or context['task_instance'].xcom_pull(
- key='yandexcloud_connection_id'
- )
- self.hook = DataprocHook(
- yandex_conn_id=yandex_conn_id,
- )
- self.hook.client.create_pyspark_job(
+ hook = self._setup(context)
+ hook.client.create_pyspark_job(
main_python_file_uri=self.main_python_file_uri,
python_file_uris=self.python_file_uris,
jar_file_uris=self.jar_file_uris,
@@ -495,5 +523,5 @@ class DataprocCreatePysparkJobOperator(BaseOperator):
repositories=self.repositories,
exclude_packages=self.exclude_packages,
name=self.name,
- cluster_id=cluster_id,
+ cluster_id=self.cluster_id,
)
diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml
index c066a2f8ed..90f6cc1c9a 100644
--- a/airflow/providers/yandex/provider.yaml
+++ b/airflow/providers/yandex/provider.yaml
@@ -34,7 +34,7 @@ versions:
dependencies:
- apache-airflow>=2.2.0
- - yandexcloud>=0.146.0
+ - yandexcloud>=0.173.0
integrations:
- integration-name: Yandex.Cloud
diff --git a/generated/README.md b/generated/README.md
index f87a767da4..d1dcc1f783 100644
--- a/generated/README.md
+++ b/generated/README.md
@@ -20,6 +20,8 @@
NOTE! The files in this folder are generated by pre-commit based on airflow sources. They are not
supposed to be manually modified.
+You can read more about pre-commit hooks [here](../STATIC_CODE_CHECKS.rst#pre-commit-hooks).
+
* `provider_dependencies.json` - is generated based on `provider.yaml` files in `airflow/providers` and
based on the imports in the provider code. If you want to add new dependency to a provider, you
need to modify the corresponding `provider.yaml` file
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 743a73d0da..51e3ea26f8 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -721,7 +721,7 @@
"yandex": {
"deps": [
"apache-airflow>=2.2.0",
- "yandexcloud>=0.146.0"
+ "yandexcloud>=0.173.0"
],
"cross-providers-deps": []
},
diff --git a/tests/providers/yandex/hooks/test_yandex.py b/tests/providers/yandex/hooks/test_yandex.py
index b4ddf0e121..a1ada7aefa 100644
--- a/tests/providers/yandex/hooks/test_yandex.py
+++ b/tests/providers/yandex/hooks/test_yandex.py
@@ -43,7 +43,11 @@ class TestYandexHook(unittest.TestCase):
)
get_credentials_mock.return_value = {"token": 122323}
- hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
+ hook = YandexCloudBaseHook(
+ yandex_conn_id=None,
+ default_folder_id=default_folder_id,
+ default_public_ssh_key=default_public_ssh_key,
+ )
assert hook.client is not None
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@@ -63,7 +67,11 @@ class TestYandexHook(unittest.TestCase):
)
with pytest.raises(AirflowException):
- YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
+ YandexCloudBaseHook(
+ yandex_conn_id=None,
+ default_folder_id=default_folder_id,
+ default_public_ssh_key=default_public_ssh_key,
+ )
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
@mock.patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials')
@@ -80,6 +88,10 @@ class TestYandexHook(unittest.TestCase):
)
get_credentials_mock.return_value = {"token": 122323}
- hook = YandexCloudBaseHook(None, default_folder_id, default_public_ssh_key)
+ hook = YandexCloudBaseHook(
+ yandex_conn_id=None,
+ default_folder_id=default_folder_id,
+ default_public_ssh_key=default_public_ssh_key,
+ )
assert hook._get_field('one') == 'value_one'
diff --git a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py
index f54087c742..23cda00e4a 100644
--- a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py
+++ b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py
@@ -127,6 +127,11 @@ class DataprocClusterCreateOperatorTest(TestCase):
subnet_id='my_subnet_id',
zone='ru-central1-c',
log_group_id=LOG_GROUP_ID,
+ properties=None,
+ enable_ui_proxy=False,
+ host_group_ids=None,
+ security_group_ids=None,
+ initialization_actions=None,
)
context['task_instance'].xcom_push.assert_has_calls(
[
diff --git a/tests/system/providers/yandex/example_yandexcloud.py b/tests/system/providers/yandex/example_yandexcloud.py
new file mode 100644
index 0000000000..708a6049a7
--- /dev/null
+++ b/tests/system/providers/yandex/example_yandexcloud.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from datetime import datetime
+from typing import Optional
+
+import yandex.cloud.dataproc.v1.cluster_pb2 as cluster_pb
+import yandex.cloud.dataproc.v1.cluster_service_pb2 as cluster_service_pb
+import yandex.cloud.dataproc.v1.cluster_service_pb2_grpc as cluster_service_grpc_pb
+import yandex.cloud.dataproc.v1.common_pb2 as common_pb
+import yandex.cloud.dataproc.v1.job_pb2 as job_pb
+import yandex.cloud.dataproc.v1.job_service_pb2 as job_service_pb
+import yandex.cloud.dataproc.v1.job_service_pb2_grpc as job_service_grpc_pb
+import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb
+from google.protobuf.json_format import MessageToDict
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = 'example_yandexcloud_hook'
+
+# Fill it with your identifiers
+YC_S3_BUCKET_NAME = '' # Fill to use S3 instead of HFDS
+YC_FOLDER_ID = None # Fill to override default YC folder from connection data
+YC_ZONE_NAME = 'ru-central1-b'
+YC_SUBNET_ID = None # Fill if you have more than one VPC subnet in given folder and zone
+YC_SERVICE_ACCOUNT_ID = None # Fill if you have more than one YC service account in given folder
+
+
+def create_cluster_request(
+ folder_id: str,
+ cluster_name: str,
+ cluster_desc: str,
+ zone: str,
+ subnet_id: str,
+ service_account_id: str,
+ ssh_public_key: str,
+ resources: common_pb.Resources,
+):
+ return cluster_service_pb.CreateClusterRequest(
+ folder_id=folder_id,
+ name=cluster_name,
+ description=cluster_desc,
+ bucket=YC_S3_BUCKET_NAME,
+ config_spec=cluster_service_pb.CreateClusterConfigSpec(
+ hadoop=cluster_pb.HadoopConfig(
+ services=('SPARK', 'YARN'),
+ ssh_public_keys=[ssh_public_key],
+ ),
+ subclusters_spec=[
+ cluster_service_pb.CreateSubclusterConfigSpec(
+ name='master',
+ role=subcluster_pb.Role.MASTERNODE,
+ resources=resources,
+ subnet_id=subnet_id,
+ hosts_count=1,
+ ),
+ cluster_service_pb.CreateSubclusterConfigSpec(
+ name='compute',
+ role=subcluster_pb.Role.COMPUTENODE,
+ resources=resources,
+ subnet_id=subnet_id,
+ hosts_count=1,
+ ),
+ ],
+ ),
+ zone_id=zone,
+ service_account_id=service_account_id,
+ )
+
+
+@task
+def create_cluster(
+ yandex_conn_id: Optional[str] = None,
+ folder_id: Optional[str] = None,
+ network_id: Optional[str] = None,
+ subnet_id: Optional[str] = None,
+ zone: str = YC_ZONE_NAME,
+ service_account_id: Optional[str] = None,
+ ssh_public_key: Optional[str] = None,
+ *,
+ dag: Optional[DAG] = None,
+ ts_nodash: Optional[str] = None,
+) -> str:
+ hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)
+ folder_id = folder_id or hook.default_folder_id
+ if subnet_id is None:
+ network_id = network_id or hook.sdk.helpers.find_network_id(folder_id)
+ subnet_id = hook.sdk.helpers.find_subnet_id(folder_id=folder_id, zone_id=zone, network_id=network_id)
+ service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id()
+ ssh_public_key = ssh_public_key or hook.default_public_ssh_key
+
+ dag_id = dag and dag.dag_id or 'dag'
+
+ request = create_cluster_request(
+ folder_id=folder_id,
+ subnet_id=subnet_id,
+ zone=zone,
+ cluster_name=f'airflow_{dag_id}_{ts_nodash}'[:62],
+ cluster_desc='Created via Airflow custom hook task',
+ service_account_id=service_account_id,
+ ssh_public_key=ssh_public_key,
+ resources=common_pb.Resources(
+ resource_preset_id='s2.micro',
+ disk_type_id='network-ssd',
+ ),
+ )
+ operation = hook.sdk.client(cluster_service_grpc_pb.ClusterServiceStub).Create(request)
+ operation_result = hook.sdk.wait_operation_and_get_result(
+ operation, response_type=cluster_pb.Cluster, meta_type=cluster_service_pb.CreateClusterMetadata
+ )
+ return operation_result.response.id
+
+
+@task
+def run_spark_job(
+ cluster_id: str,
+ yandex_conn_id: Optional[str] = None,
+):
+ hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)
+
+ request = job_service_pb.CreateJobRequest(
+ cluster_id=cluster_id,
+ name='Spark job: Find total urban population in distribution by country',
+ spark_job=job_pb.SparkJob(
+ main_jar_file_uri='file:///usr/lib/spark/examples/jars/spark-examples.jar',
+ main_class='org.apache.spark.examples.SparkPi',
+ args=['1000'],
+ ),
+ )
+ operation = hook.sdk.client(job_service_grpc_pb.JobServiceStub).Create(request)
+ operation_result = hook.sdk.wait_operation_and_get_result(
+ operation, response_type=job_pb.Job, meta_type=job_service_pb.CreateJobMetadata
+ )
+ return MessageToDict(operation_result.response)
+
+
+@task(trigger_rule='all_done')
+def delete_cluster(
+ cluster_id: str,
+ yandex_conn_id: Optional[str] = None,
+):
+ hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)
+
+ operation = hook.sdk.client(cluster_service_grpc_pb.ClusterServiceStub).Delete(
+ cluster_service_pb.DeleteClusterRequest(cluster_id=cluster_id)
+ )
+ hook.sdk.wait_operation_and_get_result(
+ operation,
+ meta_type=cluster_service_pb.DeleteClusterMetadata,
+ )
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+) as dag:
+ cluster_id = create_cluster(
+ folder_id=YC_FOLDER_ID,
+ subnet_id=YC_SUBNET_ID,
+ zone=YC_ZONE_NAME,
+ service_account_id=YC_SERVICE_ACCOUNT_ID,
+ )
+ spark_job = run_spark_job(cluster_id=cluster_id)
+ delete_task = delete_cluster(cluster_id=cluster_id)
+
+ spark_job >> delete_task
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py b/tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py
new file mode 100644
index 0000000000..d5faa0865e
--- /dev/null
+++ b/tests/system/providers/yandex/example_yandexcloud_dataproc_lightweight.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from datetime import datetime
+
+from airflow import DAG
+from airflow.providers.yandex.operators.yandexcloud_dataproc import (
+ DataprocCreateClusterOperator,
+ DataprocCreateSparkJobOperator,
+ DataprocDeleteClusterOperator,
+)
+
+# Name of the datacenter where Dataproc cluster will be created
+from airflow.utils.trigger_rule import TriggerRule
+
+# should be filled with appropriate ids
+
+
+AVAILABILITY_ZONE_ID = 'ru-central1-c'
+
+# Dataproc cluster will use this bucket as distributed storage
+S3_BUCKET_NAME = ''
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = 'example_yandexcloud_dataproc_lightweight'
+
+with DAG(
+ DAG_ID,
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+) as dag:
+ create_cluster = DataprocCreateClusterOperator(
+ task_id='create_cluster',
+ zone=AVAILABILITY_ZONE_ID,
+ s3_bucket=S3_BUCKET_NAME,
+ computenode_count=1,
+ datanode_count=0,
+ services=('SPARK', 'YARN'),
+ )
+
+ create_spark_job = DataprocCreateSparkJobOperator(
+ cluster_id=create_cluster.cluster_id,
+ task_id='create_spark_job',
+ main_jar_file_uri='file:///usr/lib/spark/examples/jars/spark-examples.jar',
+ main_class='org.apache.spark.examples.SparkPi',
+ args=['1000'],
+ )
+
+ delete_cluster = DataprocDeleteClusterOperator(
+ cluster_id=create_cluster.cluster_id,
+ task_id='delete_cluster',
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ create_spark_job >> delete_cluster
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)