You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by ca...@codespot.com on 2012/09/25 23:00:20 UTC

[cassandra-dbapi2] 4 new revisions pushed by pcannon@gmail.com on 2012-09-25 21:00 GMT

4 new revisions:

Revision: d860b4b2250d
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 11:26:45 2012
Log:      support snappy compression, if installed
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=d860b4b2250d

Revision: 56d24cd277c2
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 13:44:33 2012
Log:      update tests; move thrift_client in test_cql
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=56d24cd277c2

Revision: 553647f4b1b9
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 13:46:01 2012
Log:      add basic tests for native protocol
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=553647f4b1b9

Revision: 96b064c3159b
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 13:45:33 2012
Log:      support for callbacks on native-proto events...
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=96b064c3159b

==============================================================================
Revision: d860b4b2250d
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 11:26:45 2012
Log:      support snappy compression, if installed

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=d860b4b2250d

Modified:
  /cql/connection.py
  /cql/native.py

=======================================
--- /cql/connection.py	Mon Sep 17 04:34:57 2012
+++ /cql/connection.py	Tue Sep 25 11:26:45 2012
@@ -29,8 +29,12 @@
          * user .........: username used in authentication (optional).
          * password .....: password used in authentication (optional).
          * cql_version...: CQL version to use (optional).
