You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2017/11/20 04:09:24 UTC

[incubator-superset] branch master updated: [3541] Augmenting datasources uniqueness constraints (#3583)

This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c72e1f  [3541] Augmenting datasources uniqueness constraints (#3583)
3c72e1f is described below

commit 3c72e1f8fbd842ff4aff88ade125c2e5d6cb185a
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Sun Nov 19 20:09:18 2017 -0800

    [3541] Augmenting datasources uniqueness constraints (#3583)
---
 superset/connectors/druid/models.py           |  47 +++---
 superset/migrations/versions/4736ec66ce19_.py | 201 ++++++++++++++++++++++++++
 superset/utils.py                             |  29 +++-
 tests/import_export_tests.py                  |  11 +-
 4 files changed, 256 insertions(+), 32 deletions(-)

diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index d19a8f0..90b4dc0 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -22,7 +22,7 @@ import requests
 from six import string_types
 import sqlalchemy as sa
 from sqlalchemy import (
-    Boolean, Column, DateTime, ForeignKey, Integer, or_, String, Text,
+    Boolean, Column, DateTime, ForeignKey, Integer, or_, String, Text, UniqueConstraint,
 )
 from sqlalchemy.orm import backref, relationship
 
@@ -169,7 +169,7 @@ class DruidCluster(Model, AuditMixinNullable):
             if cols:
                 col_objs_list = (
                     session.query(DruidColumn)
-                    .filter(DruidColumn.datasource_name == datasource.datasource_name)
+                    .filter(DruidColumn.datasource_id == datasource.id)
                     .filter(or_(DruidColumn.column_name == col for col in cols))
                 )
                 col_objs = {col.column_name: col for col in col_objs_list}
@@ -179,7 +179,7 @@ class DruidCluster(Model, AuditMixinNullable):
                     col_obj = col_objs.get(col, None)
                     if not col_obj:
                         col_obj = DruidColumn(
-                            datasource_name=datasource.datasource_name,
+                            datasource_id=datasource.id,
                             column_name=col)
                         with session.no_autoflush:
                             session.add(col_obj)
@@ -220,9 +220,9 @@ class DruidColumn(Model, BaseColumn):
 
     __tablename__ = 'columns'
 
-    datasource_name = Column(
-        String(255),
-        ForeignKey('datasources.datasource_name'))
+    datasource_id = Column(
+        Integer,
+        ForeignKey('datasources.id'))
     # Setting enable_typechecks=False disables polymorphic inheritance.
     datasource = relationship(
         'DruidDatasource',
@@ -231,7 +231,7 @@ class DruidColumn(Model, BaseColumn):
     dimension_spec_json = Column(Text)
 
     export_fields = (
-        'datasource_name', 'column_name', 'is_active', 'type', 'groupby',
+        'datasource_id', 'column_name', 'is_active', 'type', 'groupby',
         'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable',
         'description', 'dimension_spec_json',
     )
@@ -334,15 +334,14 @@ class DruidColumn(Model, BaseColumn):
         metrics = self.get_metrics()
         dbmetrics = (
             db.session.query(DruidMetric)
-            .filter(DruidCluster.cluster_name == self.datasource.cluster_name)
-            .filter(DruidMetric.datasource_name == self.datasource_name)
+            .filter(DruidMetric.datasource_id == self.datasource_id)
             .filter(or_(
                 DruidMetric.metric_name == m for m in metrics
             ))
         )
         dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
         for metric in metrics.values():
-            metric.datasource_name = self.datasource_name
+            metric.datasource_id = self.datasource_id
             if not dbmetrics.get(metric.metric_name, None):
                 db.session.add(metric)
 
@@ -350,7 +349,7 @@ class DruidColumn(Model, BaseColumn):
     def import_obj(cls, i_column):
         def lookup_obj(lookup_column):
             return db.session.query(DruidColumn).filter(
-                DruidColumn.datasource_name == lookup_column.datasource_name,
+                DruidColumn.datasource_id == lookup_column.datasource_id,
                 DruidColumn.column_name == lookup_column.column_name).first()
 
         return import_util.import_simple_obj(db.session, i_column, lookup_obj)
@@ -361,9 +360,9 @@ class DruidMetric(Model, BaseMetric):
     """ORM object referencing Druid metrics for a datasource"""
 
     __tablename__ = 'metrics'
-    datasource_name = Column(
-        String(255),
-        ForeignKey('datasources.datasource_name'))
+    datasource_id = Column(
+        Integer,
+        ForeignKey('datasources.id'))
     # Setting enable_typechecks=False disables polymorphic inheritance.
     datasource = relationship(
         'DruidDatasource',
@@ -372,7 +371,7 @@ class DruidMetric(Model, BaseMetric):
     json = Column(Text)
 
     export_fields = (
-        'metric_name', 'verbose_name', 'metric_type', 'datasource_name',
+        'metric_name', 'verbose_name', 'metric_type', 'datasource_id',
         'json', 'description', 'is_restricted', 'd3format',
     )
 
@@ -400,7 +399,7 @@ class DruidMetric(Model, BaseMetric):
     def import_obj(cls, i_metric):
         def lookup_obj(lookup_metric):
             return db.session.query(DruidMetric).filter(
-                DruidMetric.datasource_name == lookup_metric.datasource_name,
+                DruidMetric.datasource_id == lookup_metric.datasource_id,
                 DruidMetric.metric_name == lookup_metric.metric_name).first()
         return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
 
@@ -420,7 +419,7 @@ class DruidDatasource(Model, BaseDatasource):
     baselink = 'druiddatasourcemodelview'
 
     # Columns
-    datasource_name = Column(String(255), unique=True)
+    datasource_name = Column(String(255))
     is_hidden = Column(Boolean, default=False)
     fetch_values_from = Column(String(100))
     cluster_name = Column(
@@ -432,6 +431,7 @@ class DruidDatasource(Model, BaseDatasource):
         sm.user_model,
         backref=backref('datasources', cascade='all, delete-orphan'),
         foreign_keys=[user_id])
+    UniqueConstraint('cluster_name', 'datasource_name')
 
     export_fields = (
         'datasource_name', 'is_hidden', 'description', 'default_endpoint',
@@ -519,7 +519,7 @@ class DruidDatasource(Model, BaseDatasource):
          superset instances. Audit metadata isn't copies over.
         """
         def lookup_datasource(d):
-            return db.session.query(DruidDatasource).join(DruidCluster).filter(
+            return db.session.query(DruidDatasource).filter(
                 DruidDatasource.datasource_name == d.datasource_name,
                 DruidCluster.cluster_name == d.cluster_name,
             ).first()
@@ -620,13 +620,12 @@ class DruidDatasource(Model, BaseDatasource):
             metrics.update(col.get_metrics())
         dbmetrics = (
             db.session.query(DruidMetric)
-            .filter(DruidCluster.cluster_name == self.cluster_name)
-            .filter(DruidMetric.datasource_name == self.datasource_name)
+            .filter(DruidMetric.datasource_id == self.id)
             .filter(or_(DruidMetric.metric_name == m for m in metrics))
         )
         dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
         for metric in metrics.values():
-            metric.datasource_name = self.datasource_name
+            metric.datasource_id = self.id
             if not dbmetrics.get(metric.metric_name, None):
                 with db.session.no_autoflush:
                     db.session.add(metric)
@@ -661,7 +660,7 @@ class DruidDatasource(Model, BaseDatasource):
         dimensions = druid_config['dimensions']
         col_objs = (
             session.query(DruidColumn)
-            .filter(DruidColumn.datasource_name == druid_config['name'])
+            .filter(DruidColumn.datasource_id == datasource.id)
             .filter(or_(DruidColumn.column_name == dim for dim in dimensions))
         )
         col_objs = {col.column_name: col for col in col_objs}
@@ -669,7 +668,7 @@ class DruidDatasource(Model, BaseDatasource):
             col_obj = col_objs.get(dim, None)
             if not col_obj:
                 col_obj = DruidColumn(
-                    datasource_name=druid_config['name'],
+                    datasource_id=datasource.id,
                     column_name=dim,
                     groupby=True,
                     filterable=True,
@@ -681,7 +680,7 @@ class DruidDatasource(Model, BaseDatasource):
         # Import Druid metrics
         metric_objs = (
             session.query(DruidMetric)
-            .filter(DruidMetric.datasource_name == druid_config['name'])
+            .filter(DruidMetric.datasource_id == datasource.id)
             .filter(or_(DruidMetric.metric_name == spec['name']
                     for spec in druid_config['metrics_spec']))
         )
diff --git a/superset/migrations/versions/4736ec66ce19_.py b/superset/migrations/versions/4736ec66ce19_.py
new file mode 100644
index 0000000..2d560d5
--- /dev/null
+++ b/superset/migrations/versions/4736ec66ce19_.py
@@ -0,0 +1,201 @@
+"""empty message
+
+Revision ID: 4736ec66ce19
+Revises: f959a6652acd
+Create Date: 2017-10-03 14:37:01.376578
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '4736ec66ce19'
+down_revision = 'f959a6652acd'
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.exc import OperationalError
+
+from superset.utils import (
+    generic_find_fk_constraint_name,
+    generic_find_fk_constraint_names,
+    generic_find_uq_constraint_name,
+)
+
+
+conv = {
+    'fk': 'fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s',
+    'uq': 'uq_%(table_name)s_%(column_0_name)s',
+}
+
+# Helper table for database migrations using minimal schema.
+datasources = sa.Table(
+    'datasources',
+    sa.MetaData(),
+    sa.Column('id', sa.Integer, primary_key=True),
+    sa.Column('datasource_name', sa.String(255)),
+)
+
+bind = op.get_bind()
+insp = sa.engine.reflection.Inspector.from_engine(bind)
+
+
+def upgrade():
+
+    # Add the new less restrictive uniqueness constraint.
+    with op.batch_alter_table('datasources', naming_convention=conv) as batch_op:
+        batch_op.create_unique_constraint(
+            'uq_datasources_cluster_name',
+            ['cluster_name', 'datasource_name'],
+        )
+
+    # Augment the tables which have a foreign key constraint related to the
+    # datasources.datasource_name column.
+    for foreign in ['columns', 'metrics']:
+        with op.batch_alter_table(foreign, naming_convention=conv) as batch_op:
+
+            # Add the datasource_id column with the relevant constraints.
+            batch_op.add_column(sa.Column('datasource_id', sa.Integer))
+
+            batch_op.create_foreign_key(
+                'fk_{}_datasource_id_datasources'.format(foreign),
+                'datasources',
+                ['datasource_id'],
+                ['id'],
+            )
+
+        # Helper table for database migration using minimal schema.
+        table = sa.Table(
+            foreign,
+            sa.MetaData(),
+            sa.Column('id', sa.Integer, primary_key=True),
+            sa.Column('datasource_name', sa.String(255)),
+            sa.Column('datasource_id', sa.Integer),
+        )
+
+        # Migrate the existing data.
+        for datasource in bind.execute(datasources.select()):
+            bind.execute(
+                table.update().where(
+                    table.c.datasource_name == datasource.datasource_name,
+                ).values(
+                    datasource_id=datasource.id,
+                ),
+            )
+
+        with op.batch_alter_table(foreign, naming_convention=conv) as batch_op:
+
+            # Drop the datasource_name column and associated constraints. Note
+            # due to prior revisions (1226819ee0e3, 3b626e2a6783) there may
+            # incorectly be multiple duplicate constraints.
+            names = generic_find_fk_constraint_names(
+                foreign,
+                {'datasource_name'},
+                'datasources',
+                insp,
+            )
+
+            for name in names:
+                batch_op.drop_constraint(
+                    name or 'fk_{}_datasource_name_datasources'.format(foreign),
+                    type_='foreignkey',
+                )
+
+            batch_op.drop_column('datasource_name')
+
+    # Drop the old more restrictive uniqueness constraint.
+    with op.batch_alter_table('datasources', naming_convention=conv) as batch_op:
+        batch_op.drop_constraint(
+            generic_find_uq_constraint_name(
+                'datasources',
+                {'datasource_name'},
+                insp,
+            ) or 'uq_datasources_datasource_name',
+            type_='unique',
+        )
+
+
+def downgrade():
+
+    # Add the new more restrictive uniqueness constraint which is required by
+    # the foreign key constraints. Note this operation will fail if the
+    # datasources.datasource_name column is no longer unique.
+    with op.batch_alter_table('datasources', naming_convention=conv) as batch_op:
+        batch_op.create_unique_constraint(
+            'uq_datasources_datasource_name',
+            ['datasource_name'],
+        )
+
+    # Augment the tables which have a foreign key constraint related to the
+    # datasources.datasource_id column.
+    for foreign in ['columns', 'metrics']:
+        with op.batch_alter_table(foreign, naming_convention=conv) as batch_op:
+
+            # Add the datasource_name column with the relevant constraints.
+            batch_op.add_column(sa.Column('datasource_name', sa.String(255)))
+
+            batch_op.create_foreign_key(
+                'fk_{}_datasource_name_datasources'.format(foreign),
+                'datasources',
+                ['datasource_name'],
+                ['datasource_name'],
+            )
+
+        # Helper table for database migration using minimal schema.
+        table = sa.Table(
+            foreign,
+            sa.MetaData(),
+            sa.Column('id', sa.Integer, primary_key=True),
+            sa.Column('datasource_name', sa.String(255)),
+            sa.Column('datasource_id', sa.Integer),
+        )
+
+        # Migrate the existing data.
+        for datasource in bind.execute(datasources.select()):
+            bind.execute(
+                table.update().where(
+                    table.c.datasource_id == datasource.id,
+                ).values(
+                    datasource_name=datasource.datasource_name,
+                ),
+            )
+
+        with op.batch_alter_table(foreign, naming_convention=conv) as batch_op:
+
+            # Drop the datasource_id column and associated constraint.
+            batch_op.drop_constraint(
+                'fk_{}_datasource_id_datasources'.format(foreign),
+                type_='foreignkey',
+            )
+
+            batch_op.drop_column('datasource_id')
+
+    with op.batch_alter_table('datasources', naming_convention=conv) as batch_op:
+
+        # Prior to dropping the uniqueness constraint, the foreign key
+        # associated with the cluster_name column needs to be dropped.
+        batch_op.drop_constraint(
+            generic_find_fk_constraint_name(
+                'datasources',
+                {'cluster_name'},
+                'clusters',
+                insp,
+            ) or 'fk_datasources_cluster_name_clusters',
+            type_='foreignkey',
+        )
+
+        # Drop the old less restrictive uniqueness constraint.
+        batch_op.drop_constraint(
+            generic_find_uq_constraint_name(
+                'datasources',
+                {'cluster_name', 'datasource_name'},
+                insp,
+            ) or 'uq_datasources_cluster_name',
+            type_='unique',
+        )
+
+        # Re-create the foreign key associated with the cluster_name column.
+        batch_op.create_foreign_key(
+                'fk_{}_datasource_id_datasources'.format(foreign),
+                'clusters',
+                ['cluster_name'],
+                ['cluster_name'],
+            )
diff --git a/superset/utils.py b/superset/utils.py
index 469bbc2..bae330b 100644
--- a/superset/utils.py
+++ b/superset/utils.py
@@ -377,11 +377,36 @@ def generic_find_constraint_name(table, columns, referenced, db):
     t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
 
     for fk in t.foreign_key_constraints:
-        if (fk.referred_table.name == referenced and
-                set(fk.column_keys) == columns):
+        if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
             return fk.name
 
 
+def generic_find_fk_constraint_name(table, columns, referenced, insp):
+    """Utility to find a foreign-key constraint name in alembic migrations"""
+    for fk in insp.get_foreign_keys(table):
+        if fk['referred_table'] == referenced and set(fk['referred_columns']) == columns:
+            return fk['name']
+
+
+def generic_find_fk_constraint_names(table, columns, referenced, insp):
+    """Utility to find foreign-key constraint names in alembic migrations"""
+    names = set()
+
+    for fk in insp.get_foreign_keys(table):
+        if fk['referred_table'] == referenced and set(fk['referred_columns']) == columns:
+            names.add(fk['name'])
+
+    return names
+
+
+def generic_find_uq_constraint_name(table, columns, insp):
+    """Utility to find a unique constraint name in alembic migrations"""
+
+    for uq in insp.get_unique_constraints(table):
+        if columns == set(uq['column_names']):
+            return uq['name']
+
+
 def get_datasource_full_name(database_name, datasource_name, schema=None):
     if not schema:
         return '[{}].[{}]'.format(database_name, datasource_name)
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index e945630..0710cac 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -485,13 +485,12 @@ class ImportExportTests(SupersetTestCase):
 
     def test_import_druid_override(self):
         datasource = self.create_druid_datasource(
-            'druid_override', id=10003, cols_names=['col1'],
+            'druid_override', id=10004, cols_names=['col1'],
             metric_names=['m1'])
         imported_id = DruidDatasource.import_obj(
             datasource, import_time=1991)
-
         table_over = self.create_druid_datasource(
-            'druid_override', id=10003,
+            'druid_override', id=10004,
             cols_names=['new_col1', 'col2', 'col3'],
             metric_names=['new_metric1'])
         imported_over_id = DruidDatasource.import_obj(
@@ -500,19 +499,19 @@ class ImportExportTests(SupersetTestCase):
         imported_over = self.get_datasource(imported_over_id)
         self.assertEquals(imported_id, imported_over.id)
         expected_datasource = self.create_druid_datasource(
-            'druid_override', id=10003, metric_names=['new_metric1', 'm1'],
+            'druid_override', id=10004, metric_names=['new_metric1', 'm1'],
             cols_names=['col1', 'new_col1', 'col2', 'col3'])
         self.assert_datasource_equals(expected_datasource, imported_over)
 
     def test_import_druid_override_idential(self):
         datasource = self.create_druid_datasource(
-            'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
+            'copy_cat', id=10005, cols_names=['new_col1', 'col2', 'col3'],
             metric_names=['new_metric1'])
         imported_id = DruidDatasource.import_obj(
             datasource, import_time=1993)
 
         copy_datasource = self.create_druid_datasource(
-            'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'],
+            'copy_cat', id=10005, cols_names=['new_col1', 'col2', 'col3'],
             metric_names=['new_metric1'])
         imported_id_copy = DruidDatasource.import_obj(
             copy_datasource, import_time=1994)

-- 
To stop receiving notification emails like this one, please contact
['"commits@superset.apache.org" <co...@superset.apache.org>'].