You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/10 11:19:22 UTC

[GitHub] ashb closed pull request #2684: [AIRFLOW-1002] Add ability to remove DAG and all dependencies

ashb closed pull request #2684: [AIRFLOW-1002] Add ability to remove DAG and all dependencies
URL: https://github.com/apache/incubator-airflow/pull/2684
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py
index f24d80945f..1122bbd1d2 100644
--- a/airflow/api/client/api_client.py
+++ b/airflow/api/client/api_client.py
@@ -32,6 +32,15 @@ def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
         """
         raise NotImplementedError()
 
+    def delete_dag(self, dag_id):
+        """
+        Creates a dag run for the specified dag
+        :param dag_id:
+        :param conf:
+        :return:
+        """
+        raise NotImplementedError()
+
     def get_pool(self, name):
         """Get pool.
 
diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py
index 37e24d3c4e..789f969cbb 100644
--- a/airflow/api/client/json_client.py
+++ b/airflow/api/client/json_client.py
@@ -50,6 +50,19 @@ def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
                              })
         return data['message']
 
+    def delete_dag(self, dag_id):
+        endpoint = '/api/experimental/dags/{}'.format(dag_id)
+        url = urljoin(self._api_base_url, endpoint)
+
+        resp = requests.delete(url, auth=self._auth)
+
+        if not resp.ok:
+            raise IOError()
+
+        data = resp.json()
+
+        return data['message']
+
     def get_pool(self, name):
         endpoint = '/api/experimental/pools/{}'.format(name)
         url = urljoin(self._api_base_url, endpoint)
diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py
index 5bc7f76aaa..e2fae9590d 100644
--- a/airflow/api/client/local_client.py
+++ b/airflow/api/client/local_client.py
@@ -14,7 +14,7 @@
 
 from airflow.api.client import api_client
 from airflow.api.common.experimental import pool
-from airflow.api.common.experimental import trigger_dag
+from airflow.api.common.experimental import trigger_dag, delete_dag
 
 
 class Client(api_client.Client):
@@ -27,6 +27,10 @@ def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
                                      execution_date=execution_date)
         return "Created {}".format(dr)
 
+    def delete_dag(self, dag_id):
+        dr = delete_dag.delete_dag(dag_id)
+        return "Deleted {}".format(dr)
+
     def get_pool(self, name):
         p = pool.get_pool(name=name)
         return p.pool, p.slots, p.description
diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py
new file mode 100644
index 0000000000..3cad1b987a
--- /dev/null
+++ b/airflow/api/common/experimental/delete_dag.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow import AirflowException
+from airflow.models import DagBag, DAG
+
+
+def delete_dag(dag_id, dag_bag=None):
+    """
+
+    :param dag: DAG to be deleted
+    :param session: orm session
+    :return: Returns true if dagrun is scheduled and successfully deleted
+             Returns false if dagrun does not exist
+    """
+
+    if dag_id is None:
+        return False
+
+    dag_bag = dag_bag or DagBag()
+
+    if dag_id not in dag_bag.dags:
+        print ("dag not found")
+        raise AirflowException("Dag id {} not found".format(dag_id))
+
+    dag = dag_bag.get_dag(dag_id)
+
+    alldags = [dag]
+    alldags.extend(dag.subdags)
+
+    for dag in alldags:
+        dag.delete()
+
+    return True
+
+def check_delete_dag(dag_id):
+    return DAG.find_deleted_entities(dag_id)
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 2675bd3167..9fd54422ac 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -2465,6 +2465,8 @@ def __init__(
             pool=None,
             *args, **kwargs):
         self.task_instance = task_instance
+        self.dag_id = self.task_instance.dag_id
+
         self.ignore_all_deps = ignore_all_deps
         self.ignore_depends_on_past = ignore_depends_on_past
         self.ignore_task_deps = ignore_task_deps
diff --git a/airflow/models.py b/airflow/models.py
index e3c52b5b72..8afcf84206 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -44,6 +44,9 @@
 import traceback
 import warnings
 import hashlib