-        * compression...: the sort of compression to use by default;
-        *                 overrideable per Cursor object. (optional).
+        * compression...: whether to use compression. For Thrift  
connections,
+        *                 this can be None or the name of some supported
+        *                 compression type (like "GZIP"). For native
+        *                 connections, this is treated as a boolean, and if
+        *                 true, the connection will try to find a type of
+        *                 compression supported by both sides.
          """
          self.host = host
          self.port = port
@@ -85,12 +89,37 @@
          return curs

  # TODO: Pull connections out of a pool instead.
-def connect(host, port=9160, keyspace=None, user=None, password=None,
-            cql_version=None, native=False):
+def connect(host, port=None, keyspace=None, user=None, password=None,
+            cql_version=None, native=False, compression=None):
+    """
+    Create a connection to a Cassandra node.
+
+    @param host Hostname of Cassandra node.
+    @param port Port number to connect to (default 9160 for thrift, 8000
+                for native)
+    @param keyspace If set, authenticate to this keyspace on connection.
+    @param user If set, use this username in authentication.
+    @param password If set, use this password in authentication.
+    @param cql_version If set, try to use the given CQL version. If unset,
+                uses the default for the connection.
+    @param compression Whether to use compression. For Thrift connections,
+                this can be None or the name of some supported compression
+                type (like "GZIP"). For native connections, this is treated
+                as a boolean, and if true, the connection will try to find
+                a type of compression supported by both sides.
+
+    @returns a Connection instance of the appropriate subclass.
+    """
+
      if native:
          from native import NativeConnection
          connclass = NativeConnection
+        if port is None:
+            port = 8000
      else:
          from thrifteries import ThriftConnection
          connclass = ThriftConnection
-    return connclass(host, port, keyspace, user, password, cql_version)
+        if port is None:
+            port = 9160
+    return connclass(host, port, keyspace, user, password,
+                     cql_version=cql_version, compression=compression)
=======================================
--- /cql/native.py	Mon Sep 17 04:34:57 2012
+++ /cql/native.py	Tue Sep 25 11:26:45 2012
@@ -85,12 +85,15 @@
                                   % (self.__class__.__name__, pname))
              setattr(self, pname, pval)

-    def send(self, f, streamid, compression=False):
+    def send(self, f, streamid, compression=None):
          body = StringIO()
          self.send_body(body)
          body = body.getvalue()
          version = PROTOCOL_VERSION | HEADER_DIRECTION_FROM_CLIENT
-        flags = 0 # no compression supported yet
+        flags = 0
+        if compression is not None and len(body) > 0:
+            body = compression(body)
+            flags |= 0x1
          msglen = int32_pack(len(body))
          header = '%c%c%c%c%s' % (version, flags, streamid, self.opcode,  
msglen)
          f.write(header)
@@ -102,7 +105,7 @@
          return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
      __repr__ = __str__

-def read_frame(f):
+def read_frame(f, decompressor=None):
      header = f.read(8)
      version, flags, stream, opcode = map(ord, header[:4])
      body_len = int32_unpack(header[4:])
@@ -111,9 +114,14 @@
      assert version & HEADER_DIRECTION_MASK == HEADER_DIRECTION_TO_CLIENT, \
              "Unexpected request from server with opcode %04x, stream  
id %r" % (opcode, stream)
      assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len
+    body = f.read(body_len)
+    if flags & 0x1:
+        if decompressor is None:
+            raise ProtocolException("No decompressor available for  
compressed frame!")
+        body = decompressor(body)
+        flags ^= 0x1
      if flags:
          warn("Unknown protocol flags set: %02x. May cause problems." %  
flags)
-    body = f.read(body_len)
      msgclass = _message_types_by_opcode[opcode]
      msg = msgclass.recv_body(StringIO(body))
      msg.stream_id = stream
@@ -670,10 +678,10 @@
          self.rowcount = len(self.result)

      def get_compression(self):
-        return None
+        return self._connection.compression

      def set_compression(self, val):
-        if val is not None:
+        if val != self.get_compression():
              raise NotImplementedError("Setting per-cursor compression is  
not "
                                        "supported in NativeCursor.")

@@ -702,6 +710,20 @@
      def close(self):
          pass

+locally_supported_compressions = {}
+
+try:
+    import snappy
+except ImportError:
+    pass
+else:
+    # work around apparently buggy snappy decompress
+    def decompress(byts):
+        if byts == '\x00':
+            return ''
+        return snappy.decompress(byts)
+    locally_supported_compressions['snappy'] = (snappy.compress,  
decompress)
+
  class NativeConnection(Connection):
      cursorclass = NativeCursor

@@ -710,6 +732,7 @@
          self.responses = {}
          self.waiting = {}
          self.conn_ready = False
+        self.compressor = self.decompressor = None
          Connection.__init__(self, *args, **kwargs)

      def establish_connection(self):
@@ -721,7 +744,7 @@
          self.open_socket = True
          supported = self.wait_for_request(OptionsMessage())
          self.supported_cql_versions = supported.cqlversions
-        self.supported_compressions = supported.options['COMPRESSION']
+        self.remote_supported_compressions =  
supported.options['COMPRESSION']

          if self.cql_version:
              if self.cql_version not in self.supported_cql_versions:
@@ -733,20 +756,30 @@
              self.cql_version = self.supported_cql_versions[0]

          opts = {}
+        compresstype = None
          if self.compression:
-            if self.compression not in self.supported_compressions:
-                raise ProgrammingError("Compression type %r is not  
supported by"
-                                       " remote. Supported compression  
types: %r"
-                                       % (self.compression,  
self.supported_compressions))
-            # XXX: Remove this once some compressions are supported
-            raise NotImplementedError("CQL driver does not yet support  
compression")
-            opts['COMPRESSION'] = self.compression
+            overlap = set(locally_supported_compressions) \
+                    & set(self.remote_supported_compressions)
+            if len(overlap) == 0:
+                warn("No available compression types supported on both  
ends."
+                     " locally supported: %r. remotely supported: %r"
+                     % (locally_supported_compressions,
+                        self.remote_supported_compressions))
+            else:
+                compresstype = iter(overlap).next() # choose any
+                opts['COMPRESSION'] = compresstype
+                compr, decompr =  
locally_supported_compressions[compresstype]
+                # set the decompressor here, but set the compressor only  
after
+                # a successful Ready message
+                self.decompressor = decompr

          sm = StartupMessage(cqlversion=self.cql_version, options=opts)
          startup_response = self.wait_for_request(sm)
          while True:
              if isinstance(startup_response, ReadyMessage):
                  self.conn_ready = True
+                if compresstype:
+                    self.compressor = compr
                  break
              if isinstance(startup_response, AuthenticateMessage):
                  self.authenticator = startup_response.authenticator
@@ -779,6 +812,11 @@

          return self.wait_for_requests(msg)[0]

+    def send_msg(self, msg):
+        reqid = self.make_reqid()
+        msg.send(self.socketf, reqid, compression=self.compressor)
+        return reqid
+
      def wait_for_requests(self, *msgs):
          """
          Given any number of message objects, send them all to the server
