You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by ca...@codespot.com on 2012/09/25 23:00:20 UTC
[cassandra-dbapi2] 4 new revisions pushed by pcannon@gmail.com on
2012-09-25 21:00 GMT
4 new revisions:
Revision: d860b4b2250d
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 11:26:45 2012
Log: support snappy compression, if installed
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=d860b4b2250d
Revision: 56d24cd277c2
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 13:44:33 2012
Log: update tests; move thrift_client in test_cql
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=56d24cd277c2
Revision: 553647f4b1b9
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 13:46:01 2012
Log: add basic tests for native protocol
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=553647f4b1b9
Revision: 96b064c3159b
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 13:45:33 2012
Log: support for callbacks on native-proto events...
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=96b064c3159b
==============================================================================
Revision: d860b4b2250d
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 11:26:45 2012
Log: support snappy compression, if installed
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=d860b4b2250d
Modified:
/cql/connection.py
/cql/native.py
=======================================
--- /cql/connection.py Mon Sep 17 04:34:57 2012
+++ /cql/connection.py Tue Sep 25 11:26:45 2012
@@ -29,8 +29,12 @@
* user .........: username used in authentication (optional).
* password .....: password used in authentication (optional).
* cql_version...: CQL version to use (optional).
- * compression...: the sort of compression to use by default;
- * overrideable per Cursor object. (optional).
+ * compression...: whether to use compression. For Thrift
connections,
+ * this can be None or the name of some supported
+ * compression type (like "GZIP"). For native
+ * connections, this is treated as a boolean, and if
+ * true, the connection will try to find a type of
+ * compression supported by both sides.
"""
self.host = host
self.port = port
@@ -85,12 +89,37 @@
return curs
# TODO: Pull connections out of a pool instead.
-def connect(host, port=9160, keyspace=None, user=None, password=None,
- cql_version=None, native=False):
+def connect(host, port=None, keyspace=None, user=None, password=None,
+ cql_version=None, native=False, compression=None):
+ """
+ Create a connection to a Cassandra node.
+
+ @param host Hostname of Cassandra node.
+ @param port Port number to connect to (default 9160 for thrift, 8000
+ for native)
+ @param keyspace If set, authenticate to this keyspace on connection.
+ @param user If set, use this username in authentication.
+ @param password If set, use this password in authentication.
+ @param cql_version If set, try to use the given CQL version. If unset,
+ uses the default for the connection.
+ @param compression Whether to use compression. For Thrift connections,
+ this can be None or the name of some supported compression
+ type (like "GZIP"). For native connections, this is treated
+ as a boolean, and if true, the connection will try to find
+ a type of compression supported by both sides.
+
+ @returns a Connection instance of the appropriate subclass.
+ """
+
if native:
from native import NativeConnection
connclass = NativeConnection
+ if port is None:
+ port = 8000
else:
from thrifteries import ThriftConnection
connclass = ThriftConnection
- return connclass(host, port, keyspace, user, password, cql_version)
+ if port is None:
+ port = 9160
+ return connclass(host, port, keyspace, user, password,
+ cql_version=cql_version, compression=compression)
=======================================
--- /cql/native.py Mon Sep 17 04:34:57 2012
+++ /cql/native.py Tue Sep 25 11:26:45 2012
@@ -85,12 +85,15 @@
% (self.__class__.__name__, pname))
setattr(self, pname, pval)
- def send(self, f, streamid, compression=False):
+ def send(self, f, streamid, compression=None):
body = StringIO()
self.send_body(body)
body = body.getvalue()
version = PROTOCOL_VERSION | HEADER_DIRECTION_FROM_CLIENT
- flags = 0 # no compression supported yet
+ flags = 0
+ if compression is not None and len(body) > 0:
+ body = compression(body)
+ flags |= 0x1
msglen = int32_pack(len(body))
header = '%c%c%c%c%s' % (version, flags, streamid, self.opcode,
msglen)
f.write(header)
@@ -102,7 +105,7 @@
return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
__repr__ = __str__
-def read_frame(f):
+def read_frame(f, decompressor=None):
header = f.read(8)
version, flags, stream, opcode = map(ord, header[:4])
body_len = int32_unpack(header[4:])
@@ -111,9 +114,14 @@
assert version & HEADER_DIRECTION_MASK == HEADER_DIRECTION_TO_CLIENT, \
"Unexpected request from server with opcode %04x, stream
id %r" % (opcode, stream)
assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len
+ body = f.read(body_len)
+ if flags & 0x1:
+ if decompressor is None:
+ raise ProtocolException("No decompressor available for
compressed frame!")
+ body = decompressor(body)
+ flags ^= 0x1
if flags:
warn("Unknown protocol flags set: %02x. May cause problems." %
flags)
- body = f.read(body_len)
msgclass = _message_types_by_opcode[opcode]
msg = msgclass.recv_body(StringIO(body))
msg.stream_id = stream
@@ -670,10 +678,10 @@
self.rowcount = len(self.result)
def get_compression(self):
- return None
+ return self._connection.compression
def set_compression(self, val):
- if val is not None:
+ if val != self.get_compression():
raise NotImplementedError("Setting per-cursor compression is
not "
"supported in NativeCursor.")
@@ -702,6 +710,20 @@
def close(self):
pass
+locally_supported_compressions = {}
+
+try:
+ import snappy
+except ImportError:
+ pass
+else:
+ # work around apparently buggy snappy decompress
+ def decompress(byts):
+ if byts == '\x00':
+ return ''
+ return snappy.decompress(byts)
+ locally_supported_compressions['snappy'] = (snappy.compress,
decompress)
+
class NativeConnection(Connection):
cursorclass = NativeCursor
@@ -710,6 +732,7 @@
self.responses = {}
self.waiting = {}
self.conn_ready = False
+ self.compressor = self.decompressor = None
Connection.__init__(self, *args, **kwargs)
def establish_connection(self):
@@ -721,7 +744,7 @@
self.open_socket = True
supported = self.wait_for_request(OptionsMessage())
self.supported_cql_versions = supported.cqlversions
- self.supported_compressions = supported.options['COMPRESSION']
+ self.remote_supported_compressions =
supported.options['COMPRESSION']
if self.cql_version:
if self.cql_version not in self.supported_cql_versions:
@@ -733,20 +756,30 @@
self.cql_version = self.supported_cql_versions[0]
opts = {}
+ compresstype = None
if self.compression:
- if self.compression not in self.supported_compressions:
- raise ProgrammingError("Compression type %r is not
supported by"
- " remote. Supported compression
types: %r"
- % (self.compression,
self.supported_compressions))
- # XXX: Remove this once some compressions are supported
- raise NotImplementedError("CQL driver does not yet support
compression")
- opts['COMPRESSION'] = self.compression
+ overlap = set(locally_supported_compressions) \
+ & set(self.remote_supported_compressions)
+ if len(overlap) == 0:
+ warn("No available compression types supported on both
ends."
+ " locally supported: %r. remotely supported: %r"
+ % (locally_supported_compressions,
+ self.remote_supported_compressions))
+ else:
+ compresstype = iter(overlap).next() # choose any
+ opts['COMPRESSION'] = compresstype
+ compr, decompr =
locally_supported_compressions[compresstype]
+ # set the decompressor here, but set the compressor only
after
+ # a successful Ready message
+ self.decompressor = decompr
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
startup_response = self.wait_for_request(sm)
while True:
if isinstance(startup_response, ReadyMessage):
self.conn_ready = True
+ if compresstype:
+ self.compressor = compr
break
if isinstance(startup_response, AuthenticateMessage):
self.authenticator = startup_response.authenticator
@@ -779,6 +812,11 @@
return self.wait_for_requests(msg)[0]
+ def send_msg(self, msg):
+ reqid = self.make_reqid()
+ msg.send(self.socketf, reqid, compression=self.compressor)
+ return reqid
+
def wait_for_requests(self, *msgs):
"""
Given any number of message objects, send them all to the server
@@ -789,9 +827,8 @@
reqids = []
for msg in msgs:
- reqid = self.make_reqid()
+ reqid = self.send_msg(msg)
reqids.append(reqid)
- msg.send(self.socketf, reqid)
resultdict = self.wait_for_results(*reqids)
return [resultdict[reqid] for reqid in reqids]
@@ -813,7 +850,7 @@
results[r] = result
waiting_for.remove(r)
while waiting_for:
- newmsg = read_frame(self.socketf)
+ newmsg = read_frame(self.socketf,
decompressor=self.decompressor)
if newmsg.stream_id in waiting_for:
results[newmsg.stream_id] = newmsg
waiting_for.remove(newmsg.stream_id)
@@ -867,6 +904,5 @@
it may have to wait until something else waits on a result.
"""
- reqid = self.make_reqid()
- msg.send(self.socketf, reqid)
+ reqid = self.send_msg(msg)
self.callback_when(reqid, cb)
==============================================================================
Revision: 56d24cd277c2
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 13:44:33 2012
Log: update tests; move thrift_client in test_cql
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=56d24cd277c2
Modified:
/cql/cqltypes.py
/cql/cursor.py
/cql/native.py
/test/test_connection.py
/test/test_cql.py
/test/test_prepared_queries.py
=======================================
--- /cql/cqltypes.py Wed Sep 12 13:32:53 2012
+++ /cql/cqltypes.py Tue Sep 25 13:44:33 2012
@@ -137,7 +137,7 @@
"""
- if isinstance(casstype, CassandraType):
+ if isinstance(casstype, (CassandraType, CassandraTypeType)):
return casstype
try:
return parse_casstype_args(casstype)
=======================================
--- /cql/cursor.py Tue Sep 11 17:31:33 2012
+++ /cql/cursor.py Tue Sep 25 13:44:33 2012
@@ -189,5 +189,5 @@
###
def __checksock(self):
- if self._connection is None:
+ if self._connection is None or not self._connection.open_socket:
raise cql.ProgrammingError("Cursor has been closed.")
=======================================
--- /cql/native.py Tue Sep 25 11:26:45 2012
+++ /cql/native.py Tue Sep 25 13:44:33 2012
@@ -388,7 +388,7 @@
return CqlResult(column_metadata=colspecs, rows=rows)
@classmethod
- def recv_results_prepared(self, f):
+ def recv_results_prepared(cls, f):
queryid = read_int(f)
colspecs = cls.recv_results_metadata(f)
return (queryid, colspecs)
@@ -625,7 +625,8 @@
return self._connection.wait_for_request(QueryMessage(query=query))
def get_response_prepared(self, prepared_query, params):
- em = ExecuteMessage(queryid=prepared_query.itemid,
queryparams=params)
+ qparams = [params[pname] for pname in prepared_query.paramnames]
+ em = ExecuteMessage(queryid=prepared_query.itemid,
queryparams=qparams)
return self._connection.wait_for_request(em)
def get_column_metadata(self, column_id):
=======================================
--- /test/test_connection.py Thu Sep 20 10:38:41 2012
+++ /test/test_connection.py Tue Sep 25 13:44:33 2012
@@ -29,34 +29,35 @@
randstring = test_cql.randstring
del test_cql
+@contextlib.contextmanager
+def with_keyspace(randstr, cursor, cqlver):
+ ksname = randstr + '_conntest_' +
cqlver.encode('ascii').replace('.', '_')
+ if cqlver.startswith('2.'):
+ cursor.execute("create keyspace '%s' with
strategy_class='SimpleStrategy'"
+ " and strategy_options:replication_factor=1;" %
ksname)
+ cursor.execute("use '%s'" % ksname)
+ yield ksname
+ cursor.execute("use system;")
+ cursor.execute("drop keyspace '%s'" % ksname)
+ elif cqlver == '3.0.0-beta1': # for cassandra 1.1
+ cursor.execute("create keyspace \"%s\" with
strategy_class='SimpleStrategy'"
+ " and strategy_options:replication_factor=1;" %
ksname)
+ cursor.execute('use "%s"' % ksname)
+ yield ksname
+ cursor.execute('use system;')
+ cursor.execute('drop keyspace "%s"' % ksname)
+ else:
+ cursor.execute("create keyspace \"%s\" with replication = "
+ "{'class': 'SimpleStrategy', 'replication_factor':
1};" % ksname)
+ cursor.execute('use "%s"' % ksname)
+ yield ksname
+ cursor.execute('use system;')
+ cursor.execute('drop keyspace "%s"' % ksname)
+
class TestConnection(unittest.TestCase):
def setUp(self):
self.randstr = randstring()
-
- @contextlib.contextmanager
- def with_keyspace(self, cursor, cqlver):
- ksname = self.randstr + '_conntest_' + cqlver.replace('.', '_')
- if cqlver.startswith('2.'):
- cursor.execute("create keyspace '%s' with
strategy_class='SimpleStrategy'"
- " and strategy_options:replication_factor=1;" %
ksname)
- cursor.execute("use '%s'" % ksname)
- yield ksname
- cursor.execute("use system;")
- cursor.execute("drop keyspace '%s'" % ksname)
- elif cqlver == '3.0.0-beta1': # for cassandra 1.1
- cursor.execute("create keyspace \"%s\" with
strategy_class='SimpleStrategy'"
- " and strategy_options:replication_factor=1;" %
ksname)
- cursor.execute('use "%s"' % ksname)
- yield ksname
- cursor.execute('use system;')
- cursor.execute('drop keyspace "%s"' % ksname)
- else:
- cursor.execute("create keyspace \"%s\" with replication = "
- "{'class': 'SimpleStrategy', 'replication_factor':
1};" %
ksname)
- cursor.execute('use "%s"' % ksname)
- yield ksname
- cursor.execute('use system;')
- cursor.execute('drop keyspace "%s"' % ksname)
+ self.with_keyspace = lambda curs, ver: with_keyspace(self.randstr,
curs, ver)
def test_connecting_with_cql_version(self):
conn = cql.connect(TEST_HOST, TEST_PORT, cql_version='2.0.0')
@@ -100,4 +101,4 @@
curs.execute('create table blah (a int primary key, b int);')
curs.execute('select * from blah;')
conn.close()
- self.assertRaises(TTransport.TTransportException,
curs.execute, 'select * from blah;')
+ self.assertRaises(cql.ProgrammingError, curs.execute, 'select *
from blah;')
=======================================
--- /test/test_cql.py Thu Sep 20 10:58:57 2012
+++ /test/test_cql.py Tue Sep 25 13:44:33 2012
@@ -52,7 +52,6 @@
client.transport = transport
client.transport.open()
return client
-thrift_client = get_thrift_client()
def uuid1bytes_to_millis(uuidbytes):
return (uuid.UUID(bytes=uuidbytes).get_time() / 10000) -
12219292800000L
@@ -164,6 +163,8 @@
keyspace = None
def setUp(self):
+ self.thrift_client = get_thrift_client()
+
# all tests in this module are against cql 2. change would be
welcomed.
dbconn = cql.connect(TEST_HOST, TEST_PORT, cql_version='2.0.0')
self.cursor = dbconn.cursor()
@@ -186,7 +187,7 @@
return ksname
def get_partitioner(self):
- return thrift_client.describe_partitioner()
+ return self.thrift_client.describe_partitioner()
def assertIsSubclass(self, class_a, class_b):
assert issubclass(class_a, class_b), '%r is not a subclass
of %r' % (class_a, class_b)
@@ -519,13 +520,13 @@
""", {'ks': ksname2})
# TODO: temporary (until this can be done with CQL).
- ksdef = thrift_client.describe_keyspace(ksname1)
+ ksdef = self.thrift_client.describe_keyspace(ksname1)
strategy_class
= "org.apache.cassandra.locator.NetworkTopologyStrategy"
self.assertEqual(ksdef.strategy_class, strategy_class)
self.assertEqual(ksdef.strategy_options['DC1'], "1")
- ksdef = thrift_client.describe_keyspace(ksname2)
+ ksdef = self.thrift_client.describe_keyspace(ksname2)
strategy_class
= "org.apache.cassandra.locator.NetworkTopologyStrategy"
self.assertEqual(ksdef.strategy_class, strategy_class)
@@ -542,14 +543,14 @@
""", {'ks': ksname})
# TODO: temporary (until this can be done with CQL).
- thrift_client.describe_keyspace(ksname)
+ self.thrift_client.describe_keyspace(ksname)
cursor.execute('DROP SCHEMA :ks;', {'ks': ksname})
# Technically this should throw a ttypes.NotFound(), but this is
# temporary and so not worth requiring it on PYTHONPATH.
self.assertRaises(Exception,
- thrift_client.describe_keyspace,
+ self.thrift_client.describe_keyspace,
ksname)
def test_create_column_family(self):
@@ -573,7 +574,7 @@
""")
# TODO: temporary (until this can be done with CQL).
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 1)
cfam= ksdef.cf_defs[0]
self.assertEqual(len(cfam.column_metadata), 4)
@@ -597,7 +598,7 @@
# No column defs
cursor.execute("""CREATE COLUMNFAMILY NewCf3
(KEY varint PRIMARY KEY) WITH comparator =
bigint""")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 2)
cfam = [i for i in ksdef.cf_defs if i.name == "NewCf3"][0]
self.assertEqual(cfam.comparator_type, "org.apache.cassandra.db.marshal.LongType")
@@ -606,7 +607,7 @@
cursor.execute("""CREATE COLUMNFAMILY NewCf4
(KEY varint PRIMARY KEY, 'a' varint, 'b'
varint)
WITH comparator = text;""")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 3)
cfam = [i for i in ksdef.cf_defs if i.name == "NewCf4"][0]
self.assertEqual(len(cfam.column_metadata), 2)
@@ -626,12 +627,12 @@
cursor.execute('CREATE COLUMNFAMILY CF4Drop (KEY varint PRIMARY
KEY);')
# TODO: temporary (until this can be done with CQL).
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
assert len(ksdef.cf_defs), "Column family not created!"
cursor.execute('DROP COLUMNFAMILY CF4Drop;')
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
assert not len(ksdef.cf_defs), "Column family not deleted!"
def test_create_indexs(self):
@@ -643,7 +644,7 @@
cursor.execute("CREATE INDEX ON CreateIndex1 (stuff)")
# TODO: temporary (until this can be done with CQL).
- ksdef = thrift_client.describe_keyspace(self.keyspace)
+ ksdef = self.thrift_client.describe_keyspace(self.keyspace)
cfam = [i for i in ksdef.cf_defs if i.name == "CreateIndex1"][0]
items = [i for i in cfam.column_metadata if i.name == "items"][0]
stuff = [i for i in cfam.column_metadata if i.name == "stuff"][0]
@@ -667,7 +668,7 @@
cursor.execute("CREATE COLUMNFAMILY IndexedCF (KEY text PRIMARY
KEY, n text)")
cursor.execute("CREATE INDEX namedIndex ON IndexedCF (n)")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
columns = ksdef.cf_defs[0].column_metadata
self.assertEqual(columns[0].index_name, "namedIndex")
@@ -676,7 +677,7 @@
# testing "DROP INDEX <INDEX_NAME>"
cursor.execute("DROP INDEX namedIndex")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
columns = ksdef.cf_defs[0].column_metadata
self.assertEqual(columns[0].index_type, None)
@@ -1243,7 +1244,7 @@
""")
# TODO: temporary (until this can be done with CQL).
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 1)
cfam = ksdef.cf_defs[0]
@@ -1252,7 +1253,7 @@
# testing "add a new column"
cursor.execute("ALTER COLUMNFAMILY NewCf1 ADD name varchar")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 1)
columns = ksdef.cf_defs[0].column_metadata
@@ -1263,7 +1264,7 @@
# testing "alter a column type"
cursor.execute("ALTER COLUMNFAMILY NewCf1 ALTER name TYPE ascii")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 1)
columns = ksdef.cf_defs[0].column_metadata
@@ -1279,7 +1280,7 @@
# testing 'drop an existing column'
cursor.execute("ALTER COLUMNFAMILY NewCf1 DROP name")
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
self.assertEqual(len(ksdef.cf_defs), 1)
columns = ksdef.cf_defs[0].column_metadata
@@ -1396,7 +1397,7 @@
""")
# TODO: temporary (until this can be done with CQL).
- ksdef = thrift_client.describe_keyspace(ksname)
+ ksdef = self.thrift_client.describe_keyspace(ksname)
cfdef = ksdef.cf_defs[0]
self.assertEqual(len(ksdef.cf_defs), 1)
=======================================
--- /test/test_prepared_queries.py Tue Sep 11 16:26:59 2012
+++ /test/test_prepared_queries.py Tue Sep 25 13:44:33 2012
@@ -25,7 +25,7 @@
TEST_HOST = os.environ.get('CQL_TEST_HOST', 'localhost')
TEST_PORT = int(os.environ.get('CQL_TEST_PORT', 9170))
-TEST_CQL_VERSION = '3.0.0-beta1'
+TEST_CQL_VERSION = os.environ.get('CQL_TEST_VERSION', '3.0.0-beta1')
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@@ -40,7 +40,7 @@
def setUp(self):
try:
self.dbconn = cql.connect(TEST_HOST, TEST_PORT,
cql_version=TEST_CQL_VERSION)
- except cql.cursor.TApplicationException:
+ except cql.thrifteries.TApplicationException:
# set_cql_version (and thus, cql3) not supported; skip all of
these
self.cursor = None
return
==============================================================================
Revision: 553647f4b1b9
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 13:46:01 2012
Log: add basic tests for native protocol
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=553647f4b1b9
Added:
/test/test_native_connection.py
=======================================
--- /dev/null
+++ /test/test_native_connection.py Tue Sep 25 13:46:01 2012
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# to configure behavior, define $CQL_TEST_HOST to the destination address
+# for native connections, and $CQL_TEST_NATIVE_PORT to the associated port.
+
+import os
+import unittest
+import contextlib
+from thrift.transport import TTransport
+import test_cql
+from test_prepared_queries import MIN_THRIFT_FOR_CQL_3_0_0_FINAL
+from test_connection import with_keyspace, TEST_HOST, randstring, cql
+
+TEST_NATIVE_PORT = int(os.environ.get('CQL_TEST_NATIVE_PORT', '8000'))
+
+class TestNativeConnection(unittest.TestCase):
+ def setUp(self):
+ self.randstr = randstring()
+ self.with_keyspace = lambda curs, ver: with_keyspace(self.randstr,
curs, ver)
+
+ def test_connecting_with_cql_version(self):
+ # 2.0.0 won't be supported by binary protocol
+ self.assertRaises(cql.ProgrammingError,
+ cql.connect, TEST_HOST, TEST_NATIVE_PORT,
+ native=True, cql_version='2.0.0')
+
+ def test_connecting_with_keyspace(self):
+ # this conn is just for creating the keyspace
+ conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True)
+ curs = conn.cursor()
+ with self.with_keyspace(curs, conn.cql_version) as ksname:
+ curs.execute('create table blah1_%s (a int primary key, b
int);' % self.randstr)
+ conn2 = cql.connect(TEST_HOST, TEST_NATIVE_PORT,
keyspace=ksname,
+ native=True, cql_version=conn.cql_version)
+ curs2 = conn2.cursor()
+ curs2.execute('select * from blah1_%s;' % self.randstr)
+ conn2.close()
+
+ def test_execution_fails_after_close(self):
+ conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True)
+ curs = conn.cursor()
+ with self.with_keyspace(curs, conn.cql_version) as ksname:
+ curs.execute('create table blah (a int primary key, b int);')
+ curs.execute('select * from blah;')
+ conn.close()
+ self.assertRaises(cql.ProgrammingError, curs.execute, 'select *
from blah;')
+
+ def try_basic_stuff(self, conn):
+ curs = conn.cursor()
+ with self.with_keyspace(curs, conn.cql_version) as ksname:
+ curs.execute('create table moo (a text primary key, b int, c
float);')
+ curs.execute("insert into moo (a, b, c) values (:d, :e, :f);",
+ {'d': 'hi', 'e': 1234, 'f': 1.234});
+ qprep = curs.prepare_query("select * from moo where a
= :fish;")
+ curs.execute_prepared(qprep, {'fish': 'hi'})
+ res = curs.fetchall()
+ self.assertEqual(len(res), 1)
+ self.assertEqual(res[0][0], 'hi')
+ self.assertEqual(res[0][1], 1234)
+ self.assertAlmostEqual(res[0][2], 1.234)
+
+ def test_connecting_without_compression(self):
+ conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True,
compression=False)
+ self.assertEqual(conn.compressor, None)
+ self.try_basic_stuff(conn)
+
+ def test_connecting_with_compression(self):
+ try:
+ import snappy
+ except ImportError:
+ if hasattr(unittest, 'skipTest'):
+ unittest.skipTest('Snappy compression not available')
+ else:
+ return
+ conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True,
compression=True)
+ self.assertEqual(conn.compressor, snappy.compress)
+ self.try_basic_stuff(conn)
==============================================================================
Revision: 96b064c3159b
Author: paul cannon <pa...@thepaul.org>
Date: Tue Sep 25 13:45:33 2012
Log: support for callbacks on native-proto events
i.e., STATUS_CHANGE, TOPOLOGY_CHANGE
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=96b064c3159b
Modified:
/cql/native.py
=======================================
--- /cql/native.py Tue Sep 25 13:44:33 2012
+++ /cql/native.py Tue Sep 25 13:45:33 2012
@@ -15,7 +15,8 @@
# limitations under the License.
import cql
-from cql.marshal import int32_pack, int32_unpack, uint16_pack,
uint16_unpack
+from cql.marshal import (int32_pack, int32_unpack, uint16_pack,
uint16_unpack,
+ int8_pack, int8_unpack)
from cql.cqltypes import lookup_cqltype
from cql.connection import Connection
from cql.cursor import Cursor, _VOID_DESCRIPTION, _COUNT_DESCRIPTION
@@ -32,8 +33,6 @@
PROTOCOL_VERSION = 0x01
PROTOCOL_VERSION_MASK = 0x7f
-# XXX: should these be called request/response instead? unclear which one
will
-# apply if/when the server initiates streams in the other direction.
HEADER_DIRECTION_FROM_CLIENT = 0x00
HEADER_DIRECTION_TO_CLIENT = 0x80
HEADER_DIRECTION_MASK = 0x80
@@ -95,7 +94,8 @@
body = compression(body)
flags |= 0x1
msglen = int32_pack(len(body))
- header = '%c%c%c%c%s' % (version, flags, streamid, self.opcode,
msglen)
+ header = ''.join(map(int8_pack, (version, flags, streamid,
self.opcode))) \
+ + msglen
f.write(header)
if len(body) > 0:
f.write(body)
@@ -107,7 +107,7 @@
def read_frame(f, decompressor=None):
header = f.read(8)
- version, flags, stream, opcode = map(ord, header[:4])
+ version, flags, stream, opcode = map(int8_unpack, header[:4])
body_len = int32_unpack(header[4:])
assert version & PROTOCOL_VERSION_MASK == PROTOCOL_VERSION, \
"Unsupported CQL protocol version %d" % version
@@ -496,10 +496,10 @@
def read_byte(f):
- return ord(f.read(1))
+ return int8_unpack(f.read(1))
def write_byte(f, b):
- f.write(chr(b))
+ f.write(int8_pack(b))
def read_int(f):
return int32_unpack(f.read(4))
@@ -734,6 +734,7 @@
self.waiting = {}
self.conn_ready = False
self.compressor = self.decompressor = None
+ self.event_watchers = {}
Connection.__init__(self, *args, **kwargs)
def establish_connection(self):
@@ -838,6 +839,10 @@
Given any number of stream-ids, wait until responses have arrived
for
each one, and return a dictionary mapping the stream-ids to the
appropriate results.
+
+ For internal use, None may be passed in place of a reqid, which
will
+ be considered satisfied when a message of any kind is received
(and, if
+ appropriate, handled).
"""
waiting_for = set(reqids)
@@ -857,6 +862,9 @@
waiting_for.remove(newmsg.stream_id)
else:
self.handle_incoming(newmsg)
+ if None in waiting_for:
+ results[None] = newmsg
+ waiting_for.remove(None)
return results
def wait_for_result(self, reqid):
@@ -907,3 +915,79 @@
reqid = self.send_msg(msg)
self.callback_when(reqid, cb)
+
+ def handle_pushed(self, msg):
+ """
+ Process an incoming message originated by the server.
+ """
+ watchers = self.event_watchers.get(msg.eventtype, ())
+ for cb in watchers:
+ cb(msg.eventargs)
+
+ def register_watcher(self, eventtype, cb):
+ """
+ Request that any events of the given type be passed to the given
+ callback when they arrive. Note that the callback may not be called
+ immediately upon the arrival of the event packet; it may have to
wait
+ until something else waits on a result, or until wait_for_even() is
+ called.
+
+ If the event type has not been registered for already, this may
+ block while a new REGISTER message is sent to the server.
+
+ The available event types are in the cql.native.known_event_types
+ list.
+
+ When an event arrives, a dictionary will be passed to the callback
+ with the info about the event. Some example result dictionaries:
+
+ (For STATUS_CHANGE events:)
+
+ {'changetype': u'DOWN', 'address': ('12.114.19.76', 8000)}
+
+ (For TOPOLOGY_CHANGE events:)
+
+ {'changetype': u'NEW_NODE', 'address': ('19.10.122.13', 8000)}
+ """
+
+ if isinstance(eventtype, str):
+ eventtype = eventtype.decode('utf8')
+ try:
+ watchers = self.event_watchers[eventtype]
+ except KeyError:
+ ans =
self.wait_for_request(RegisterMessage(eventlist=(eventtype,)))
+ if isinstance(ans, ErrorMessage):
+ raise cql.ProgrammingError("Server did not accept
registration"
+ " for %s events: %s"
+ % (eventtype, ans.summarymsg()))
+ watchers = self.event_watchers.setdefault(eventtype, [])
+ watchers.append(cb)
+
+ def unregister_watcher(self, eventtype, cb):
+ """
+ Given an eventtype and a callback previously registered with
+ register_watcher(), remove that callback from the list of watchers
for
+ the given event type.
+ """
+
+ if isinstance(eventtype, str):
+ eventtype = eventtype.decode('utf8')
+ self.event_watchers[eventtype].remove(cb)
+
+ def wait_for_event(self):
+ """
+ Wait for any sort of event to arrive, and handle it via the
+ registered callbacks. It is recommended that some event watchers
+ be registered before calling this; otherwise, no events will be
+ sent by the server.
+ """
+ eventsseen = []
+ def i_saw_an_event(ev):
+ eventsseen.append(ev)
+ wlists = self.event_watchers.values()
+ for wlist in wlists:
+ wlist.append(i_saw_an_event)
+ while not eventsseen:
+ self.wait_for_result(None)
+ for wlist in wlists:
+ wlist.remove(i_saw_an_event)