You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@bloodhound.apache.org by gj...@apache.org on 2012/03/05 12:38:36 UTC

svn commit: r1297011 - in /incubator/bloodhound/trunk/bloodhound_multiproduct: multiproduct/model.py tests/model.py

Author: gjm
Date: Mon Mar  5 11:38:36 2012
New Revision: 1297011

URL: http://svn.apache.org/viewvc?rev=1297011&view=rev
Log:
update parameterised sql queries to make use of automatic escapes; tests fixed

Modified:
    incubator/bloodhound/trunk/bloodhound_multiproduct/multiproduct/model.py
    incubator/bloodhound/trunk/bloodhound_multiproduct/tests/model.py

Modified: incubator/bloodhound/trunk/bloodhound_multiproduct/multiproduct/model.py
URL: http://svn.apache.org/viewvc/incubator/bloodhound/trunk/bloodhound_multiproduct/multiproduct/model.py?rev=1297011&r1=1297010&r2=1297011&view=diff
==============================================================================
--- incubator/bloodhound/trunk/bloodhound_multiproduct/multiproduct/model.py (original)
+++ incubator/bloodhound/trunk/bloodhound_multiproduct/multiproduct/model.py Mon Mar  5 11:38:36 2012
@@ -29,6 +29,23 @@ DB_VERSION = 1
 DB_SYSTEM_KEY = 'bloodhound_multi_product_version'
 PLUGIN_NAME = 'Bloodhound multi product'
 