@@ -789,9 +827,8 @@

          reqids = []
          for msg in msgs:
-            reqid = self.make_reqid()
+            reqid = self.send_msg(msg)
              reqids.append(reqid)
-            msg.send(self.socketf, reqid)
          resultdict = self.wait_for_results(*reqids)
          return [resultdict[reqid] for reqid in reqids]

@@ -813,7 +850,7 @@
                  results[r] = result
                  waiting_for.remove(r)
          while waiting_for:
-            newmsg = read_frame(self.socketf)
+            newmsg = read_frame(self.socketf,  
decompressor=self.decompressor)
              if newmsg.stream_id in waiting_for:
                  results[newmsg.stream_id] = newmsg
                  waiting_for.remove(newmsg.stream_id)
@@ -867,6 +904,5 @@
          it may have to wait until something else waits on a result.
          """

-        reqid = self.make_reqid()
-        msg.send(self.socketf, reqid)
+        reqid = self.send_msg(msg)
          self.callback_when(reqid, cb)

==============================================================================
Revision: 56d24cd277c2
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 13:44:33 2012
Log:      update tests; move thrift_client in test_cql

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=56d24cd277c2

Modified:
  /cql/cqltypes.py
  /cql/cursor.py
  /cql/native.py
  /test/test_connection.py
  /test/test_cql.py
  /test/test_prepared_queries.py

=======================================
--- /cql/cqltypes.py	Wed Sep 12 13:32:53 2012
+++ /cql/cqltypes.py	Tue Sep 25 13:44:33 2012
@@ -137,7 +137,7 @@

      """

-    if isinstance(casstype, CassandraType):
+    if isinstance(casstype, (CassandraType, CassandraTypeType)):
          return casstype
      try:
          return parse_casstype_args(casstype)
=======================================
--- /cql/cursor.py	Tue Sep 11 17:31:33 2012
+++ /cql/cursor.py	Tue Sep 25 13:44:33 2012
@@ -189,5 +189,5 @@
      ###

      def __checksock(self):
-        if self._connection is None:
+        if self._connection is None or not self._connection.open_socket:
              raise cql.ProgrammingError("Cursor has been closed.")
=======================================
--- /cql/native.py	Tue Sep 25 11:26:45 2012
+++ /cql/native.py	Tue Sep 25 13:44:33 2012
@@ -388,7 +388,7 @@
          return CqlResult(column_metadata=colspecs, rows=rows)

      @classmethod
-    def recv_results_prepared(self, f):
+    def recv_results_prepared(cls, f):
          queryid = read_int(f)
          colspecs = cls.recv_results_metadata(f)
          return (queryid, colspecs)
@@ -625,7 +625,8 @@
          return self._connection.wait_for_request(QueryMessage(query=query))

      def get_response_prepared(self, prepared_query, params):
-        em = ExecuteMessage(queryid=prepared_query.itemid,  
queryparams=params)
+        qparams = [params[pname] for pname in prepared_query.paramnames]
+        em = ExecuteMessage(queryid=prepared_query.itemid,  
queryparams=qparams)
          return self._connection.wait_for_request(em)

      def get_column_metadata(self, column_id):
=======================================
--- /test/test_connection.py	Thu Sep 20 10:38:41 2012
+++ /test/test_connection.py	Tue Sep 25 13:44:33 2012
@@ -29,34 +29,35 @@
  randstring = test_cql.randstring
  del test_cql

