You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/01 12:01:31 UTC
[airflow] branch main updated: Create new databases from the ORM (#24156)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5588c3fe6e Create new databases from the ORM (#24156)
5588c3fe6e is described below
commit 5588c3fe6e5641e651d20d08fa43fe508e265d2f
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Mon Aug 1 13:01:07 2022 +0100
Create new databases from the ORM (#24156)
This PR opens up to creating new databases from the ORM instead of going through the migration files.
`airflow db init` creates the new db.
Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
---
.../0080_2_0_2_change_default_pool_slots_to_1.py | 5 +++
airflow/models/base.py | 23 ++++++++--
airflow/models/dag.py | 8 ++--
airflow/models/renderedtifields.py | 10 ++++-
airflow/models/taskinstance.py | 5 ++-
airflow/models/taskreschedule.py | 14 +++++-
airflow/models/xcom.py | 14 +++++-
airflow/utils/db.py | 52 ++++++++++++++++++----
tests/conftest.py | 10 +++++
9 files changed, 121 insertions(+), 20 deletions(-)
diff --git a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
index ef819468ef..70337993aa 100644
--- a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
+++ b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
@@ -44,5 +44,10 @@ def upgrade():
def downgrade():
"""Unapply Change default ``pool_slots`` to ``1``"""
+ conn = op.get_bind()
+ if conn.dialect.name == 'mssql':
+ # DB created from ORM doesn't set a server_default here and MSSQL fails while trying to drop
+ # the non existent server_default. We ignore it for MSSQL
+ return
with op.batch_alter_table("task_instance", schema=None) as batch_op:
batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None)
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 478bd904eb..a0c944cfe3 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -25,9 +25,26 @@ from airflow.configuration import conf
SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA")
-metadata = (
- None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA)
-)
+# For more information about what the tokens in the naming convention
+# below mean, see:
+# https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.MetaData.params.naming_convention
+naming_convention = {
+ "ix": "idx_%(column_0_N_label)s",
+ "uq": "%(table_name)s_%(column_0_N_name)s_uq",
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
+ "fk": "%(table_name)s_%(column_0_name)s_fkey",
+ "pk": "%(table_name)s_pkey",
+}
+
+
+def _get_schema():
+ if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace():
+ return None
+ return SQL_ALCHEMY_SCHEMA
+
+
+metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
+
Base: Any = declarative_base(metadata=metadata)
ID_LEN = 250
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 596f231721..f7bddb6676 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -66,7 +66,7 @@ from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowDagInconsistent, AirflowException, DuplicateTaskIdFound, TaskNotFound
from airflow.models.abstractoperator import AbstractOperator
-from airflow.models.base import ID_LEN, Base
+from airflow.models.base import Base, StringID
from airflow.models.dagbag import DagBag
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
@@ -2788,7 +2788,7 @@ class DagTag(Base):
__tablename__ = "dag_tag"
name = Column(String(TAG_MAX_LEN), primary_key=True)
dag_id = Column(
- String(ID_LEN),
+ StringID(),
ForeignKey('dag.dag_id', name='dag_tag_dag_id_fkey', ondelete='CASCADE'),
primary_key=True,
)
@@ -2804,8 +2804,8 @@ class DagModel(Base):
"""
These items are stored in the database for state related information
"""
- dag_id = Column(String(ID_LEN), primary_key=True)
- root_dag_id = Column(String(ID_LEN))
+ dag_id = Column(StringID(), primary_key=True)
+ root_dag_id = Column(StringID())
# A DAG can be paused from the UI / DB
# Set this default value of is_paused based on a configuration value!
is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation')
diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py
index f1b826c1bc..a98081ea16 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -20,7 +20,7 @@ import os
from typing import Optional
import sqlalchemy_jsonfield
-from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_, text, tuple_
+from sqlalchemy import Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, and_, not_, text, tuple_
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Session, relationship
@@ -46,6 +46,14 @@ class RenderedTaskInstanceFields(Base):
k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
__table_args__ = (
+ PrimaryKeyConstraint(
+ "dag_id",
+ "task_id",
+ "run_id",
+ "map_index",
+ name='rendered_task_instance_fields_pkey',
+ mssql_clustered=True,
+ ),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 7a92febe17..d83ad11b04 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -59,6 +59,7 @@ from sqlalchemy import (
ForeignKeyConstraint,
Index,
Integer,
+ PrimaryKeyConstraint,
String,
and_,
false,
@@ -428,7 +429,6 @@ class TaskInstance(Base, LoggingMixin):
"""
__tablename__ = "task_instance"
-
task_id = Column(StringID(), primary_key=True, nullable=False)
dag_id = Column(StringID(), primary_key=True, nullable=False)
run_id = Column(StringID(), primary_key=True, nullable=False)
@@ -480,6 +480,9 @@ class TaskInstance(Base, LoggingMixin):
Index('ti_pool', pool, state, priority_weight),
Index('ti_job_id', job_id),
Index('ti_trigger_id', trigger_id),
+ PrimaryKeyConstraint(
+ "dag_id", "task_id", "run_id", "map_index", name='task_instance_pkey', mssql_clustered=True
+ ),
ForeignKeyConstraint(
[trigger_id],
['trigger.id'],
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 132554d8d1..48e9178345 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -20,7 +20,7 @@
import datetime
from typing import TYPE_CHECKING
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, text
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
@@ -134,3 +134,15 @@ class TaskReschedule(Base):
return TaskReschedule.query_for_task_instance(
task_instance, session=session, try_number=try_number
).all()
+
+
+@event.listens_for(TaskReschedule.__table__, "before_create")
+def add_ondelete_for_mssql(table, conn, **kw):
+ if conn.dialect.name != "mssql":
+ return
+
+ for constraint in table.constraints:
+ if constraint.name != "task_reschedule_dr_fkey":
+ continue
+ constraint.ondelete = 'NO ACTION'
+ return
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 2d5528ca0b..9970e3c1b8 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -26,7 +26,16 @@ from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union, cast, overload
import pendulum
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, LargeBinary, String, text
+from sqlalchemy import (
+ Column,
+ ForeignKeyConstraint,
+ Index,
+ Integer,
+ LargeBinary,
+ PrimaryKeyConstraint,
+ String,
+ text,
+)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
@@ -73,6 +82,9 @@ class BaseXCom(Base, LoggingMixin):
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
+ PrimaryKeyConstraint(
+ "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True
+ ),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index d4b961bff5..da10e9bd4b 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -661,19 +661,45 @@ def create_default_connections(session: Session = NEW_SESSION):
)
-@provide_session
-def initdb(session: Session = NEW_SESSION):
- """Initialize Airflow database."""
- upgradedb(session=session)
+def _create_db_from_orm(session):
+ from alembic import command
+ from flask import Flask
+ from flask_sqlalchemy import SQLAlchemy
- if conf.getboolean('database', 'LOAD_DEFAULT_CONNECTIONS'):
- create_default_connections(session=session)
+ from airflow.models import Base
+ from airflow.www.fab_security.sqla.models import Model
+ from airflow.www.session import AirflowDatabaseSessionInterface
+
+ def _create_flask_session_tbl():
+ flask_app = Flask(__name__)
+ flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get('database', 'SQL_ALCHEMY_CONN')
+ db = SQLAlchemy(flask_app)
+ AirflowDatabaseSessionInterface(app=flask_app, db=db, table='session', key_prefix='')
+ db.create_all()
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
+ Base.metadata.create_all(settings.engine)
+ Model.metadata.create_all(settings.engine)
+ _create_flask_session_tbl()
+ # stamp the migration head
+ config = _get_alembic_config()
+ command.stamp(config, "head")
- from flask_appbuilder.models.sqla import Base
- Base.metadata.create_all(settings.engine)
+@provide_session
+def initdb(session: Session = NEW_SESSION, load_connections: bool = True):
+ """Initialize Airflow database."""
+ db_exists = _get_current_revision(session)
+ if db_exists:
+ upgradedb(session=session)
+ else:
+ _create_db_from_orm(session=session)
+ # Load default connections
+ if conf.getboolean('database', 'LOAD_DEFAULT_CONNECTIONS') and load_connections:
+ create_default_connections(session=session)
+ # Add default pool & sync log_template
+ add_default_pool_if_not_exists()
+ synchronize_log_template()
def _get_alembic_config():
@@ -1487,6 +1513,11 @@ def upgradedb(
if errors_seen:
exit(1)
+ if not to_revision and not _get_current_revision(session=session):
+ # Don't load default connections
+ # New DB; initialize and exit
+ initdb(session=session, load_connections=False)
+ return
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
log.info("Creating tables")
command.upgrade(config, revision=to_revision or 'heads')
@@ -1699,7 +1730,10 @@ def compare_type(context, inspected_column, metadata_column, inspected_type, met
if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String):
# This is a hack to get around MySQL VARCHAR collation
- # not being possible to change from utf8_bin to utf8mb3_bin
+ # not being possible to change from utf8_bin to utf8mb3_bin.
+ # We only make sure lengths are the same
+ if inspected_type.length != metadata_type.length:
+ return True
return False
return None
diff --git a/tests/conftest.py b/tests/conftest.py
index b153c213d5..e6329e1991 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -193,10 +193,20 @@ def pytest_addoption(parser):
def initial_db_init():
+ from flask import Flask
+
+ from airflow.configuration import conf
from airflow.utils import db
+ from airflow.www.app import sync_appbuilder_roles
+ from airflow.www.extensions.init_appbuilder import init_appbuilder
db.resetdb()
db.bootstrap_dagbag()
+ # minimal app to add roles
+ flask_app = Flask(__name__)
+ flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get('database', 'SQL_ALCHEMY_CONN')
+ init_appbuilder(flask_app)
+ sync_appbuilder_roles(flask_app)
@pytest.fixture(autouse=True, scope="session")