You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/01 12:01:31 UTC

[airflow] branch main updated: Create new databases from the ORM (#24156)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5588c3fe6e Create new databases from the ORM (#24156)
5588c3fe6e is described below

commit 5588c3fe6e5641e651d20d08fa43fe508e265d2f
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Mon Aug 1 13:01:07 2022 +0100

    Create new databases from the ORM (#24156)
    
    This PR opens up to creating new databases from the ORM instead of going through the migration files.
    
    `airflow db init` creates the new db.
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
---
 .../0080_2_0_2_change_default_pool_slots_to_1.py   |  5 +++
 airflow/models/base.py                             | 23 ++++++++--
 airflow/models/dag.py                              |  8 ++--
 airflow/models/renderedtifields.py                 | 10 ++++-
 airflow/models/taskinstance.py                     |  5 ++-
 airflow/models/taskreschedule.py                   | 14 +++++-
 airflow/models/xcom.py                             | 14 +++++-
 airflow/utils/db.py                                | 52 ++++++++++++++++++----
 tests/conftest.py                                  | 10 +++++
 9 files changed, 121 insertions(+), 20 deletions(-)

diff --git a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
index ef819468ef..70337993aa 100644
--- a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
+++ b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
@@ -44,5 +44,10 @@ def upgrade():
 
 def downgrade():
     """Unapply Change default ``pool_slots`` to ``1``"""
+    conn = op.get_bind()
+    if conn.dialect.name == 'mssql':
+        # DB created from ORM doesn't set a server_default here and MSSQL fails while trying to drop
+        # the non existent server_default. We ignore it for MSSQL
+        return
     with op.batch_alter_table("task_instance", schema=None) as batch_op:
         batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None)
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 478bd904eb..a0c944cfe3 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -25,9 +25,26 @@ from airflow.configuration import conf
 
 SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA")
 
-metadata = (
-    None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA)
-)
+# For more information about what the tokens in the naming convention
+# below mean, see:
+# https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.MetaData.params.naming_convention
+naming_convention = {
+    "ix": "idx_%(column_0_N_label)s",
+    "uq": "%(table_name)s_%(column_0_N_name)s_uq",
+    "ck": "ck_%(table_name)s_%(constraint_name)s",
+    "fk": "%(table_name)s_%(column_0_name)s_fkey",
+    "pk": "%(table_name)s_pkey",
+}
+
+
+def _get_schema():
+    if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace():
+        return None
+    return SQL_ALCHEMY_SCHEMA
+
+
+metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
+
 Base: Any = declarative_base(metadata=metadata)
 
 ID_LEN = 250
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 596f231721..f7bddb6676 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -66,7 +66,7 @@ from airflow.compat.functools import cached_property
 from airflow.configuration import conf
 from airflow.exceptions import AirflowDagInconsistent, AirflowException, DuplicateTaskIdFound, TaskNotFound
 from airflow.models.abstractoperator import AbstractOperator
-from airflow.models.base import ID_LEN, Base
+from airflow.models.base import Base, StringID
 from airflow.models.dagbag import DagBag
 from airflow.models.dagcode import DagCode
 from airflow.models.dagpickle import DagPickle
@@ -2788,7 +2788,7 @@ class DagTag(Base):
     __tablename__ = "dag_tag"
     name = Column(String(TAG_MAX_LEN), primary_key=True)
     dag_id = Column(
-        String(ID_LEN),
+        StringID(),
         ForeignKey('dag.dag_id', name='dag_tag_dag_id_fkey', ondelete='CASCADE'),
         primary_key=True,
     )
@@ -2804,8 +2804,8 @@ class DagModel(Base):
     """
     These items are stored in the database for state related information
     """
-    dag_id = Column(String(ID_LEN), primary_key=True)
-    root_dag_id = Column(String(ID_LEN))
+    dag_id = Column(StringID(), primary_key=True)
+    root_dag_id = Column(StringID())
     # A DAG can be paused from the UI / DB
     # Set this default value of is_paused based on a configuration value!
     is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation')
diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py
index f1b826c1bc..a98081ea16 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -20,7 +20,7 @@ import os
 from typing import Optional
 
 import sqlalchemy_jsonfield
-from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_, text, tuple_
+from sqlalchemy import Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, and_, not_, text, tuple_
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import Session, relationship
 