+@contextlib.contextmanager
+def with_keyspace(randstr, cursor, cqlver):
+    ksname = randstr + '_conntest_' +  
cqlver.encode('ascii').replace('.', '_')
+    if cqlver.startswith('2.'):
+        cursor.execute("create keyspace '%s' with  
strategy_class='SimpleStrategy'"
+                       " and strategy_options:replication_factor=1;" %  
ksname)
+        cursor.execute("use '%s'" % ksname)
+        yield ksname
+        cursor.execute("use system;")
+        cursor.execute("drop keyspace '%s'" % ksname)
+    elif cqlver == '3.0.0-beta1': # for cassandra 1.1
+        cursor.execute("create keyspace \"%s\" with  
strategy_class='SimpleStrategy'"
+                       " and strategy_options:replication_factor=1;" %  
ksname)
+        cursor.execute('use "%s"' % ksname)
+        yield ksname
+        cursor.execute('use system;')
+        cursor.execute('drop keyspace "%s"' % ksname)
+    else:
+        cursor.execute("create keyspace \"%s\" with replication = "
+                       "{'class': 'SimpleStrategy', 'replication_factor':  
1};" % ksname)
+        cursor.execute('use "%s"' % ksname)
+        yield ksname
+        cursor.execute('use system;')
+        cursor.execute('drop keyspace "%s"' % ksname)
+
  class TestConnection(unittest.TestCase):
      def setUp(self):
          self.randstr = randstring()
-
-    @contextlib.contextmanager
-    def with_keyspace(self, cursor, cqlver):
-        ksname = self.randstr + '_conntest_' + cqlver.replace('.', '_')
-        if cqlver.startswith('2.'):
-            cursor.execute("create keyspace '%s' with  
strategy_class='SimpleStrategy'"
-                           " and strategy_options:replication_factor=1;" %  
ksname)
-            cursor.execute("use '%s'" % ksname)
-            yield ksname
-            cursor.execute("use system;")
-            cursor.execute("drop keyspace '%s'" % ksname)
-        elif cqlver == '3.0.0-beta1': # for cassandra 1.1
-            cursor.execute("create keyspace \"%s\" with  
strategy_class='SimpleStrategy'"
-                           " and strategy_options:replication_factor=1;" %  
ksname)
-            cursor.execute('use "%s"' % ksname)
-            yield ksname
-            cursor.execute('use system;')
-            cursor.execute('drop keyspace "%s"' % ksname)
-        else:
-            cursor.execute("create keyspace \"%s\" with replication = "
-                           "{'class': 'SimpleStrategy', 'replication_factor':  
1};" %  
ksname)
-            cursor.execute('use "%s"' % ksname)
-            yield ksname
-            cursor.execute('use system;')
-            cursor.execute('drop keyspace "%s"' % ksname)
+        self.with_keyspace = lambda curs, ver: with_keyspace(self.randstr,  
curs, ver)

      def test_connecting_with_cql_version(self):
          conn = cql.connect(TEST_HOST, TEST_PORT, cql_version='2.0.0')
@@ -100,4 +101,4 @@
              curs.execute('create table blah (a int primary key, b int);')
              curs.execute('select * from blah;')
          conn.close()
-        self.assertRaises(TTransport.TTransportException,  
curs.execute, 'select * from blah;')
+        self.assertRaises(cql.ProgrammingError, curs.execute, 'select *  
from blah;')
=======================================
--- /test/test_cql.py	Thu Sep 20 10:58:57 2012
+++ /test/test_cql.py	Tue Sep 25 13:44:33 2012
@@ -52,7 +52,6 @@
      client.transport = transport
      client.transport.open()
      return client
-thrift_client = get_thrift_client()

  def uuid1bytes_to_millis(uuidbytes):
      return (uuid.UUID(bytes=uuidbytes).get_time() / 10000) -  
12219292800000L
@@ -164,6 +163,8 @@
      keyspace = None

      def setUp(self):
+        self.thrift_client = get_thrift_client()
+
          # all tests in this module are against cql 2. change would be  
welcomed.
          dbconn = cql.connect(TEST_HOST, TEST_PORT, cql_version='2.0.0')
          self.cursor = dbconn.cursor()
@@ -186,7 +187,7 @@
          return ksname

      def get_partitioner(self):
-        return thrift_client.describe_partitioner()
+        return self.thrift_client.describe_partitioner()

      def assertIsSubclass(self, class_a, class_b):
          assert issubclass(class_a, class_b), '%r is not a subclass  
of %r' % (class_a, class_b)
@@ -519,13 +520,13 @@
          """, {'ks': ksname2})

          # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname1)
+        ksdef = self.thrift_client.describe_keyspace(ksname1)

          strategy_class  
= "org.apache.cassandra.locator.NetworkTopologyStrategy"
          self.assertEqual(ksdef.strategy_class, strategy_class)
          self.assertEqual(ksdef.strategy_options['DC1'], "1")

-        ksdef = thrift_client.describe_keyspace(ksname2)
+        ksdef = self.thrift_client.describe_keyspace(ksname2)

          strategy_class  
= "org.apache.cassandra.locator.NetworkTopologyStrategy"
          self.assertEqual(ksdef.strategy_class, strategy_class)
@@ -542,14 +543,14 @@
          """, {'ks': ksname})

          # TODO: temporary (until this can be done with CQL).
