You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@bloodhound.apache.org by ju...@apache.org on 2012/12/18 11:48:20 UTC

svn commit: r1423401 - /incubator/bloodhound/branches/bep_0003_multiproduct/trac/trac/bloodhound/db.py

Author: jure
Date: Tue Dec 18 10:48:19 2012
New Revision: 1423401

URL: http://svn.apache.org/viewvc?rev=1423401&view=rev
Log:
Towards #288, DDL (CREATE TABLE/INDEX, ALTER TABLE, DROP TABLE) support for 3rd party (plugins) tables


Modified:
    incubator/bloodhound/branches/bep_0003_multiproduct/trac/trac/bloodhound/db.py

Modified: incubator/bloodhound/branches/bep_0003_multiproduct/trac/trac/bloodhound/db.py
URL: http://svn.apache.org/viewvc/incubator/bloodhound/branches/bep_0003_multiproduct/trac/trac/bloodhound/db.py?rev=1423401&r1=1423400&r2=1423401&view=diff
==============================================================================
--- incubator/bloodhound/branches/bep_0003_multiproduct/trac/trac/bloodhound/db.py (original)
+++ incubator/bloodhound/branches/bep_0003_multiproduct/trac/trac/bloodhound/db.py Tue Dec 18 10:48:19 2012
@@ -76,7 +76,7 @@ class BloodhoundProductSQLTranslate(obje
     _join_statements = ['LEFT JOIN', 'LEFT OUTER JOIN',
                         'RIGHT JOIN', 'RIGHT OUTER JOIN',
                         'JOIN', 'INNER JOIN']
-    _from_end_words = ['WHERE', 'GROUP', 'HAVING', 'ORDER', 'UNION']
+    _from_end_words = ['WHERE', 'GROUP', 'HAVING', 'ORDER', 'UNION', 'LIMIT']
 
     def __init__(self, skip_tables, translate_tables, product_column, product_prefix):
         self._skip_tables = skip_tables
@@ -95,11 +95,11 @@ class BloodhoundProductSQLTranslate(obje
             sql += ' AS %s' % alias
         return sql
 
-    def _prefixed_table_name(self, tablename):
+    def _prefixed_table_entity_name(self, tablename):
         return "%s_%s" % (self._product_prefix, tablename)
 
     def _prefixed_table_view_sql(self, name, alias):
-        return '(SELECT * FROM %s) AS %s' % (self._prefixed_table_name(name),
+        return '(SELECT * FROM %s) AS %s' % (self._prefixed_table_entity_name(name),
                                              alias)
 
     def _token_first(self, parent):
@@ -201,14 +201,16 @@ class BloodhoundProductSQLTranslate(obje
                 parent.tokens[self._token_idx(parent, token)] = sqlparse.parse(self._prefixed_table_view_sql(name,
                                                                                                              alias))[0]
 
-        def process_table_name_tokens(token, nametokens):
+        def process_table_name_tokens(nametokens):
             if nametokens:
                 l = self._select_table_name_alias(nametokens)
                 if not l:
                     raise Exception("Invalid FROM table name")
                 name, alias = l[0], None
-                if len(l) > 1:
-                    alias = l[1]
+                alias = l[1] if len(l) > 1 else name
+                token = nametokens[0]
+                for t in nametokens[1:]:
+                    del parent.tokens[self._token_idx(parent, t)]
                 inject_table_view(token, name, alias)
             return list()
 
@@ -234,17 +236,20 @@ class BloodhoundProductSQLTranslate(obje
                 else:
                     tablename = current_token.value.strip()
                     tablealias = current_token.get_name().strip()
-                    inject_table_view(current_token, tablename, tablealias)
+                    if tablename == tablealias:
+                        table_name_tokens.append(current_token)
+                    else:
+                        inject_table_view(current_token, tablename, tablealias)
             elif current_token.ttype == Tokens.Punctuation:
                 if table_name_tokens:
                     next_token = self._token_next(parent, current_token)
-                    table_name_tokens = process_table_name_tokens(current_token,
-                                                                  table_name_tokens)
+                    table_name_tokens = process_table_name_tokens(table_name_tokens)
             elif current_token.match(Tokens.Keyword, ['JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER'] + self._join_statements):
                 join_tokens.append(current_token.value.strip().upper())
                 join = ' '.join(join_tokens)
                 if join in self._join_statements:
                     join_tokens = list()
+                    table_name_tokens = process_table_name_tokens(table_name_tokens)
                     next_token = self._select_join(parent,
                                                    current_token,
                                                    ['JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER']
@@ -257,9 +262,8 @@ class BloodhoundProductSQLTranslate(obje
                 raise Exception("Failed to parse FROM table name")
             current_token = next_token
 
-        if last_token and table_name_tokens:
-            process_table_name_tokens(last_token,
-                                      table_name_tokens)
+        if last_token:
+            process_table_name_tokens(table_name_tokens)
         return current_token
 
     def _select(self, parent, start_token, insert_table=None):
@@ -282,7 +286,7 @@ class BloodhoundProductSQLTranslate(obje
             return None
         while current_token:
             if isinstance(current_token, Types.Where) or \
-               current_token.match(Tokens.Keyword, ['GROUP', 'HAVING', 'ORDER']):
+               current_token.match(Tokens.Keyword, ['GROUP', 'HAVING', 'ORDER', 'LIMIT']):
                 if isinstance(current_token, Types.Where):
                     self._where(parent, current_token)
                 start_token = self._token_next(parent, current_token)
@@ -301,13 +305,22 @@ class BloodhoundProductSQLTranslate(obje
             current_token = next_token
         return current_token
 
-    def _replace_table_name(self, parent, token, table_name):
+    def _replace_table_entity_name(self, parent, token, table_name, entity_name=None):
+        if not entity_name:
+            entity_name = table_name
         next_token = self._token_next(parent, token)
-        if table_name in self._skip_tables + self._translate_tables:
-            pass
-        else:
-            parent.tokens[self._token_idx(parent, token)] = Types.Token(Tokens.Keyword,
-                                                                        self._prefixed_table_name(table_name))
+        if not table_name in self._skip_tables + self._translate_tables:
+            token_to_replace = parent.tokens[self._token_idx(parent, token)]
+            if isinstance(token_to_replace, Types.Function):
+                t = self._token_first(token_to_replace)
+                if isinstance(t, Types.Identifier):
+                    token_to_replace.tokens[self._token_idx(token_to_replace, t)] = Types.Token(Tokens.Keyword,
+                                                                                                self._prefixed_table_entity_name(entity_name))
+            elif isinstance(token_to_replace, Types.Identifier) or isinstance(token_to_replace, Types.Token):
+                parent.tokens[self._token_idx(parent, token_to_replace)] = Types.Token(Tokens.Keyword,
+                                                                                       self._prefixed_table_entity_name(entity_name))
+            else:
+                raise Exception("Internal error, invalid table entity token type")
         return next_token
 
     def _insert(self, parent, start_token):
@@ -334,15 +347,16 @@ class BloodhoundProductSQLTranslate(obje
             token = self._token_first(table_name_token)
             if isinstance(token, Types.Identifier):
                 tablename = token.get_name()
-                columns_token = self._replace_table_name(table_name_token, token, tablename)
+                columns_token = self._replace_table_entity_name(table_name_token, token, tablename)
                 insert_extra_column(tablename, columns_token)
                 token = self._token_next(parent, table_name_token)
         else:
             tablename = table_name_token.value
-            columns_token = self._replace_table_name(parent, table_name_token, tablename)
+            columns_token = self._replace_table_entity_name(parent, table_name_token, tablename)
             insert_extra_column(tablename, columns_token)
             token = self._token_next(parent, columns_token)
         if token.match(Tokens.Keyword, 'VALUES'):
+            separators = [',', '(', ')']
             token = self._token_next(parent, token)
             while token:
                 if isinstance(token, Types.Parenthesis):
@@ -351,14 +365,14 @@ class BloodhoundProductSQLTranslate(obje
                         raise Exception("Invalid INSERT statement")
                     insert_extra_column_value(tablename, token, ptoken)
                     while ptoken:
-                        if not ptoken.match(Tokens.Punctuation, [',', '(', ')']) and \
-                           not ptoken.match(Tokens.Keyword, [',', '(', ')']) and \
+                        if not ptoken.match(Tokens.Punctuation, separators) and \
+                           not ptoken.match(Tokens.Keyword, separators) and \
                            not ptoken.is_whitespace():
                             ptoken = self._expression_token_unwind_hack(token, ptoken, self._token_prev(token, ptoken))
                             self._eval_expression_value(token, ptoken)
                         ptoken = self._token_next(token, ptoken)
-                elif not token.match(Tokens.Punctuation, [',', '(', ')']) and\
-                     not token.match(Tokens.Keyword, [',', '(', ')']) and\
+                elif not token.match(Tokens.Punctuation, separators) and\
+                     not token.match(Tokens.Keyword, separators) and\
                      not token.is_whitespace():
                     raise Exception("Invalid INSERT statement, unable to parse VALUES section")
                 token = self._token_next(parent, token)
@@ -404,15 +418,24 @@ class BloodhoundProductSQLTranslate(obje
                     self._token_insert_after(parent, last_token, Types.Token(Tokens.Keyword, keyword))
         return
 
+    def _get_entity_name_from_token(self, parent, token):
+        tablename = None
+        if isinstance(token, Types.Identifier):
+            tablename = token.get_name()
+        elif isinstance(token, Types.Function):
+            token = self._token_first(token)
+            if isinstance(token, Types.Identifier):
+                tablename = token.get_name()
+        elif isinstance(token, Types.Token):
+            tablename = token.value
+        return tablename
+
     def _update(self, parent, start_token):
         table_name_token = self._token_next(parent, start_token)
-        if isinstance(table_name_token, Types.Identifier):
-            tablename = table_name_token.get_name()
-        elif isinstance(table_name_token, Types.Token):
-            tablename = table_name_token.value
-        else:
+        tablename = self._get_entity_name_from_token(parent, table_name_token)
+        if not tablename:
             raise Exception("Invalid UPDATE statement, expected table name")
-        token = self._replace_table_name(parent, table_name_token, tablename)
+        token = self._replace_table_entity_name(parent, table_name_token, tablename)
         set_token = self._token_next_match(parent, token, Tokens.Keyword, 'SET')
         if set_token:
             token = set_token
@@ -442,30 +465,94 @@ class BloodhoundProductSQLTranslate(obje
         if not token.match(Tokens.Keyword, 'FROM'):
             raise Exception("Invalid DELETE statement")
         table_name_token = self._token_next(parent, token)
-        if isinstance(table_name_token, Types.Identifier):
-            tablename = table_name_token.get_name()
-        elif isinstance(table_name_token, Types.Token):
-            tablename = table_name_token.value
-        else:
+        tablename = self._get_entity_name_from_token(parent, table_name_token)
+        if not tablename:
             raise Exception("Invalid DELETE statement, expected table name")
-        start_token = self._replace_table_name(parent, table_name_token, tablename)
+        start_token = self._replace_table_entity_name(parent, table_name_token, tablename)
         self._update_delete_where_limit(tablename, parent, start_token)
         return
 
+    def _create(self, parent, start_token):
+        token = self._token_next(parent, start_token)
+        if token.match(Tokens.Keyword, 'TABLE'):
+            token = self._token_next(parent, token)
+            while token.match(Tokens.Keyword, ['IF', 'NOT', 'EXIST']) or \
+                  token.is_whitespace():
+                token = self._token_next(parent, token)
+            table_name = self._get_entity_name_from_token(parent, token)
+            if not table_name:
+                raise Exception("Invalid CREATE TABLE statement, expected table name")
+            self._replace_table_entity_name(parent, token, table_name)
+        elif token.match(Tokens.Keyword, ['UNIQUE', 'INDEX']):
+            if token.match(Tokens.Keyword, 'UNIQUE'):
+                token = self._token_next(parent, token)
+            if token.match(Tokens.Keyword, 'INDEX'):
+                index_token = self._token_next(parent, token)
+                index_name = self._get_entity_name_from_token(parent, index_token)
+                if not index_name:
+                    raise Exception("Invalid CREATE INDEX statement, expected index name")
+                on_token = self._token_next_match(parent, index_token, Tokens.Keyword, 'ON')
+                if not on_token:
+                    raise Exception("Invalid CREATE INDEX statement, expected ON specifier")
+                table_name_token = self._token_next(parent, on_token)
+                table_name = self._get_entity_name_from_token(parent, table_name_token)
+                if not table_name:
+                    raise Exception("Invalid CREATE INDEX statement, expected table name")
+                self._replace_table_entity_name(parent, table_name_token, table_name)
+                self._replace_table_entity_name(parent, index_token, table_name, entity_name=index_name)
+        return
+
+    def _alter(self, parent, start_token):
+        token = self._token_next(parent, start_token)
+        if token.match(Tokens.Keyword, 'TABLE'):
+            token = self._token_next(parent, token)
+            table_name = self._get_entity_name_from_token(parent, token)
+            if not table_name:
+                raise Exception("Invalid CREATE TABLE statement, expected table name")
+            token = self._replace_table_entity_name(parent, token, table_name)
+            if token.match(Tokens.Keyword.DDL, ['ADD', 'DROP']) or\
+               token.match(Tokens.Keyword, ['ADD', 'DROP']):
+                token = self._token_next(parent, token)
+                if token.match(Tokens.Keyword, 'CONSTRAINT'):
+                    token = self._token_next(parent, token)
+                    constraint_name = self._get_entity_name_from_token(parent, token)
+                    if not constraint_name:
+                        raise Exception("Invalid ALTER TABLE statement, expected constraint name")
+                    self._replace_table_entity_name(parent, token, table_name, constraint_name)
+        return
+
+    def _drop(self, parent, start_token):
+        token = self._token_next(parent, start_token)
+        if token.match(Tokens.Keyword, 'TABLE'):
+            token = self._token_next(parent, token)
+            while token.match(Tokens.Keyword, ['IF', 'EXIST']) or\
+                  token.is_whitespace():
+                token = self._token_next(parent, token)
+            table_name = self._get_entity_name_from_token(parent, token)
+            if not table_name:
+                raise Exception("Invalid DROP TABLE statement, expected table name")
+            self._replace_table_entity_name(parent, token, table_name)
+        return
+
     def translate(self, sql):
         dml_handlers = {'SELECT': self._select,
                         'INSERT': self._insert,
                         'UPDATE': self._update,
                         'DELETE': self._delete,
                         }
+        ddl_handlers = {'CREATE': self._create,
+                        'ALTER': self._alter,
+                        'DROP': self._drop,
+                        }
         try:
             sql_statement = sqlparse.parse(sql)[0]
             t = sql_statement.token_first()
-            if not t.match(Tokens.DML, dml_handlers.keys()):
-                return sql
-            dml_handlers[t.value](sql_statement, t)
-            translated_sql = sqlparse.format(sql_statement.to_unicode(), reindent=True)
+            if t.match(Tokens.DML, dml_handlers.keys()):
+                dml_handlers[t.value](sql_statement, t)
+                sql = sqlparse.format(sql_statement.to_unicode(), reindent=True)
+            elif t.match(Tokens.DDL, ddl_handlers.keys()):
+                ddl_handlers[t.value](sql_statement, t)
+                sql = sqlparse.format(sql_statement.to_unicode(), reindent=True)
         except Exception:
             raise Exception("Failed to translate SQL '%s'" % sql)
-        return translated_sql
-
+        return sql