+import multiprocessing
+import threading
+
 from urllib.parse import urlparse
 
 from sqlalchemy import (
@@ -113,6 +116,8 @@ def get_fernet():
 # Used by DAG context_managers
 _CONTEXT_MANAGER_DAG = None
 
+TEST_DAG_FOLDER = os.path.join(
+    os.path.dirname(os.path.realpath(__file__)), 'dags')
 
 def clear_task_instances(tis, session, activate_dag_runs=True, dag=None):
     """
@@ -3370,7 +3375,7 @@ def clear(
                 )
             tis = tis.filter(or_(*conditions))
         else:
-            tis = session.query(TI).filter(TI.dag_id == self.dag_id)
+            tis = session.query(TI).filter(TI.dag_id.in_(self.dag_id))
             tis = tis.filter(TI.task_id.in_(self.task_ids))
 
         if start_date:
@@ -3696,6 +3701,178 @@ def cli(self):
         args = parser.parse_args()
         args.func(args, self)
 
+    @staticmethod
+    @provide_session
+    def find_runs(dag_id, session=None, states=[], task_id=None):
+        """
+
+        :param dag_id: id of dag to find
+        :param session: orm session
+        :return: Returns None if not found else return dagruns
+        """
+        if dag_id is None:
+            return False
+
+        if task_id is None:
+            if states:
+                return session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.state.in_(states)).all()
+            else:
+                return session.query(DagRun).filter(DagRun.dag_id == dag_id).all()
+        else:
+            if states:
+                return session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == task_id, DagRun.state.in_(states)).all()
+            else:
+                return session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == task_id).all()
+
+    @staticmethod
+    @provide_session
+    def find_tis(dag_id, session=None, states=[], task_id=None):
+        """
+
+        :param dag_id: id of dag to find
+        :param session: orm session
+        :return: Returns None if not found else return task instances
+
+        """
+        if dag_id is None:
+            return False
+        if states is None or len(states) == 0:
+            if task_id is None:
+                task_instances = (
+                    session
+                        .query(TaskInstance)
+                        .filter(TaskInstance.dag_id.in_(dag_id))
+                        .all()
+                )
+            else:
+                task_instances = (
+                    session
+                        .query(TaskInstance)
+                        .filter(TaskInstance.dag_id.in_(dag_id))
+                        .filter(TaskInstance.task_id.in_(task_id))
+                        .all()
+                )
+        else:
+            if task_id is None:
+                task_instances = (
+                    session
+                        .query(TaskInstance)
+                        .filter(TaskInstance.dag_id.in_(dag_id))
+                        .filter(TaskInstance.state.in_(states))
+                        .all()
+                )
+            else:
+                task_instances = (
+                    session
+                        .query(TaskInstance)
+                        .filter(TaskInstance.dag_id.in_(dag_id))
+                        .filter(TaskInstance.state.in_(states))
+                        .filter(TaskInstance.task_id.in_(task_id))
+                        .all()
+                    )
+
+
+        return task_instances
+
+    @staticmethod
+    @provide_session
+    def find_jobs(dag_id, session=None):
+        from airflow.jobs import BaseJob as BJ
+        jobs = (
+            session.query(BJ).filter(BJ.dag_id == dag_id).all()
+        )
+        return jobs
+
+    @staticmethod
+    @provide_session
+    def find_models(dag_id, session=None, task_id=None):
+        if task_id is None:
+            return (
+                session
+                .query(DagModel)
+                .filter(DagModel.dag_id.in_(dag_id))
+                .all()
+            )
+        else:
+            return (
+                session
+                    .query(DagModel)
+                    .filter(DagModel.dag_id.in_(dag_id), DagModel.task_id.in_(task_id))
+                    .all()
+            )
+
+    @staticmethod
+    @provide_session
+    def find_deleted_entities(dag_id, session=None, task_id=None):
+        return (
+            DAG.find_models(dag_id, session, task_id),
+            DAG.find_tis(dag_id, session, task_id),
+            DAG.find_jobs(dag_id, session),
+            DAG.find_runs(dag_id, session, task_id)
+        )
+
+    @provide_session
+    def delete_runs(self, dag_id, session=None, states=[], task_id=None):
+        active_dag_runs = self.find_runs(dag_id, session, states, task_id)
+        for active_dag_run in active_dag_runs:
+            session.delete(active_dag_run)
+            session.commit()
+
+    def shutdown_jobs_if_running(self, jobs):
+        for j in jobs:
+            if j.state == State.RUNNING:
+                j.state = State.SHUTDOWN
+
+    @provide_session
+    def pause(self, session=None, dag_id=None, task_id=None):
+
+        if dag_id is None:
+            return
+        if task_id is None:
+            dm = session.query(DagModel).filter(
+                DagModel.dag_id == dag_id).first()
+            if dm:
+                dm.is_paused = True
+                session.commit()
+
+            msg = "Dag: {}, paused by delete: {}".format(self, str(self.is_paused))
+            print(msg)
+
+    @provide_session
+    def delete_rest(self, dag_id, session=None):
+        from airflow.jobs import BaseJob
+        for t in [DagModel, DagStat, BaseJob, Log, SlaMiss, TaskFail, XCom]:
+            session.query(t).filter(t.dag_id == dag_id).delete()
+            session.commit()
+
+
+    @provide_session
+    def delete(self,
+               session=None):
+
+        dag_id = self.dag_id
+        task_id = None
+
+        if self.is_subdag:
+            did, tid = self.dag_id.rsplit(".", 1)
+            print ("Deleting dag with dag_id={0}, task_id={1}".format(did,tid))
+            dag_id = did
+            task_id = tid
+
+        self.pause(session, dag_id, task_id)
+
+        # delete running dags first
+        self.delete_runs(dag_id, session, [State.RUNNING], task_id)
+        # then all others
+        self.delete_runs(dag_id, session, [], task_id)
+
+        # delete running instances first
+        self.clear(only_running=True, include_subdags=True)
+        # delete rest
+        self.clear(include_subdags=True)
+
+        self.delete_rest(dag_id, session)
+
     @provide_session
     def create_dagrun(self,
                       run_id,
diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py
index b5a30524cc..747d02b523 100644
--- a/airflow/www/api/experimental/endpoints.py
+++ b/airflow/www/api/experimental/endpoints.py
@@ -15,12 +15,15 @@
 
 from airflow.api.common.experimental import pool as pool_api
 from airflow.api.common.experimental import trigger_dag as trigger
+from airflow.api.common.experimental import delete_dag as delete
+
 from airflow.api.common.experimental.get_task import get_task
 from airflow.api.common.experimental.get_task_instance import get_task_instance
 from airflow.exceptions import AirflowException
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.www.app import csrf
 
+
 from flask import (
     g, Markup, Blueprint, redirect, jsonify, abort,
     request, current_app, send_file, url_for
@@ -85,6 +88,28 @@ def trigger_dag(dag_id):
     response = jsonify(message="Created {}".format(dr))
     return response
 
+@csrf.exempt
+@api_experimental.route('/dags/<string:dag_id>', methods=['DELETE'])
+@requires_authentication
+def delete_dag(dag_id):
+    """
+    Trigger a new dag run for a Dag with an execution date of now unless
+    specified in the data.
+    """
+    try:
+        dd = delete.delete_dag(dag_id)
+    except AirflowException as err:
+        _log.error(err)
+        response = jsonify(error="{}".format(err))
+        response.status_code = 404
+        return response
+
+    if getattr(g, 'user', None):
+        _log.info("User {} deleted {}".format(g.user, dd))
+
+    response = jsonify(message="Deleted {}".format(dd))
+    response.status_code = 204
+    return response
 
 @api_experimental.route('/test', methods=['GET'])
 @requires_authentication
diff --git a/tests/core.py b/tests/core.py
index 0c94137d15..6a055e42bb 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -36,8 +36,11 @@
 import sqlalchemy
 
 from airflow import configuration
+from airflow.api.common.experimental.delete_dag import check_delete_dag
+
 from airflow.executors import SequentialExecutor
-from airflow.models import Variable
+from airflow.models import Variable, DagBag
+
 from tests.test_utils.fake_datetime import FakeDatetime
 
 configuration.load_test_config()
@@ -64,6 +67,7 @@
 from jinja2 import UndefinedError
 
 import six
+from airflow.api.common.experimental import delete_dag as delete
 
 NUM_EXAMPLE_DAGS = 18
 DEV_NULL = '/dev/null'
@@ -224,6 +228,15 @@ def test_fractional_seconds(self):
         self.assertEqual(start_date, run.start_date,
                          "dag run start_date loses precision ")
 
+    def test_delete_invalid_dag(self):
+        """
+        Tests that an empty dag cannot be deleted
+        :return:
+        """
+        dag_delete_status = delete.delete_dag(None)
+        assert dag_delete_status is False
+
+
     def test_schedule_dag_start_end_dates(self):
         """
         Tests that an attempt to schedule a task after the Dag's end_date
diff --git a/tests/dags/test_delete_dag.py b/tests/dags/test_delete_dag.py
new file mode 100644
index 0000000000..ededb33cfa
--- /dev/null
+++ b/tests/dags/test_delete_dag.py
@@ -0,0 +1,48 @@
+"""
+Code that goes along with the Airflow tutorial located at:
+https://github.com/airbnb/airflow/blob/master/airflow/example_dags/tutorial.py
+"""
+from airflow import DAG
+from airflow.operators.bash_operator import BashOperator
+from airflow.operators.python_operator import PythonOperator
+from datetime import datetime, timedelta
+
+
+default_args = {
+    'depends_on_past': False,
+    'start_date': datetime.now(),
+    'email': ['airflow@airflow.com'],
+    'email_on_failure': False,
+    'email_on_retry': False,
+    'retries': 1,
+    'retry_delay': timedelta(minutes=5)
+    # 'queue': 'bash_queue',
+    # 'pool': 'backfill',
+    # 'priority_weight': 10,
+    # 'end_date': datetime(2016, 1, 1),
+}
+
+delta = timedelta(seconds=3)
+dag = DAG('test_delete_dag', default_args=default_args, schedule_interval=delta)
+
+# t1, t2 and t3 are examples of tasks created by instantiating operators
+t1 = BashOperator(
+    task_id='print_date',
+    bash_command='date',
+    dag=dag)
+
+def py_callable(*args, **kwargs):
+    print "args = "
+    print args
+    print "kwargs = "
+    print kwargs
+
+t3 = PythonOperator(
+    task_id='py_callable',
+    python_callable=py_callable,
+    op_args=['dogs'],
+    op_kwargs={'cats': 20},
+    provide_context=True,
+    dag=dag)
+
+t3.set_upstream(t1)
diff --git a/tests/dags/test_delete_subdag.py b/tests/dags/test_delete_subdag.py
new file mode 100644
index 0000000000..1921191619
--- /dev/null
+++ b/tests/dags/test_delete_subdag.py
@@ -0,0 +1,85 @@
+"""
+Code that goes along with the Airflow tutorial located at:
+https://github.com/airbnb/airflow/blob/master/airflow/example_dags/tutorial.py
+"""
+from airflow.operators.dummy_operator import DummyOperator
+
+from airflow import DAG
+from airflow.operators.bash_operator import BashOperator
+from airflow.operators.python_operator import PythonOperator
+from datetime import datetime, timedelta
+
+
+default_args = {
+    'depends_on_past': False,
+    'start_date': datetime.now(),
+    'email': ['airflow@airflow.com'],
+    'email_on_failure': False,
+    'email_on_retry': False,
+    'retries': 1,
+    'retry_delay': timedelta(minutes=5)
+    # 'queue': 'bash_queue',
+    # 'pool': 'backfill',
+    # 'priority_weight': 10,
+    # 'end_date': datetime(2016, 1, 1),
+}
+
+delta = timedelta(seconds=3)
+
+def sub_dag(parent_dag_name, child_dag_name, start_date, schedule_interval):
+  dag = DAG(
+    '%s.%s' % (parent_dag_name, child_dag_name),
+    default_args=default_args,
+    schedule_interval=schedule_interval,
+    start_date=start_date,
+  )
+
+  dummy_operator = DummyOperator(
+    task_id='dummy_task',
+    dag=dag,
+  )
+
+  return dag
+
+from airflow.operators.subdag_operator import SubDagOperator
+
+
+PARENT_DAG_NAME = 'test_delete_subdag'
+CHILD_DAG_NAME = 'test_delete_subdag_child'
+
+main_dag = DAG(
+  dag_id=PARENT_DAG_NAME,
+  default_args=default_args,
+  schedule_interval=timedelta(seconds=3),
+  start_date=datetime(2016, 1, 1)
+)
+
+sub_dag = SubDagOperator(
+  subdag=sub_dag(PARENT_DAG_NAME, CHILD_DAG_NAME, main_dag.start_date,
+                 main_dag.schedule_interval),
+  task_id=CHILD_DAG_NAME,
+  dag=main_dag,
+)
+
+# t1, t2 and t3 are examples of tasks created by instantiating operators
+t1 = BashOperator(
+    task_id='print_date',
+    bash_command='date',
+    dag=main_dag)
+
+def py_callable(*args, **kwargs):
+    print "args = "
+    print args
+    print "kwargs = "
+    print kwargs
+
+t3 = PythonOperator(
+    task_id='py_callable',
+    python_callable=py_callable,
+    op_args=['dogs'],
+    op_kwargs={'cats': 20},
+    provide_context=True,
+    dag=main_dag)
+
+t3.set_upstream(t1)
+
diff --git a/tests/jobs.py b/tests/jobs.py
index ba08fd62f8..26cfe02572 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -54,6 +54,8 @@
 
 import sqlalchemy
 
+from airflow.api.common.experimental import delete_dag as delete
+
 try:
     from unittest import mock
 except ImportError:
@@ -2609,6 +2611,76 @@ def test_dag_catchup_option(self):
         # The DR should be scheduled BEFORE now
         self.assertLess(dr.execution_date, datetime.datetime.now())
 
+    def test_delete_dag_after_schedule_dag(self):
+        """
+        Tests that deletion on scheduled dag successfully deletes the dag
+        :return:
+        """
+        dagbag = DagBag(dag_folder=TEST_DAG_FOLDER, executor=SequentialExecutor, include_examples=False)
+        dag = dagbag.get_dag('test_delete_dag')
+
+        session = settings.Session
+
+        orm_dag = DagModel(dag_id=dag.dag_id)
+        session.merge(orm_dag)
+        session.commit()
+        session.close()
+
+
+        scheduler = SchedulerJob(num_runs=1, **self.default_scheduler_args)
+        scheduler.create_dag_run(dag)
+
+        #scheduler.run()
+
+        delete.delete_dag(dag.dag_id, dagbag)
+
+        (_models, _tis, _jobs, _runs) = delete.check_delete_dag(dag.dag_id)
+
+        assert len(_models) is 0
+        assert len(_tis) is 0
+        assert len(_runs) is 0
+
+        for j in _jobs:
+            assert j.state is State.SHUTDOWN
+
+    def test_delete_subdag_after_schedule_dag(self):
+        """
+        Tests that deletion on scheduled dag successfully deletes the dag
+        :return:
+        """
+
+        dagbag = DagBag(dag_folder=TEST_DAG_FOLDER, executor=SequentialExecutor, include_examples=False)
+
+        dag = dagbag.get_dag('test_delete_subdag')
+
+        session = settings.Session
+
+        orm_dag = DagModel(dag_id=dag.dag_id)
+        session.merge(orm_dag)
+        session.commit()
+        session.close()
+
+        scheduler = SchedulerJob(num_runs=1, **self.default_scheduler_args)
+        scheduler.create_dag_run(dag)
+
+        scheduler.run()
+
+        delete.delete_dag(dag.dag_id, dagbag)  # delete parent dag
+
+        # check parent dag and subdags
+        all_dag_ids = [dag.dag_id]
+        all_dag_ids.extend([ sd.dag_id for sd in dag.subdags])
+
+        for dagid in all_dag_ids:
+            (_models, _tis, _jobs, _runs) = delete.check_delete_dag(dagid)
+
+            assert len(_models) is 0
+            assert len(_tis) is 0
+            assert len(_runs) is 0
+
+            for j in _jobs:
+                assert j.state is State.SHUTDOWN
+
     def test_add_unparseable_file_before_sched_start_creates_import_error(self):
         try:
             dags_folder = mkdtemp()
diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py
index 65a6f75864..141b27fa13 100644
--- a/tests/www/api/experimental/test_endpoints.py
+++ b/tests/www/api/experimental/test_endpoints.py
@@ -129,6 +129,16 @@ def test_trigger_dag_for_date(self):
         )
         self.assertEqual(400, response.status_code)
 
+    def test_delete_nonexisting_dag(self):
+        url_template = '/api/experimental/dags/{}'
+        response = self.app.delete(
+            url_template.format('does_not_exist_dag'),
+            data=json.dumps(dict({})),
+            content_type="application/json"
+        )
+
+        self.assertEqual(404, response.status_code)
+
     def test_task_instance_info(self):
         url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}'
         dag_id = 'example_bash_operator'


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services