+def dict_to_kv_str(data=None, sep=' AND '):
+    """Converts a dictionary into a string and a list suitable for using as part
+    of an SQL where clause like:
+        ('key0=%s AND key1=%s', ['value0','value1'])
+    The sep argument allows ' AND ' to be changed for ',' for UPDATE purposes
+    """
+    if data is None:
+        return ('', [])
+    return (sep.join(['%s=%%s' % k for k in data.keys()]), data.values())
+
+def fields_to_kv_str(fields, data, sep=' AND '):
+    """Converts a list of fields and a dictionary containing those fields into a
+    string and a list suitable for using as part of an SQL where clause like:
+        ('key0=%s,key1=%s', ['value0','value1'])
+    """
+    return dict_to_kv_str(dict([(f, data[f]) for f in fields]), sep)
+
 class ModelBase(object):
     """Base class for the models to factor out common features
     Derived classes should provide a meta dictionary to describe the table like:
@@ -94,8 +111,7 @@ class ModelBase(object):
     def _get_row(self, keys):
         """queries the database and stores the result in the model"""
         row = None
-        key_fields = self._meta['key_fields']
-        where = ','.join(['%s="%s"' % (k, keys[k]) for k in key_fields])
+        where, values = fields_to_kv_str(self._meta['key_fields'], keys)
         fields = ','.join(self._meta['key_fields']+self._meta['non_key_fields'])
         sdata = {'fields':fields,
                  'where':where}
@@ -104,7 +120,7 @@ class ModelBase(object):
         sql = """SELECT %(fields)s FROM %(table_name)s
                  WHERE %(where)s""" % sdata
         with self._env.db_query as db:
-            for row in db(sql):
+            for row in db(sql, values):
                 self._update_from_row(row)
                 break
             else:
@@ -115,13 +131,13 @@ class ModelBase(object):
         """Deletes the matching record from the database"""
         if not self._exists:
             raise TracError('%(object_name)s does not exist' % self._meta)
-        sdata = {'where':','.join(['%s="%s"' % (k, self._data[k])
-                                   for k in self._meta['key_fields']])}
+        where, values = fields_to_kv_str(self._meta['key_fields'], self._data)
+        sdata = {'where': where}
         sdata.update(self._meta)
         sql = """DELETE FROM %(table_name)s
                  WHERE %(where)s""" % sdata
         with self._env.db_transaction as db:
-            db(sql)
+            db(sql, values)
             self._exists = False
             self._data = dict([(k, None) for k in self._data.keys()])
             self._old_data.update(self._data)
@@ -131,7 +147,7 @@ class ModelBase(object):
         if self._exists or len(self.select(self._env, where =
                                 dict([(k,self._data[k])
                                       for k in self._meta['key_fields']]))):
-            sdata = {'keys':','.join(['%s="%s"' % (k, self._data[k])
+            sdata = {'keys':','.join(["%s='%s'" % (k, self._data[k])
                                      for k in self._meta['key_fields']])}
             sdata.update(self._meta)
             raise TracError('%(object_name)s %(keys)s already exists' %
@@ -145,13 +161,13 @@ class ModelBase(object):
                                 sdata)
         fields = self._meta['key_fields']+self._meta['non_key_fields']
         sdata = {'fields':','.join(fields),
-                 'values':','.join(['"%s"' % self._data[f] for f in fields])}
+                 'values':','.join(['%s'] * len(fields))}
         sdata.update(self._meta)
         
         sql = """INSERT INTO %(table_name)s (%(fields)s)
                  VALUES (%(values)s)""" % sdata
         with self._env.db_transaction as db:
-            db(sql)
+            db(sql, [self._data[f] for f in fields])
             self._exists = True
             self._old_data.update(self._data)
 
@@ -163,15 +179,17 @@ class ModelBase(object):
             if self._data[key] != self._old_data[key]:
                 raise TracError('%s cannot be changed' % key)
         
-        sdata = {'where':','.join(['%s="%s"' % (k, self._data[k])
-                                   for k in self._meta['key_fields']]),
-                 'values':','.join(['%s="%s"' % (k, self._data[k]) 
-                                    for k in self._meta['non_key_fields']])}
+        setsql, setvalues = fields_to_kv_str(self._meta['non_key_fields'],
+                                             self._data, sep=',')
+        where, values = fields_to_kv_str(self._meta['key_fields'], self._data)
+        
+        sdata = {'where': where,
+                 'values': setsql}
         sdata.update(self._meta)
         sql = """UPDATE %(table_name)s SET %(values)s
                  WHERE %(where)s""" % sdata
         with self._env.db_transaction as db:
-            db(sql)
+            db(sql, setvalues + values)
             self._old_data.update(self._data)
     
     @classmethod
@@ -182,12 +200,11 @@ class ModelBase(object):
         
         sdata = {'fields':','.join(fields),}
         sdata.update(cls._meta)
-        sql = 'SELECT %(fields)s FROM %(table_name)s' % sdata
-        wherestr = ''
-        if where is not None:
-            wherestr = ' WHERE ' + ','.join(['%s="%s"' % (k, v) 
-                                             for k, v in where.iteritems()])
-        for row in env.db_query(sql+wherestr):
+        sql = r'SELECT %(fields)s FROM %(table_name)s' % sdata
+        wherestr, values = dict_to_kv_str(where)
+        if wherestr:
+            wherestr = ' WHERE ' + wherestr
+        for row in env.db_query(sql + wherestr, values):
             # we won't know which class we need until called
             model = cls.__new__(cls)
             data = dict([(fields[i], row[i]) for i in range(len(fields))])
@@ -269,8 +286,8 @@ class MultiProductEnvironmentProvider(Co
     def get_version(self):
         """Finds the current version of the bloodhound database schema"""
         rows = self.env.db_query("""
-            SELECT value FROM system WHERE name = '%s'
-            """ % DB_SYSTEM_KEY)
+            SELECT value FROM system WHERE name = %s
+            """, (DB_SYSTEM_KEY,))
         return int(rows[0][0]) if rows else -1
     
     # IEnvironmentSetupParticipant methods

Modified: incubator/bloodhound/trunk/bloodhound_multiproduct/tests/model.py
URL: http://svn.apache.org/viewvc/incubator/bloodhound/trunk/bloodhound_multiproduct/tests/model.py?rev=1297011&r1=1297010&r2=1297011&view=diff
==============================================================================
--- incubator/bloodhound/trunk/bloodhound_multiproduct/tests/model.py (original)
+++ incubator/bloodhound/trunk/bloodhound_multiproduct/tests/model.py Mon Mar  5 11:38:36 2012
@@ -18,7 +18,6 @@
 
 """Tests for multiproduct/model.py"""
 import unittest