-        thrift_client.describe_keyspace(ksname)
+        self.thrift_client.describe_keyspace(ksname)

          cursor.execute('DROP SCHEMA :ks;', {'ks': ksname})

          # Technically this should throw a ttypes.NotFound(), but this is
          # temporary and so not worth requiring it on PYTHONPATH.
          self.assertRaises(Exception,
-                          thrift_client.describe_keyspace,
+                          self.thrift_client.describe_keyspace,
                            ksname)

      def test_create_column_family(self):
@@ -573,7 +574,7 @@
          """)

          # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 1)
          cfam= ksdef.cf_defs[0]
          self.assertEqual(len(cfam.column_metadata), 4)
@@ -597,7 +598,7 @@
          # No column defs
          cursor.execute("""CREATE COLUMNFAMILY NewCf3
                              (KEY varint PRIMARY KEY) WITH comparator =  
bigint""")
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 2)
          cfam = [i for i in ksdef.cf_defs if i.name == "NewCf3"][0]
           
self.assertEqual(cfam.comparator_type, "org.apache.cassandra.db.marshal.LongType")
@@ -606,7 +607,7 @@
          cursor.execute("""CREATE COLUMNFAMILY NewCf4
                              (KEY varint PRIMARY KEY, 'a' varint, 'b'  
varint)
                              WITH comparator = text;""")
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 3)
          cfam = [i for i in ksdef.cf_defs if i.name == "NewCf4"][0]
          self.assertEqual(len(cfam.column_metadata), 2)
@@ -626,12 +627,12 @@
          cursor.execute('CREATE COLUMNFAMILY CF4Drop (KEY varint PRIMARY  
KEY);')

          # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          assert len(ksdef.cf_defs), "Column family not created!"

          cursor.execute('DROP COLUMNFAMILY CF4Drop;')

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          assert not len(ksdef.cf_defs), "Column family not deleted!"

      def test_create_indexs(self):
@@ -643,7 +644,7 @@
          cursor.execute("CREATE INDEX ON CreateIndex1 (stuff)")

          # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(self.keyspace)
+        ksdef = self.thrift_client.describe_keyspace(self.keyspace)
          cfam = [i for i in ksdef.cf_defs if i.name == "CreateIndex1"][0]
          items = [i for i in cfam.column_metadata if i.name == "items"][0]
          stuff = [i for i in cfam.column_metadata if i.name == "stuff"][0]
@@ -667,7 +668,7 @@
          cursor.execute("CREATE COLUMNFAMILY IndexedCF (KEY text PRIMARY  
KEY, n text)")
          cursor.execute("CREATE INDEX namedIndex ON IndexedCF (n)")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          columns = ksdef.cf_defs[0].column_metadata

          self.assertEqual(columns[0].index_name, "namedIndex")
@@ -676,7 +677,7 @@
          # testing "DROP INDEX <INDEX_NAME>"
          cursor.execute("DROP INDEX namedIndex")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          columns = ksdef.cf_defs[0].column_metadata

          self.assertEqual(columns[0].index_type, None)
@@ -1243,7 +1244,7 @@
          """)

          # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 1)
          cfam = ksdef.cf_defs[0]

