You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/03/22 19:16:24 UTC
[airflow] 13/31: Disable default_pool delete on web ui (#21658)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7d1abed68f95fdc0de1dea837533e9ca790cf3ae
Author: Chenglong Yan <al...@gmail.com>
AuthorDate: Wed Mar 16 16:31:12 2022 +0800
Disable default_pool delete on web ui (#21658)
(cherry picked from commit df6058c862a910a99fbb86858502d9d93fdbe1e5)
---
airflow/api/common/experimental/pool.py | 2 +-
airflow/models/pool.py | 19 ++++++++++++++++++-
airflow/www/views.py | 13 ++++++++++++-
tests/models/test_pool.py | 6 ++++++
4 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py
index fe4f161..b1ca9f0 100644
--- a/airflow/api/common/experimental/pool.py
+++ b/airflow/api/common/experimental/pool.py
@@ -83,7 +83,7 @@ def delete_pool(name, session=None):
raise AirflowBadRequest("Pool name shouldn't be empty")
if name == Pool.DEFAULT_POOL_NAME:
- raise AirflowBadRequest("default_pool cannot be deleted")
+ raise AirflowBadRequest(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 8ae88aa..7d092f7 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -86,6 +86,23 @@ class Pool(Base):
@staticmethod
@provide_session
+ def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
+ """
+ Check id if is the default_pool.
+
+ :param id: pool id
+ :param session: SQLAlchemy ORM Session
+ :return: True if id is default_pool, otherwise False
+ """
+ return (
+ session.query(func.count(Pool.id))
+ .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME)
+ .scalar()
+ > 0
+ )
+
+ @staticmethod
+ @provide_session
def create_or_update_pool(name: str, slots: int, description: str, session: Session = NEW_SESSION):
"""Create a pool with given parameters or update it if it already exists."""
if not name:
@@ -107,7 +124,7 @@ class Pool(Base):
def delete_pool(name: str, session: Session = NEW_SESSION):
"""Delete pool by a given name."""
if name == Pool.DEFAULT_POOL_NAME:
- raise AirflowException("default_pool cannot be deleted")
+ raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
pool = session.query(Pool).filter_by(pool=name).first()
if pool is None:
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 9fc61b5..d6bc40b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -3727,13 +3727,24 @@ class PoolModelView(AirflowModelView):
def action_muldelete(self, items):
"""Multiple delete."""
if any(item.pool == models.Pool.DEFAULT_POOL_NAME for item in items):
- flash("default_pool cannot be deleted", 'error')
+ flash(f"{models.Pool.DEFAULT_POOL_NAME} cannot be deleted", 'error')
self.update_redirect()
return redirect(self.get_redirect())
self.datamodel.delete_all(items)
self.update_redirect()
return redirect(self.get_redirect())
+ @expose("/delete/<pk>", methods=["GET", "POST"])
+ @has_access
+ def delete(self, pk):
+ """Single delete."""
+ if models.Pool.is_default_pool(pk):
+ flash(f"{models.Pool.DEFAULT_POOL_NAME} cannot be deleted", 'error')
+ self.update_redirect()
+ return redirect(self.get_redirect())
+
+ return super().delete(pk)
+
def pool_link(self):
"""Pool link rendering."""
pool_id = self.get('pool')
diff --git a/tests/models/test_pool.py b/tests/models/test_pool.py
index 95e585e..1c5bbe1 100644
--- a/tests/models/test_pool.py
+++ b/tests/models/test_pool.py
@@ -220,3 +220,9 @@ class TestPool:
def test_delete_default_pool_not_allowed(self):
with pytest.raises(AirflowException, match="^default_pool cannot be deleted$"):
Pool.delete_pool(Pool.DEFAULT_POOL_NAME)
+
+ def test_is_default_pool(self):
+ pool = Pool.create_or_update_pool(name="not_default_pool", slots=1, description="test")
+ default_pool = Pool.get_default_pool()
+ assert not Pool.is_default_pool(id=pool.id)
+ assert Pool.is_default_pool(str(default_pool.id))