-import os
 import tempfile
 import shutil
 
@@ -72,14 +71,17 @@ class ProductTestCase(unittest.TestCase)
         product2.insert()
         product3.insert()
         
-        products = list(Product.select(self.env, {'prefix':'tp'}))
+        products = list(Product.select(self.env, where={'prefix':'tp'}))
         self.assertEqual(1, len(products))
-        products = list(Product.select(self.env, {'name':'test project'}))
+        products = list(Product.select(self.env, where={'name':'test project'}))
         self.assertEqual(3, len(products))
+        products = list(Product.select(self.env, where={'prefix':'tp3',
+                                                        'name':'test project'}))
+        self.assertEqual(1, len(products))
     
     def test_update(self):
         """tests that we can use update to push data to the database"""
-        product = list(Product.select(self.env, {'prefix':'tp'}))[0]
+        product = list(Product.select(self.env, where={'prefix':'tp'}))[0]
         self.assertEqual('test project', product._data['name'])
         
         new_data = {'prefix':'tp', 
@@ -88,7 +90,7 @@ class ProductTestCase(unittest.TestCase)
         product._data.update(new_data)
         product.update()
         
-        comp_product = list(Product.select(self.env, {'prefix':'tp'}))[0]
+        comp_product = list(Product.select(self.env, where={'prefix':'tp'}))[0]
         self.assertEqual('updated', comp_product._data['name'])
     
     def test_update_key_change(self):
@@ -96,7 +98,7 @@ class ProductTestCase(unittest.TestCase)
         bad_data = {'prefix':'tp0', 
                     'name':'update', 
                     'description':'nothing'}
-        product = list(Product.select(self.env, {'prefix':'tp'}))[0]
+        product = list(Product.select(self.env, where={'prefix':'tp'}))[0]
         product._data.update(bad_data)
         self.assertRaises(TracError, product.update)
     
@@ -107,7 +109,7 @@ class ProductTestCase(unittest.TestCase)
         product._data.update(data)
         product.insert()
         
-        check_products = list(Product.select(self.env, {'prefix':'new'}))
+        check_products = list(Product.select(self.env, where={'prefix':'new'}))
         
         self.assertEqual(product._data['prefix'],
                          check_products[0]._data['prefix'])
@@ -124,15 +126,15 @@ class ProductTestCase(unittest.TestCase)
     
     def test_delete(self):
         """test that we are able to delete Products"""
-        product = list(Product.select(self.env, {'prefix':'tp'}))[0]
+        product = list(Product.select(self.env, where={'prefix':'tp'}))[0]
         product.delete()
         
-        post = list(Product.select(self.env, {'prefix':'tp'}))
+        post = list(Product.select(self.env, where={'prefix':'tp'}))
         self.assertEqual(0, len(post))
         
     def test_delete_twice(self):
         """test that we error when deleting twice on the same key"""
-        product = list(Product.select(self.env, {'prefix':'tp'}))[0]
+        product = list(Product.select(self.env, where={'prefix':'tp'}))[0]
         product.delete()
         
         self.assertRaises(TracError, product.delete)
@@ -142,7 +144,7 @@ class ProductTestCase(unittest.TestCase)
         prefix = self.default_data['prefix']
         name = self.default_data['name']
         description = self.default_data['description']
-        product = list(Product.select(self.env, {'prefix':prefix}))[0]
+        product = list(Product.select(self.env, where={'prefix':prefix}))[0]
         self.assertEqual(prefix, product.prefix)
         self.assertEqual(name, product.name)
         self.assertEqual(description, product.description)
@@ -150,7 +152,7 @@ class ProductTestCase(unittest.TestCase)
     def test_field_set(self):
         """tests that we can use table.field = something to set field data"""
         prefix = self.default_data['prefix']
-        product = list(Product.select(self.env, {'prefix':prefix}))[0]
+        product = list(Product.select(self.env, where={'prefix':prefix}))[0]
         
         new_description = 'test change of description'
         product.description = new_description