@@ -1252,7 +1253,7 @@
          # testing "add a new column"
          cursor.execute("ALTER COLUMNFAMILY NewCf1 ADD name varchar")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 1)
          columns = ksdef.cf_defs[0].column_metadata

@@ -1263,7 +1264,7 @@
          # testing "alter a column type"
          cursor.execute("ALTER COLUMNFAMILY NewCf1 ALTER name TYPE ascii")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 1)
          columns = ksdef.cf_defs[0].column_metadata

@@ -1279,7 +1280,7 @@
          # testing 'drop an existing column'
          cursor.execute("ALTER COLUMNFAMILY NewCf1 DROP name")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          self.assertEqual(len(ksdef.cf_defs), 1)
          columns = ksdef.cf_defs[0].column_metadata

@@ -1396,7 +1397,7 @@
          """)

          # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
          cfdef = ksdef.cf_defs[0]

          self.assertEqual(len(ksdef.cf_defs), 1)
=======================================
--- /test/test_prepared_queries.py	Tue Sep 11 16:26:59 2012
+++ /test/test_prepared_queries.py	Tue Sep 25 13:44:33 2012
@@ -25,7 +25,7 @@

  TEST_HOST = os.environ.get('CQL_TEST_HOST', 'localhost')
  TEST_PORT = int(os.environ.get('CQL_TEST_PORT', 9170))
-TEST_CQL_VERSION = '3.0.0-beta1'
+TEST_CQL_VERSION = os.environ.get('CQL_TEST_VERSION', '3.0.0-beta1')

  sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

@@ -40,7 +40,7 @@
      def setUp(self):
          try:
              self.dbconn = cql.connect(TEST_HOST, TEST_PORT,  
cql_version=TEST_CQL_VERSION)
-        except cql.cursor.TApplicationException:
+        except cql.thrifteries.TApplicationException:
              # set_cql_version (and thus, cql3) not supported; skip all of  
these
              self.cursor = None
              return

==============================================================================
Revision: 553647f4b1b9
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 13:46:01 2012
Log:      add basic tests for native protocol

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=553647f4b1b9

Added:
  /test/test_native_connection.py

=======================================
--- /dev/null
+++ /test/test_native_connection.py	Tue Sep 25 13:46:01 2012
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# to configure behavior, define $CQL_TEST_HOST to the destination address
+# for native connections, and $CQL_TEST_NATIVE_PORT to the associated port.
+
+import os
+import unittest
+import contextlib
+from thrift.transport import TTransport
+import test_cql
+from test_prepared_queries import MIN_THRIFT_FOR_CQL_3_0_0_FINAL
+from test_connection import with_keyspace, TEST_HOST, randstring, cql
+
+TEST_NATIVE_PORT = int(os.environ.get('CQL_TEST_NATIVE_PORT', '8000'))
+
+class TestNativeConnection(unittest.TestCase):
+    def setUp(self):
+        self.randstr = randstring()
+        self.with_keyspace = lambda curs, ver: with_keyspace(self.randstr,  
curs, ver)
+
+    def test_connecting_with_cql_version(self):
+        # 2.0.0 won't be supported by binary protocol
+        self.assertRaises(cql.ProgrammingError,
+                          cql.connect, TEST_HOST, TEST_NATIVE_PORT,
+                          native=True, cql_version='2.0.0')
+
+    def test_connecting_with_keyspace(self):
+        # this conn is just for creating the keyspace
+        conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True)
+        curs = conn.cursor()
+        with self.with_keyspace(curs, conn.cql_version) as ksname:
+            curs.execute('create table blah1_%s (a int primary key, b  
int);' % self.randstr)
+            conn2 = cql.connect(TEST_HOST, TEST_NATIVE_PORT,  
keyspace=ksname,
+                                native=True, cql_version=conn.cql_version)
+            curs2 = conn2.cursor()
+            curs2.execute('select * from blah1_%s;' % self.randstr)
+            conn2.close()
+
+    def test_execution_fails_after_close(self):
+        conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True)
+        curs = conn.cursor()
+        with self.with_keyspace(curs, conn.cql_version) as ksname:
+            curs.execute('create table blah (a int primary key, b int);')
+            curs.execute('select * from blah;')
+        conn.close()
+        self.assertRaises(cql.ProgrammingError, curs.execute, 'select *  
from blah;')
+
+    def try_basic_stuff(self, conn):
+        curs = conn.cursor()
+        with self.with_keyspace(curs, conn.cql_version) as ksname:
+            curs.execute('create table moo (a text primary key, b int, c  
float);')
+            curs.execute("insert into moo (a, b, c) values (:d, :e, :f);",
+                         {'d': 'hi', 'e': 1234, 'f': 1.234});
+            qprep = curs.prepare_query("select * from moo where a  
= :fish;")
+            curs.execute_prepared(qprep, {'fish': 'hi'})
+            res = curs.fetchall()
+            self.assertEqual(len(res), 1)
+            self.assertEqual(res[0][0], 'hi')
+            self.assertEqual(res[0][1], 1234)
+            self.assertAlmostEqual(res[0][2], 1.234)
+
+    def test_connecting_without_compression(self):
+        conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True,  
compression=False)
+        self.assertEqual(conn.compressor, None)
+        self.try_basic_stuff(conn)
+
+    def test_connecting_with_compression(self):
+        try:
+            import snappy
+        except ImportError:
+            if hasattr(unittest, 'skipTest'):
+                unittest.skipTest('Snappy compression not available')
+            else:
+                return
+        conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True,  
compression=True)
+        self.assertEqual(conn.compressor, snappy.compress)
+        self.try_basic_stuff(conn)

==============================================================================
Revision: 96b064c3159b
Author:   paul cannon <pa...@thepaul.org>
Date:     Tue Sep 25 13:45:33 2012
Log:      support for callbacks on native-proto events

i.e., STATUS_CHANGE, TOPOLOGY_CHANGE

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=96b064c3159b

Modified:
  /cql/native.py

=======================================
--- /cql/native.py	Tue Sep 25 13:44:33 2012
+++ /cql/native.py	Tue Sep 25 13:45:33 2012
@@ -15,7 +15,8 @@
  # limitations under the License.

  import cql
-from cql.marshal import int32_pack, int32_unpack, uint16_pack,  
uint16_unpack
+from cql.marshal import (int32_pack, int32_unpack, uint16_pack,  
uint16_unpack,
+                         int8_pack, int8_unpack)
  from cql.cqltypes import lookup_cqltype
  from cql.connection import Connection
  from cql.cursor import Cursor, _VOID_DESCRIPTION, _COUNT_DESCRIPTION
@@ -32,8 +33,6 @@
  PROTOCOL_VERSION             = 0x01
  PROTOCOL_VERSION_MASK        = 0x7f

-# XXX: should these be called request/response instead? unclear which one  
will
-# apply if/when the server initiates streams in the other direction.
  HEADER_DIRECTION_FROM_CLIENT = 0x00
  HEADER_DIRECTION_TO_CLIENT   = 0x80
  HEADER_DIRECTION_MASK        = 0x80
@@ -95,7 +94,8 @@
              body = compression(body)
              flags |= 0x1
          msglen = int32_pack(len(body))
-        header = '%c%c%c%c%s' % (version, flags, streamid, self.opcode,  
msglen)
+        header = ''.join(map(int8_pack, (version, flags, streamid,  
self.opcode))) \
+                 + msglen
          f.write(header)
          if len(body) > 0:
              f.write(body)
@@ -107,7 +107,7 @@

  def read_frame(f, decompressor=None):
      header = f.read(8)
-    version, flags, stream, opcode = map(ord, header[:4])
+    version, flags, stream, opcode = map(int8_unpack, header[:4])
      body_len = int32_unpack(header[4:])
      assert version & PROTOCOL_VERSION_MASK == PROTOCOL_VERSION, \
              "Unsupported CQL protocol version %d" % version
@@ -496,10 +496,10 @@


  def read_byte(f):
-    return ord(f.read(1))
+    return int8_unpack(f.read(1))

  def write_byte(f, b):
-    f.write(chr(b))
+    f.write(int8_pack(b))

  def read_int(f):
      return int32_unpack(f.read(4))
@@ -734,6 +734,7 @@
          self.waiting = {}
          self.conn_ready = False
          self.compressor = self.decompressor = None
+        self.event_watchers = {}
          Connection.__init__(self, *args, **kwargs)

      def establish_connection(self):
@@ -838,6 +839,10 @@
          Given any number of stream-ids, wait until responses have arrived  
for
          each one, and return a dictionary mapping the stream-ids to the
          appropriate results.
+
+        For internal use, None may be passed in place of a reqid, which  
will
+        be considered satisfied when a message of any kind is received  
(and, if
+        appropriate, handled).
          """

          waiting_for = set(reqids)