@@ -46,6 +46,14 @@ class RenderedTaskInstanceFields(Base):
     k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
 
     __table_args__ = (
+        PrimaryKeyConstraint(
+            "dag_id",
+            "task_id",
+            "run_id",
+            "map_index",
+            name='rendered_task_instance_fields_pkey',
+            mssql_clustered=True,
+        ),
         ForeignKeyConstraint(
             [dag_id, task_id, run_id, map_index],
             [
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 7a92febe17..d83ad11b04 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -59,6 +59,7 @@ from sqlalchemy import (
     ForeignKeyConstraint,
     Index,
     Integer,
+    PrimaryKeyConstraint,
     String,
     and_,
     false,
@@ -428,7 +429,6 @@ class TaskInstance(Base, LoggingMixin):
     """
 
     __tablename__ = "task_instance"
-
     task_id = Column(StringID(), primary_key=True, nullable=False)
     dag_id = Column(StringID(), primary_key=True, nullable=False)
     run_id = Column(StringID(), primary_key=True, nullable=False)
@@ -480,6 +480,9 @@ class TaskInstance(Base, LoggingMixin):
         Index('ti_pool', pool, state, priority_weight),
         Index('ti_job_id', job_id),
         Index('ti_trigger_id', trigger_id),
+        PrimaryKeyConstraint(
+            "dag_id", "task_id", "run_id", "map_index", name='task_instance_pkey', mssql_clustered=True
+        ),
         ForeignKeyConstraint(
             [trigger_id],
             ['trigger.id'],
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 132554d8d1..48e9178345 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -20,7 +20,7 @@
 import datetime
 from typing import TYPE_CHECKING
 
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, text
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import relationship
 
@@ -134,3 +134,15 @@ class TaskReschedule(Base):
         return TaskReschedule.query_for_task_instance(
             task_instance, session=session, try_number=try_number
         ).all()
+
+
+@event.listens_for(TaskReschedule.__table__, "before_create")
+def add_ondelete_for_mssql(table, conn, **kw):
+    if conn.dialect.name != "mssql":
+        return
+
+    for constraint in table.constraints:
+        if constraint.name != "task_reschedule_dr_fkey":
+            continue
+        constraint.ondelete = 'NO ACTION'
+        return
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 2d5528ca0b..9970e3c1b8 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -26,7 +26,16 @@ from functools import wraps
 from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union, cast, overload
 
 import pendulum
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, LargeBinary, String, text
+from sqlalchemy import (
+    Column,
+    ForeignKeyConstraint,
+    Index,
+    Integer,
+    LargeBinary,
+    PrimaryKeyConstraint,
+    String,
+    text,
+)
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import Query, Session, reconstructor, relationship
 from sqlalchemy.orm.exc import NoResultFound
@@ -73,6 +82,9 @@ class BaseXCom(Base, LoggingMixin):
         # separately, and enforce uniqueness with DagRun.id instead.
         Index("idx_xcom_key", key),
         Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
+        PrimaryKeyConstraint(
+            "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True
+        ),
         ForeignKeyConstraint(
             [dag_id, task_id, run_id, map_index],
             [
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index d4b961bff5..da10e9bd4b 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -661,19 +661,45 @@ def create_default_connections(session: Session = NEW_SESSION):
     )
 
 
-@provide_session
-def initdb(session: Session = NEW_SESSION):
-    """Initialize Airflow database."""
-    upgradedb(session=session)
+def _create_db_from_orm(session):
+    from alembic import command
+    from flask import Flask
+    from flask_sqlalchemy import SQLAlchemy
 
-    if conf.getboolean('database', 'LOAD_DEFAULT_CONNECTIONS'):
-        create_default_connections(session=session)
+    from airflow.models import Base
+    from airflow.www.fab_security.sqla.models import Model
+    from airflow.www.session import AirflowDatabaseSessionInterface
+
+    def _create_flask_session_tbl():
+        flask_app = Flask(__name__)
+        flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get('database', 'SQL_ALCHEMY_CONN')
+        db = SQLAlchemy(flask_app)
+        AirflowDatabaseSessionInterface(app=flask_app, db=db, table='session', key_prefix='')
+        db.create_all()
 
     with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
+        Base.metadata.create_all(settings.engine)
+        Model.metadata.create_all(settings.engine)
+        _create_flask_session_tbl()
+        # stamp the migration head
+        config = _get_alembic_config()
+        command.stamp(config, "head")
 
-        from flask_appbuilder.models.sqla import Base
 
-        Base.metadata.create_all(settings.engine)
+@provide_session
+def initdb(session: Session = NEW_SESSION, load_connections: bool = True):
+    """Initialize Airflow database."""
+    db_exists = _get_current_revision(session)
+    if db_exists:
+        upgradedb(session=session)
+    else:
+        _create_db_from_orm(session=session)
+    # Load default connections
+    if conf.getboolean('database', 'LOAD_DEFAULT_CONNECTIONS') and load_connections:
+        create_default_connections(session=session)
+    # Add default pool & sync log_template
+    add_default_pool_if_not_exists()
+    synchronize_log_template()
 
 
 def _get_alembic_config():
@@ -1487,6 +1513,11 @@ def upgradedb(
     if errors_seen:
         exit(1)
 
+    if not to_revision and not _get_current_revision(session=session):
+        # Don't load default connections
+        # New DB; initialize and exit
+        initdb(session=session, load_connections=False)
+        return
     with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
         log.info("Creating tables")
         command.upgrade(config, revision=to_revision or 'heads')
@@ -1699,7 +1730,10 @@ def compare_type(context, inspected_column, metadata_column, inspected_type, met
 
         if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String):
             # This is a hack to get around MySQL VARCHAR collation
-            # not being possible to change from utf8_bin to utf8mb3_bin
+            # not being possible to change from utf8_bin to utf8mb3_bin.
+            # We only make sure lengths are the same
+            if inspected_type.length != metadata_type.length:
+                return True
             return False
     return None
 
diff --git a/tests/conftest.py b/tests/conftest.py
index b153c213d5..e6329e1991 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -193,10 +193,20 @@ def pytest_addoption(parser):
 
 
 def initial_db_init():
+    from flask import Flask
+
+    from airflow.configuration import conf
     from airflow.utils import db
+    from airflow.www.app import sync_appbuilder_roles
+    from airflow.www.extensions.init_appbuilder import init_appbuilder
 
     db.resetdb()
     db.bootstrap_dagbag()
+    # minimal app to add roles
+    flask_app = Flask(__name__)
+    flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get('database', 'SQL_ALCHEMY_CONN')
+    init_appbuilder(flask_app)
+    sync_appbuilder_roles(flask_app)
 
 
 @pytest.fixture(autouse=True, scope="session")