You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by tu...@apache.org on 2020/06/25 10:11:36 UTC
[airflow] branch master updated: Add read-only Task endpoint (#9330)
This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 5eb2808 Add read-only Task endpoint (#9330)
5eb2808 is described below
commit 5eb2808da1e33cb7a72fad8693d75e8a21401828
Author: Tomek Urbaszek <tu...@gmail.com>
AuthorDate: Thu Jun 25 12:11:11 2020 +0200
Add read-only Task endpoint (#9330)
Add API endpoints for tasks and DAG details
Co-authored-by: Kamil BreguĊa <ka...@polidea.com>
---
airflow/api_connexion/endpoints/dag_endpoint.py | 12 +-
airflow/api_connexion/endpoints/task_endpoint.py | 28 +++-
airflow/api_connexion/openapi/v1.yaml | 29 +++-
airflow/api_connexion/schemas/common_schema.py | 169 +++++++++++++++++++++
airflow/api_connexion/schemas/dag_schema.py | 93 ++++++++++++
airflow/api_connexion/schemas/task_schema.py | 80 ++++++++++
requirements/requirements-python3.6.txt | 1 +
requirements/requirements-python3.7.txt | 1 +
requirements/requirements-python3.8.txt | 3 +-
requirements/setup-3.6.md5 | 2 +-
requirements/setup-3.7.md5 | 2 +-
requirements/setup-3.8.md5 | 2 +-
setup.py | 1 +
tests/api_connexion/endpoints/test_dag_endpoint.py | 91 ++++++++++-
.../api_connexion/endpoints/test_task_endpoint.py | 149 +++++++++++++++++-
tests/api_connexion/schemas/test_common_schema.py | 145 ++++++++++++++++++
tests/api_connexion/schemas/test_dag_schema.py | 123 +++++++++++++++
tests/api_connexion/schemas/test_task_schema.py | 99 ++++++++++++
tests/cli/commands/test_dag_command.py | 7 +
tests/test_utils/db.py | 6 +
20 files changed, 1021 insertions(+), 22 deletions(-)
diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py
index 7cbb0ef..7cdeeb6 100644
--- a/airflow/api_connexion/endpoints/dag_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_endpoint.py
@@ -15,10 +15,15 @@
# specific language governing permissions and limitations
# under the License.
+from flask import current_app
+
+from airflow import DAG
+from airflow.api_connexion.exceptions import NotFound
# TODO(mik-laj): We have to implement it.
# Do you want to help? Please look at:
# * https://github.com/apache/airflow/issues/8128
# * https://github.com/apache/airflow/issues/8138
+from airflow.api_connexion.schemas.dag_schema import dag_detail_schema
def get_dag():
@@ -28,11 +33,14 @@ def get_dag():
raise NotImplementedError("Not implemented yet.")
-def get_dag_details():
+def get_dag_details(dag_id):
"""
Get details of DAG.
"""
- raise NotImplementedError("Not implemented yet.")
+ dag: DAG = current_app.dag_bag.get_dag(dag_id)
+ if not dag:
+ raise NotFound("DAG not found")
+ return dag_detail_schema.dump(dag)
def get_dags():
diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py
index de7eaa4..e23483a 100644
--- a/airflow/api_connexion/endpoints/task_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_endpoint.py
@@ -14,20 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from flask import current_app
-# TODO(mik-laj): We have to implement it.
-# Do you want to help? Please look at: https://github.com/apache/airflow/issues/8138
+from airflow import DAG
+from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema
+from airflow.exceptions import TaskNotFound
-def get_task():
+def get_task(dag_id, task_id):
"""
Get simplified representation of a task.
"""
- raise NotImplementedError("Not implemented yet.")
+ dag: DAG = current_app.dag_bag.get_dag(dag_id)
+ if not dag:
+ raise NotFound("DAG not found")
+ try:
+ task = dag.get_task(task_id=task_id)
+ except TaskNotFound:
+ raise NotFound("Task not found")
+ return task_schema.dump(task)
-def get_tasks():
+
+def get_tasks(dag_id):
"""
Get tasks for DAG
"""
- raise NotImplementedError("Not implemented yet.")
+ dag: DAG = current_app.dag_bag.get_dag(dag_id)
+ if not dag:
+ raise NotFound("DAG not found")
+ tasks = dag.tasks
+ task_collection = TaskCollection(tasks=tasks, total_entries=len(tasks))
+ return task_collection_schema.dump(task_collection)
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index e6ab5f6..1794b65 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1234,8 +1234,10 @@ components:
root_dag_id:
type: string
readOnly: true
+ nullable: true
is_paused:
type: boolean
+ nullable: true
is_subdag:
type: boolean
readOnly: true
@@ -1257,11 +1259,13 @@ components:
description:
type: string
readOnly: true
+ nullable: true
schedule_interval:
$ref: '#/components/schemas/ScheduleInterval'
readOnly: true
tags:
type: array
+ nullable: true
items:
$ref: '#/components/schemas/Tag'
readOnly: true
@@ -1638,6 +1642,7 @@ components:
format: 'date-time'
readOnly: true
dag_run_timeout:
+ nullable: true
$ref: '#/components/schemas/TimeDelta'
doc_md:
type: string
@@ -1685,6 +1690,7 @@ components:
type: string
format: 'date-time'
readOnly: true
+ nullable: true
trigger_rule:
$ref: '#/components/schemas/TriggerRule'
extra_links:
@@ -1715,8 +1721,10 @@ components:
readOnly: true
execution_timeout:
$ref: '#/components/schemas/TimeDelta'
+ nullable: true
retry_delay:
$ref: '#/components/schemas/TimeDelta'
+ nullable: true
retry_exponential_backoff:
type: boolean
readOnly: true
@@ -2003,17 +2011,35 @@ components:
type: object
required:
- __type
+ - days
+ - seconds
+ - microseconds
properties:
__type: {type: string}
days: {type: integer}
seconds: {type: integer}
- microsecond: {type: integer}
+ microseconds: {type: integer}
RelativeDelta:
# TODO: Why we need these fields?
type: object
required:
- __type
+ - years
+ - months
+ - days
+ - leapdays
+ - hours
+ - minutes
+ - seconds
+ - microseconds
+ - year
+ - month
+ - day
+ - hour
+ - minute
+ - second
+ - microsecond
properties:
__type: {type: string}
years: {type: integer}
@@ -2036,6 +2062,7 @@ components:
type: object
required:
- __type
+ - value
properties:
__type: {type: string}
value: {type: string}
diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py
new file mode 100644
index 0000000..5e3afe6
--- /dev/null
+++ b/airflow/api_connexion/schemas/common_schema.py
@@ -0,0 +1,169 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import inspect
+import typing
+
+import marshmallow
+from dateutil import relativedelta
+from marshmallow import Schema, fields, validate
+from marshmallow_oneofschema import OneOfSchema
+
+from airflow.serialization.serialized_objects import SerializedBaseOperator
+from airflow.utils.weight_rule import WeightRule
+
+
+class CronExpression(typing.NamedTuple):
+ """Cron expression schema"""
+ value: str
+
+
+class TimeDeltaSchema(Schema):
+ """Time delta schema"""
+
+ objectType = fields.Constant("TimeDelta", dump_to="__type")
+ days = fields.Integer()
+ seconds = fields.Integer()
+ microseconds = fields.Integer()
+
+ @marshmallow.post_load
+ def make_time_delta(self, data, **kwargs):
+ """Create time delta based on data"""
+
+ if "objectType" in data:
+ del data["objectType"]
+ return datetime.timedelta(**data)
+
+
+class RelativeDeltaSchema(Schema):
+ """Relative delta schema"""
+
+ objectType = fields.Constant("RelativeDelta", dump_to="__type")
+ years = fields.Integer()
+ months = fields.Integer()
+ days = fields.Integer()
+ leapdays = fields.Integer()
+ hours = fields.Integer()
+ minutes = fields.Integer()
+ seconds = fields.Integer()
+ microseconds = fields.Integer()
+ year = fields.Integer()
+ month = fields.Integer()
+ day = fields.Integer()
+ hour = fields.Integer()
+ minute = fields.Integer()
+ second = fields.Integer()
+ microsecond = fields.Integer()
+
+ @marshmallow.post_load
+ def make_relative_delta(self, data, **kwargs):
+ """Create relative delta based on data"""
+
+ if "objectType" in data:
+ del data["objectType"]
+
+ return relativedelta.relativedelta(**data)
+
+
+class CronExpressionSchema(Schema):
+ """Cron expression schema"""
+
+ objectType = fields.Constant("CronExpression", dump_to="__type", required=True)
+ value = fields.String(required=True)
+
+ @marshmallow.post_load
+ def make_cron_expression(self, data, **kwargs):
+ """Create cron expression based on data"""
+ return CronExpression(data["value"])
+
+
+class ScheduleIntervalSchema(OneOfSchema):
+ """
+ Schedule interval.
+
+ It supports the following types:
+
+ * TimeDelta
+ * RelativeDelta
+ * CronExpression
+ """
+ type_field = "__type"
+ type_schemas = {
+ "TimeDelta": TimeDeltaSchema,
+ "RelativeDelta": RelativeDeltaSchema,
+ "CronExpression": CronExpressionSchema,
+ }
+
+ def _dump(self, obj, update_fields=True, **kwargs):
+ if isinstance(obj, str):
+ obj = CronExpression(obj)
+
+ return super()._dump(obj, update_fields=update_fields, **kwargs)
+
+ def get_obj_type(self, obj):
+ """Select schema based on object type"""
+ if isinstance(obj, datetime.timedelta):
+ return "TimeDelta"
+ elif isinstance(obj, relativedelta.relativedelta):
+ return "RelativeDelta"
+ elif isinstance(obj, CronExpression):
+ return "CronExpression"
+ else:
+ raise Exception("Unknown object type: {}".format(obj.__class__.__name__))
+
+
+class ColorField(fields.String):
+ """Schema for color property"""
+ def __init__(self, **metadata):
+ super().__init__(**metadata)
+ self.validators = (
+ [validate.Regexp("^#[a-fA-F0-9]{3,6}$")] + list(self.validators)
+ )
+
+
+class WeightRuleField(fields.String):
+ """Schema for WeightRule"""
+ def __init__(self, **metadata):
+ super().__init__(**metadata)
+ self.validators = (
+ [validate.OneOf(WeightRule.all_weight_rules())] + list(self.validators)
+ )
+
+
+class TimezoneField(fields.String):
+ """Schema for timezone"""
+
+
+class ClassReferenceSchema(Schema):
+ """
+ Class reference schema.
+ """
+ module_path = fields.Method("_get_module", required=True)
+ class_name = fields.Method("_get_class_name", required=True)
+
+ def _get_module(self, obj):
+ if isinstance(obj, SerializedBaseOperator):
+ return obj._task_module # pylint: disable=protected-access
+ return inspect.getmodule(obj).__name__
+
+ def _get_class_name(self, obj):
+ if isinstance(obj, SerializedBaseOperator):
+ return obj._task_type # pylint: disable=protected-access
+ if isinstance(obj, type):
+ return obj.__name__
+ return type(obj).__name__
diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py
new file mode 100644
index 0000000..5104d70
--- /dev/null
+++ b/airflow/api_connexion/schemas/dag_schema.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.api_connexion.schemas.common_schema import ScheduleIntervalSchema, TimeDeltaSchema, TimezoneField
+from airflow.models.dag import DagModel, DagTag
+
+
+class DagTagSchema(SQLAlchemySchema):
+ """Dag Tag schema"""
+ class Meta:
+ """Meta"""
+
+ model = DagTag
+
+ name = auto_field()
+
+
+class DAGSchema(SQLAlchemySchema):
+ """DAG schema"""
+
+ class Meta:
+ """Meta"""
+
+ model = DagModel
+
+ dag_id = auto_field(dump_only=True)
+ root_dag_id = auto_field(dump_only=True)
+ is_paused = auto_field(dump_only=True)
+ is_subdag = auto_field(dump_only=True)
+ fileloc = auto_field(dump_only=True)
+ owners = fields.Method("get_owners", dump_only=True)
+ description = auto_field(dump_only=True)
+ schedule_interval = fields.Nested(ScheduleIntervalSchema, dump_only=True)
+ tags = fields.List(fields.Nested(DagTagSchema), dump_only=True)
+
+ @staticmethod
+ def get_owners(obj: DagModel):
+ """Convert owners attribute to DAG representation"""
+
+ if not obj.owners:
+ return []
+ return obj.owners.split(",")
+
+
+class DAGDetailSchema(DAGSchema):
+ """DAG details"""
+
+ timezone = TimezoneField(dump_only=True)
+ catchup = fields.Boolean(dump_only=True)
+ orientation = fields.String(dump_only=True)
+ concurrency = fields.Integer(dump_only=True)
+ start_date = fields.DateTime(dump_only=True)
+ dag_run_timeout = fields.Nested(TimeDeltaSchema, dump_only=True, attribute="dagrun_timeout")
+ doc_md = fields.String(dump_only=True)
+ default_view = fields.String(dump_only=True)
+
+
+class DAGCollection(NamedTuple):
+ """List of DAGs with metadata"""
+
+ dags: List[DagModel]
+ total_entries: int
+
+
+class DAGCollectionSchema(Schema):
+ """DAG Collection schema"""
+
+ dags = fields.List(fields.Nested(DAGSchema))
+ total_entries = fields.Int()
+
+
+dags_collection_schema = DAGCollectionSchema()
+dag_schema = DAGSchema()
+dag_detail_schema = DAGDetailSchema()
diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py
new file mode 100644
index 0000000..52a6a30
--- /dev/null
+++ b/airflow/api_connexion/schemas/task_schema.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, NamedTuple
+
+from marshmallow import Schema, fields
+
+from airflow.api_connexion.schemas.common_schema import (
+ ClassReferenceSchema, ColorField, TimeDeltaSchema, WeightRuleField,
+)
+from airflow.api_connexion.schemas.dag_schema import DAGSchema
+from airflow.models.baseoperator import BaseOperator
+
+
+class TaskSchema(Schema):
+ """Task schema"""
+
+ class_ref = fields.Method("_get_class_reference", dump_only=True)
+ task_id = fields.String(dump_only=True)
+ owner = fields.String(dump_only=True)
+ start_date = fields.DateTime(dump_only=True)
+ end_date = fields.DateTime(dump_only=True)
+ trigger_rule = fields.String(dump_only=True)
+ extra_links = fields.List(
+ fields.Nested(ClassReferenceSchema),
+ dump_only=True,
+ attribute="operator_extra_links"
+ )
+ depends_on_past = fields.Boolean(dump_only=True)
+ wait_for_downstream = fields.Boolean(dump_only=True)
+ retries = fields.Number(dump_only=True)
+ queue = fields.String(dump_only=True)
+ pool = fields.String(dump_only=True)
+ pool_slots = fields.Number(dump_only=True)
+ execution_timeout = fields.Nested(TimeDeltaSchema, dump_only=True)
+ retry_delay = fields.Nested(TimeDeltaSchema, dump_only=True)
+ retry_exponential_backoff = fields.Boolean(dump_only=True)
+ priority_weight = fields.Number(dump_only=True)
+ weight_rule = WeightRuleField(dump_only=True)
+ ui_color = ColorField(dump_only=True)
+ ui_fgcolor = ColorField(dump_only=True)
+ template_fields = fields.List(fields.String(), dump_only=True)
+ sub_dag = fields.Nested(DAGSchema, dump_only=True)
+ downstream_task_ids = fields.List(fields.String(), dump_only=True)
+
+ def _get_class_reference(self, obj):
+ result = ClassReferenceSchema().dump(obj)
+ return result.data if hasattr(result, "data") else result
+
+
+class TaskCollection(NamedTuple):
+ """List of Tasks with metadata"""
+
+ tasks: List[BaseOperator]
+ total_entries: int
+
+
+class TaskCollectionSchema(Schema):
+ """Schema for TaskCollection"""
+
+ tasks = fields.List(fields.Nested(TaskSchema))
+ total_entries = fields.Int()
+
+
+task_schema = TaskSchema()
+task_collection_schema = TaskCollectionSchema()
diff --git a/requirements/requirements-python3.6.txt b/requirements/requirements-python3.6.txt
index 8e0e6ad..b6b690b 100644
--- a/requirements/requirements-python3.6.txt
+++ b/requirements/requirements-python3.6.txt
@@ -214,6 +214,7 @@ lazy-object-proxy==1.5.0
ldap3==2.7
lockfile==0.12.2
marshmallow-enum==1.5.1
+marshmallow-oneofschema==1.0.6
marshmallow-sqlalchemy==0.23.1
marshmallow==2.21.0
mccabe==0.6.1
diff --git a/requirements/requirements-python3.7.txt b/requirements/requirements-python3.7.txt
index 06d4761..e18701c 100644
--- a/requirements/requirements-python3.7.txt
+++ b/requirements/requirements-python3.7.txt
@@ -210,6 +210,7 @@ lazy-object-proxy==1.5.0
ldap3==2.7
lockfile==0.12.2
marshmallow-enum==1.5.1
+marshmallow-oneofschema==1.0.6
marshmallow-sqlalchemy==0.23.1
marshmallow==2.21.0
mccabe==0.6.1
diff --git a/requirements/requirements-python3.8.txt b/requirements/requirements-python3.8.txt
index 3885f0b..918f3ef 100644
--- a/requirements/requirements-python3.8.txt
+++ b/requirements/requirements-python3.8.txt
@@ -45,7 +45,7 @@ apispec==1.3.3
appdirs==1.4.4
argcomplete==1.11.1
asn1crypto==1.3.0
-astroid==2.4.2
+astroid==2.3.3
async-generator==1.10
async-timeout==3.0.1
atlasclient==1.0.0
@@ -210,6 +210,7 @@ lazy-object-proxy==1.5.0
ldap3==2.7
lockfile==0.12.2
marshmallow-enum==1.5.1
+marshmallow-oneofschema==1.0.6
marshmallow-sqlalchemy==0.23.1
marshmallow==2.21.0
mccabe==0.6.1
diff --git a/requirements/setup-3.6.md5 b/requirements/setup-3.6.md5
index 5b4b71f..86c4da4 100644
--- a/requirements/setup-3.6.md5
+++ b/requirements/setup-3.6.md5
@@ -1 +1 @@
-58b2fa003085a21989e7c2cc68a10461 /opt/airflow/setup.py
+cac9433ddd48ca884fa160b007be3818 /opt/airflow/setup.py
diff --git a/requirements/setup-3.7.md5 b/requirements/setup-3.7.md5
index 5b4b71f..86c4da4 100644
--- a/requirements/setup-3.7.md5
+++ b/requirements/setup-3.7.md5
@@ -1 +1 @@
-58b2fa003085a21989e7c2cc68a10461 /opt/airflow/setup.py
+cac9433ddd48ca884fa160b007be3818 /opt/airflow/setup.py
diff --git a/requirements/setup-3.8.md5 b/requirements/setup-3.8.md5
index 5b4b71f..86c4da4 100644
--- a/requirements/setup-3.8.md5
+++ b/requirements/setup-3.8.md5
@@ -1 +1 @@
-58b2fa003085a21989e7c2cc68a10461 /opt/airflow/setup.py
+cac9433ddd48ca884fa160b007be3818 /opt/airflow/setup.py
diff --git a/setup.py b/setup.py
index 4a2d451..4443514 100644
--- a/setup.py
+++ b/setup.py
@@ -710,6 +710,7 @@ INSTALL_REQUIREMENTS = [
'lockfile>=0.12.2',
'markdown>=2.5.2, <3.0',
'markupsafe>=1.1.1, <2.0',
+ 'marshmallow-oneofschema<2',
'pandas>=0.17.1, <2.0',
'pendulum~=2.0',
'pep562~=1.0;python_version<"3.7"',
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 5234d36..9261401 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -14,35 +14,120 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import os
import unittest
+from datetime import datetime
import pytest
+from airflow import DAG
+from airflow.models import DagBag
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.operators.dummy_operator import DummyOperator
from airflow.www import app
+from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
class TestDagEndpoint(unittest.TestCase):
+ dag_id = "test_dag"
+ task_id = "op1"
+
+ @staticmethod
+ def clean_db():
+ clear_db_runs()
+ clear_db_dags()
+ clear_db_serialized_dags()
+
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.app = app.create_app(testing=True) # type:ignore
+ with DAG(cls.dag_id, start_date=datetime(2020, 6, 15), doc_md="details") as dag:
+ DummyOperator(task_id=cls.task_id)
+
+ cls.dag = dag # type:ignore
+ dag_bag = DagBag(os.devnull, include_examples=False)
+ dag_bag.dags = {dag.dag_id: dag}
+ cls.app.dag_bag = dag_bag # type:ignore
+
def setUp(self) -> None:
+ self.clean_db()
self.client = self.app.test_client() # type:ignore
+ def tearDown(self) -> None:
+ self.clean_db()
+
class TestGetDag(TestDagEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
- response = self.client.get("/api/v1/dag/1/")
+ response = self.client.get("/api/v1/dags/1/")
assert response.status_code == 200
class TestGetDagDetails(TestDagEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
- response = self.client.get("/api/v1/dag/TEST_DAG_ID/details")
+ response = self.client.get(f"/api/v1/dags/{self.dag_id}/details")
+ assert response.status_code == 200
+ expected = {
+ 'catchup': True,
+ 'concurrency': 16,
+ 'dag_id': 'test_dag',
+ 'dag_run_timeout': None,
+ 'default_view': 'tree',
+ 'description': None,
+ 'doc_md': 'details',
+ 'fileloc': __file__,
+ 'is_paused': None,
+ 'is_subdag': False,
+ 'orientation': 'LR',
+ 'schedule_interval': {
+ '__type': 'TimeDelta',
+ 'days': 1,
+ 'microseconds': 0,
+ 'seconds': 0
+ },
+ 'start_date': '2020-06-15T00:00:00+00:00',
+ 'tags': None,
+ 'timezone': "Timezone('UTC')"
+ }
+ assert response.json == expected
+
+ def test_should_response_200_serialized(self):
+ # Create empty app with empty dagbag to check if DAG is read from db
+ app_serialized = app.create_app(testing=True) # type:ignore
+ dag_bag = DagBag(os.devnull, include_examples=False, store_serialized_dags=True)
+ app_serialized.dag_bag = dag_bag # type:ignore
+ client = app_serialized.test_client()
+
+ SerializedDagModel.write_dag(self.dag)
+
+ expected = {
+ 'catchup': True,
+ 'concurrency': 16,
+ 'dag_id': 'test_dag',
+ 'dag_run_timeout': None,
+ 'default_view': 'tree',
+ 'description': None,
+ 'doc_md': 'details',
+ 'fileloc': __file__,
+ 'is_paused': None,
+ 'is_subdag': False,
+ 'orientation': 'LR',
+ 'schedule_interval': {
+ '__type': 'TimeDelta',
+ 'days': 1,
+ 'microseconds': 0,
+ 'seconds': 0
+ },
+ 'start_date': '2020-06-15T00:00:00+00:00',
+ 'tags': None,
+ 'timezone': "Timezone('UTC')"
+ }
+ response = client.get(f"/api/v1/dags/{self.dag_id}/details")
assert response.status_code == 200
+ assert response.json == expected
class TestGetDags(TestDagEndpoint):
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py
index ab6b649..92d08ef 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -14,32 +14,169 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import os
import unittest
+from datetime import datetime
-import pytest
-
+from airflow import DAG
+from airflow.models import DagBag
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.operators.dummy_operator import DummyOperator
from airflow.www import app
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
class TestTaskEndpoint(unittest.TestCase):
+ dag_id = "test_dag"
+ task_id = "op1"
+
+ @staticmethod
+ def clean_db():
+ clear_db_runs()
+ clear_db_dags()
+ clear_db_serialized_dags()
+
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.app = app.create_app(testing=True) # type:ignore
+ with DAG(cls.dag_id, start_date=datetime(2020, 6, 15), doc_md="details") as dag:
+ DummyOperator(task_id=cls.task_id)
+
+ cls.dag = dag # type:ignore
+ dag_bag = DagBag(os.devnull, include_examples=False)
+ dag_bag.dags = {dag.dag_id: dag}
+ cls.app.dag_bag = dag_bag # type:ignore
+
def setUp(self) -> None:
+ self.clean_db()
self.client = self.app.test_client() # type:ignore
+ def tearDown(self) -> None:
+ self.clean_db()
+
class TestGetTask(TestTaskEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
- response = self.client.get("/api/v1/dags/TEST_DAG_ID/tasks/TEST_TASK_ID")
+ expected = {
+ "class_ref": {
+ "class_name": "DummyOperator",
+ "module_path": "airflow.operators.dummy_operator",
+ },
+ "depends_on_past": False,
+ "downstream_task_ids": [],
+ "end_date": None,
+ "execution_timeout": None,
+ "extra_links": [],
+ "owner": "airflow",
+ "pool": "default_pool",
+ "pool_slots": 1.0,
+ "priority_weight": 1.0,
+ "queue": "default",
+ "retries": 0.0,
+ "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0},
+ "retry_exponential_backoff": False,
+ "start_date": "2020-06-15T00:00:00+00:00",
+ "task_id": "op1",
+ "template_fields": [],
+ "trigger_rule": "all_success",
+ "ui_color": "#e8f7e4",
+ "ui_fgcolor": "#000",
+ "wait_for_downstream": False,
+ "weight_rule": "downstream",
+ }
+ response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}")
assert response.status_code == 200
+ assert response.json == expected
+
+ @conf_vars({("core", "store_serialized_dags"): "True"})
+ def test_should_response_200_serialized(self):
+ # Create empty app with empty dagbag to check if DAG is read from db
+ app_serialized = app.create_app(testing=True) # type:ignore
+ dag_bag = DagBag(os.devnull, include_examples=False, store_serialized_dags=True)
+ app_serialized.dag_bag = dag_bag # type:ignore
+ client = app_serialized.test_client()
+
+ SerializedDagModel.write_dag(self.dag)
+
+ expected = {
+ "class_ref": {
+ "class_name": "DummyOperator",
+ "module_path": "airflow.operators.dummy_operator",
+ },
+ "depends_on_past": False,
+ "downstream_task_ids": [],
+ "end_date": None,
+ "execution_timeout": None,
+ "extra_links": [],
+ "owner": "airflow",
+ "pool": "default_pool",
+ "pool_slots": 1.0,
+ "priority_weight": 1.0,
+ "queue": "default",
+ "retries": 0.0,
+ "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0},
+ "retry_exponential_backoff": False,
+ "start_date": "2020-06-15T00:00:00+00:00",
+ "task_id": "op1",
+ "template_fields": [],
+ "trigger_rule": "all_success",
+ "ui_color": "#e8f7e4",
+ "ui_fgcolor": "#000",
+ "wait_for_downstream": False,
+ "weight_rule": "downstream",
+ }
+ response = client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}")
+ assert response.status_code == 200
+ assert response.json == expected
+
+ def test_should_response_404(self):
+ task_id = "xxxx_not_existing"
+ response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{task_id}")
+ assert response.status_code == 404
class TestGetTasks(TestTaskEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
- response = self.client.get("/api/v1/dags/TEST_DAG_ID/tasks")
+ expected = {
+ "tasks": [
+ {
+ "class_ref": {
+ "class_name": "DummyOperator",
+ "module_path": "airflow.operators.dummy_operator",
+ },
+ "depends_on_past": False,
+ "downstream_task_ids": [],
+ "end_date": None,
+ "execution_timeout": None,
+ "extra_links": [],
+ "owner": "airflow",
+ "pool": "default_pool",
+ "pool_slots": 1.0,
+ "priority_weight": 1.0,
+ "queue": "default",
+ "retries": 0.0,
+ "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0},
+ "retry_exponential_backoff": False,
+ "start_date": "2020-06-15T00:00:00+00:00",
+ "task_id": "op1",
+ "template_fields": [],
+ "trigger_rule": "all_success",
+ "ui_color": "#e8f7e4",
+ "ui_fgcolor": "#000",
+ "wait_for_downstream": False,
+ "weight_rule": "downstream",
+ }
+ ],
+ "total_entries": 1,
+ }
+ response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks")
assert response.status_code == 200
+ assert response.json == expected
+
+ def test_should_response_404(self):
+ dag_id = "xxxx_not_existing"
+ response = self.client.get(f"/api/v1/dags/{dag_id}/tasks")
+ assert response.status_code == 404
diff --git a/tests/api_connexion/schemas/test_common_schema.py b/tests/api_connexion/schemas/test_common_schema.py
new file mode 100644
index 0000000..d0419b0
--- /dev/null
+++ b/tests/api_connexion/schemas/test_common_schema.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+
+from dateutil import relativedelta
+
+from airflow.api_connexion.schemas.common_schema import (
+ CronExpression, CronExpressionSchema, RelativeDeltaSchema, ScheduleIntervalSchema, TimeDeltaSchema,
+)
+
+
+class TestTimeDeltaSchema(unittest.TestCase):
+ def test_should_serialize(self):
+ instance = datetime.timedelta(days=12)
+ schema_instance = TimeDeltaSchema()
+ result = schema_instance.dump(instance)
+ self.assertEqual(
+ {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0},
+ result.data
+ )
+
+ def test_should_deserialize(self):
+ instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}
+ schema_instance = TimeDeltaSchema()
+ result = schema_instance.load(instance)
+ expected_instance = datetime.timedelta(days=12)
+ self.assertEqual(expected_instance, result.data)
+
+
+class TestRelativeDeltaSchema(unittest.TestCase):
+ def test_should_serialize(self):
+ instance = relativedelta.relativedelta(days=+12)
+ schema_instance = RelativeDeltaSchema()
+ result = schema_instance.dump(instance)
+ self.assertEqual(
+ {
+ '__type': 'RelativeDelta',
+ "day": None,
+ "days": 12,
+ "hour": None,
+ "hours": 0,
+ "leapdays": 0,
+ "microsecond": None,
+ "microseconds": 0,
+ "minute": None,
+ "minutes": 0,
+ "month": None,
+ "months": 0,
+ "second": None,
+ "seconds": 0,
+ "year": None,
+ "years": 0,
+ },
+ result.data,
+ )
+
+ def test_should_deserialize(self):
+ instance = {"__type": "RelativeDelta", "days": 12, "seconds": 0}
+ schema_instance = RelativeDeltaSchema()
+ result = schema_instance.load(instance)
+ expected_instance = relativedelta.relativedelta(days=+12)
+ self.assertEqual(expected_instance, result.data)
+
+
+class TestCronExpressionSchema(unittest.TestCase):
+ def test_should_deserialize(self):
+ instance = {"__type": "CronExpression", "value": "5 4 * * *"}
+ schema_instance = CronExpressionSchema()
+ result = schema_instance.load(instance)
+ expected_instance = CronExpression("5 4 * * *")
+ self.assertEqual(expected_instance, result.data)
+
+
+class TestScheduleIntervalSchema(unittest.TestCase):
+ def test_should_serialize_timedelta(self):
+ instance = datetime.timedelta(days=12)
+ schema_instance = ScheduleIntervalSchema()
+ result = schema_instance.dump(instance)
+ self.assertEqual(
+ {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0},
+ result.data
+ )
+
+ def test_should_deserialize_timedelta(self):
+ instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}
+ schema_instance = ScheduleIntervalSchema()
+ result = schema_instance.load(instance)
+ expected_instance = datetime.timedelta(days=12)
+ self.assertEqual(expected_instance, result.data)
+
+ def test_should_serialize_relative_delta(self):
+ instance = relativedelta.relativedelta(days=+12)
+ schema_instance = ScheduleIntervalSchema()
+ result = schema_instance.dump(instance)
+ self.assertEqual(
+ {
+ "__type": "RelativeDelta",
+ "day": None,
+ "days": 12,
+ "hour": None,
+ "hours": 0,
+ "leapdays": 0,
+ "microsecond": None,
+ "microseconds": 0,
+ "minute": None,
+ "minutes": 0,
+ "month": None,
+ "months": 0,
+ "second": None,
+ "seconds": 0,
+ "year": None,
+ "years": 0,
+ },
+ result.data,
+ )
+
+ def test_should_deserialize_relative_delta(self):
+ instance = {"__type": "RelativeDelta", "days": 12, "seconds": 0}
+ schema_instance = ScheduleIntervalSchema()
+ result = schema_instance.load(instance)
+ expected_instance = relativedelta.relativedelta(days=+12)
+ self.assertEqual(expected_instance, result.data)
+
+ def test_should_serialize_cron_expresssion(self):
+ instance = "5 4 * * *"
+ schema_instance = ScheduleIntervalSchema()
+ result = schema_instance.dump(instance)
+ expected_instance = {"__type": "CronExpression", "value": "5 4 * * *"}
+ self.assertEqual(expected_instance, result.data)
diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py
new file mode 100644
index 0000000..327bce5
--- /dev/null
+++ b/tests/api_connexion/schemas/test_dag_schema.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from datetime import datetime
+
+from airflow import DAG
+from airflow.api_connexion.schemas.dag_schema import (
+ DAGCollection, DAGCollectionSchema, DAGDetailSchema, DAGSchema,
+)
+from airflow.models import DagModel, DagTag
+
+
+class TestDagSchema(unittest.TestCase):
+ def test_serialize(self):
+ dag_model = DagModel(
+ dag_id="test_dag_id",
+ root_dag_id="test_root_dag_id",
+ is_paused=True,
+ is_subdag=False,
+ fileloc="/root/airflow/dags/my_dag.py",
+ owners="airflow1,airflow2",
+ description="The description",
+ schedule_interval="5 4 * * *",
+ tags=[DagTag(name="tag-1"), DagTag(name="tag-2")],
+ )
+ serialized_dag = DAGSchema().dump(dag_model)
+ self.assertEqual(
+ {
+ "dag_id": "test_dag_id",
+ "description": "The description",
+ "fileloc": "/root/airflow/dags/my_dag.py",
+ "is_paused": True,
+ "is_subdag": False,
+ "owners": ["airflow1", "airflow2"],
+ "root_dag_id": "test_root_dag_id",
+ "schedule_interval": {"__type": "CronExpression", "value": "5 4 * * *"},
+ "tags": [{"name": "tag-1"}, {"name": "tag-2"}],
+ },
+ serialized_dag.data,
+ )
+
+
+class TestDAGCollectionSchema(unittest.TestCase):
+ def test_serialize(self):
+ dag_model_a = DagModel(dag_id="test_dag_id_a", fileloc="/tmp/a.py")
+ dag_model_b = DagModel(dag_id="test_dag_id_b", fileloc="/tmp/a.py")
+ schema = DAGCollectionSchema()
+ instance = DAGCollection(dags=[dag_model_a, dag_model_b], total_entries=2)
+ self.assertEqual(
+ {
+ "dags": [
+ {
+ "dag_id": "test_dag_id_a",
+ "description": None,
+ "fileloc": "/tmp/a.py",
+ "is_paused": None,
+ "is_subdag": None,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": None,
+ "tags": [],
+ },
+ {
+ "dag_id": "test_dag_id_b",
+ "description": None,
+ "fileloc": "/tmp/a.py",
+ "is_paused": None,
+ "is_subdag": None,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": None,
+ "tags": [],
+ },
+ ],
+ "total_entries": 2,
+ },
+ schema.dump(instance).data,
+ )
+
+
+class TestDAGDetailSchema:
+ def test_serialize(self):
+ dag = DAG(
+ dag_id="test_dag",
+ start_date=datetime(2020, 6, 19),
+ doc_md="docs",
+ orientation="LR",
+ default_view="duration",
+ )
+ schema = DAGDetailSchema()
+ expected = {
+ 'catchup': True,
+ 'concurrency': 16,
+ 'dag_id': 'test_dag',
+ 'dag_run_timeout': None,
+ 'default_view': 'duration',
+ 'description': None,
+ 'doc_md': 'docs',
+ 'fileloc': __file__,
+ 'is_paused': None,
+ 'is_subdag': False,
+ 'orientation': 'LR',
+ 'schedule_interval': {'__type': 'TimeDelta', 'days': 1, 'seconds': 0, 'microseconds': 0},
+ 'start_date': '2020-06-19T00:00:00+00:00',
+ 'tags': None,
+ 'timezone': "Timezone('UTC')"
+ }
+ assert schema.dump(dag).data == expected
diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py
new file mode 100644
index 0000000..a804869
--- /dev/null
+++ b/tests/api_connexion/schemas/test_task_schema.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema
+from airflow.operators.dummy_operator import DummyOperator
+
+
+class TestTaskSchema:
+ def test_serialize(self):
+ op = DummyOperator(
+ task_id="task_id",
+ start_date=datetime(2020, 6, 16),
+ end_date=datetime(2020, 6, 26),
+ )
+ result = task_schema.dump(op)
+ expected = {
+ "class_ref": {
+ "module_path": "airflow.operators.dummy_operator",
+ "class_name": "DummyOperator",
+ },
+ "depends_on_past": False,
+ "downstream_task_ids": [],
+ "end_date": "2020-06-26T00:00:00+00:00",
+ "execution_timeout": None,
+ "extra_links": [],
+ "owner": "airflow",
+ "pool": "default_pool",
+ "pool_slots": 1.0,
+ "priority_weight": 1.0,
+ "queue": "default",
+ "retries": 0.0,
+ "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0},
+ "retry_exponential_backoff": False,
+ "start_date": "2020-06-16T00:00:00+00:00",
+ "task_id": "task_id",
+ "template_fields": [],
+ "trigger_rule": "all_success",
+ "ui_color": "#e8f7e4",
+ "ui_fgcolor": "#000",
+ "wait_for_downstream": False,
+ "weight_rule": "downstream",
+ }
+ assert expected == result.data
+
+
+class TestTaskCollectionSchema:
+ def test_serialize(self):
+ tasks = [DummyOperator(task_id="task_id1")]
+ collection = TaskCollection(tasks, 1)
+ result = task_collection_schema.dump(collection)
+ expected = {
+ "tasks": [
+ {
+ "class_ref": {
+ "class_name": "DummyOperator",
+ "module_path": "airflow.operators.dummy_operator",
+ },
+ "depends_on_past": False,
+ "downstream_task_ids": [],
+ "end_date": None,
+ "execution_timeout": None,
+ "extra_links": [],
+ "owner": "airflow",
+ "pool": "default_pool",
+ "pool_slots": 1.0,
+ "priority_weight": 1.0,
+ "queue": "default",
+ "retries": 0.0,
+ "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0},
+ "retry_exponential_backoff": False,
+ "start_date": None,
+ "task_id": "task_id1",
+ "template_fields": [],
+ "trigger_rule": "all_success",
+ "ui_color": "#e8f7e4",
+ "ui_fgcolor": "#000",
+ "wait_for_downstream": False,
+ "weight_rule": "downstream",
+ }
+ ],
+ "total_entries": 1,
+ }
+ assert expected == result.data
diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py
index 6ad235b..fa4a32e 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -35,6 +35,7 @@ from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_dags, clear_db_runs
dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1])
@@ -59,8 +60,14 @@ class TestCliDags(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dagbag = DagBag(include_examples=True)
+ cls.dagbag.sync_to_db()
cls.parser = cli_parser.get_parser()
+ @classmethod
+ def tearDownClass(cls) -> None:
+ clear_db_runs()
+ clear_db_dags()
+
@mock.patch("airflow.cli.commands.dag_command.DAG.run")
def test_backfill(self, mock_run):
dag_command.dag_backfill(self.parser.parse_args([
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index e967712..6c2c297 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -20,6 +20,7 @@ from airflow.models import (
XCom, errors,
)
from airflow.models.dagcode import DagCode
+from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections
from airflow.utils.session import create_session
@@ -36,6 +37,11 @@ def clear_db_dags():
session.query(DagModel).delete()
+def clear_db_serialized_dags():
+ with create_session() as session:
+ session.query(SerializedDagModel).delete()
+
+
def clear_db_sla_miss():
with create_session() as session:
session.query(SlaMiss).delete()