@@ -857,6 +862,9 @@
                  waiting_for.remove(newmsg.stream_id)
              else:
                  self.handle_incoming(newmsg)
+            if None in waiting_for:
+                results[None] = newmsg
+                waiting_for.remove(None)
          return results

      def wait_for_result(self, reqid):
@@ -907,3 +915,79 @@

          reqid = self.send_msg(msg)
          self.callback_when(reqid, cb)
+
+    def handle_pushed(self, msg):
+        """
+        Process an incoming message originated by the server.
+        """
+        watchers = self.event_watchers.get(msg.eventtype, ())
+        for cb in watchers:
+            cb(msg.eventargs)
+
+    def register_watcher(self, eventtype, cb):
+        """
+        Request that any events of the given type be passed to the given
+        callback when they arrive. Note that the callback may not be called
+        immediately upon the arrival of the event packet; it may have to  
wait
+        until something else waits on a result, or until wait_for_even() is
+        called.
+
+        If the event type has not been registered for already, this may
+        block while a new REGISTER message is sent to the server.
+
+        The available event types are in the cql.native.known_event_types
+        list.
+
+        When an event arrives, a dictionary will be passed to the callback
+        with the info about the event. Some example result dictionaries:
+
+        (For STATUS_CHANGE events:)
+
+          {'changetype': u'DOWN', 'address': ('12.114.19.76', 8000)}
+
+        (For TOPOLOGY_CHANGE events:)
+
+          {'changetype': u'NEW_NODE', 'address': ('19.10.122.13', 8000)}
+        """
+
+        if isinstance(eventtype, str):
+            eventtype = eventtype.decode('utf8')
+        try:
+            watchers = self.event_watchers[eventtype]
+        except KeyError:
+            ans =  
self.wait_for_request(RegisterMessage(eventlist=(eventtype,)))
+            if isinstance(ans, ErrorMessage):
+                raise cql.ProgrammingError("Server did not accept  
registration"
+                                           " for %s events: %s"
+                                           % (eventtype, ans.summarymsg()))
+            watchers = self.event_watchers.setdefault(eventtype, [])
+        watchers.append(cb)
+
+    def unregister_watcher(self, eventtype, cb):
+        """
+        Given an eventtype and a callback previously registered with
+        register_watcher(), remove that callback from the list of watchers  
for
+        the given event type.
+        """
+
+        if isinstance(eventtype, str):
+            eventtype = eventtype.decode('utf8')
+        self.event_watchers[eventtype].remove(cb)
+
+    def wait_for_event(self):
+        """
+        Wait for any sort of event to arrive, and handle it via the
+        registered callbacks. It is recommended that some event watchers
+        be registered before calling this; otherwise, no events will be
+        sent by the server.
+        """
+        eventsseen = []
+        def i_saw_an_event(ev):
+            eventsseen.append(ev)
+        wlists = self.event_watchers.values()
+        for wlist in wlists:
+            wlist.append(i_saw_an_event)
+        while not eventsseen:
+            self.wait_for_result(None)
+        for wlist in wlists:
+            wlist.remove(i_saw_an_event)