4 new revisions:

Revision: d860b4b2250d
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 11:26:45 2012
Log:      support snappy compression, if installed
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=d860b4b2250d

Revision: 56d24cd277c2
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 13:44:33 2012
Log:      update tests; move thrift_client in test_cql
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=56d24cd277c2

Revision: 553647f4b1b9
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 13:46:01 2012
Log:      add basic tests for native protocol
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=553647f4b1b9

Revision: 96b064c3159b
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 13:45:33 2012
Log:      support for callbacks on native-proto events...
http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=96b064c3159b

==============================================================================
Revision: d860b4b2250d
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 11:26:45 2012
Log:      support snappy compression, if installed

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=d860b4b2250d

Modified:
 /cql/connection.py
 /cql/native.py

=======================================
--- /cql/connection.py  Mon Sep 17 04:34:57 2012
+++ /cql/connection.py  Tue Sep 25 11:26:45 2012
@@ -29,8 +29,12 @@
         * user .........: username used in authentication (optional).
         * password .....: password used in authentication (optional).
         * cql_version...: CQL version to use (optional).
-        * compression...: the sort of compression to use by default;
-        *                 overrideable per Cursor object. (optional).
+ * compression...: whether to use compression. For Thrift connections,
+        *                 this can be None or the name of some supported
+        *                 compression type (like "GZIP"). For native
+        *                 connections, this is treated as a boolean, and if
+        *                 true, the connection will try to find a type of
+        *                 compression supported by both sides.
         """
         self.host = host
         self.port = port
@@ -85,12 +89,37 @@
         return curs

 # TODO: Pull connections out of a pool instead.
-def connect(host, port=9160, keyspace=None, user=None, password=None,
-            cql_version=None, native=False):
+def connect(host, port=None, keyspace=None, user=None, password=None,
+            cql_version=None, native=False, compression=None):
+    """
+    Create a connection to a Cassandra node.
+
+    @param host Hostname of Cassandra node.
+    @param port Port number to connect to (default 9160 for thrift, 8000
+                for native)
+    @param keyspace If set, authenticate to this keyspace on connection.
+    @param user If set, use this username in authentication.
+    @param password If set, use this password in authentication.
+    @param cql_version If set, try to use the given CQL version. If unset,
+                uses the default for the connection.
+    @param compression Whether to use compression. For Thrift connections,
+                this can be None or the name of some supported compression
+                type (like "GZIP"). For native connections, this is treated
+                as a boolean, and if true, the connection will try to find
+                a type of compression supported by both sides.
+
+    @returns a Connection instance of the appropriate subclass.
+    """
+
     if native:
         from native import NativeConnection
         connclass = NativeConnection
+        if port is None:
+            port = 8000
     else:
         from thrifteries import ThriftConnection
         connclass = ThriftConnection
-    return connclass(host, port, keyspace, user, password, cql_version)
+        if port is None:
+            port = 9160
+    return connclass(host, port, keyspace, user, password,
+                     cql_version=cql_version, compression=compression)
=======================================
--- /cql/native.py      Mon Sep 17 04:34:57 2012
+++ /cql/native.py      Tue Sep 25 11:26:45 2012
@@ -85,12 +85,15 @@
                                  % (self.__class__.__name__, pname))
             setattr(self, pname, pval)

-    def send(self, f, streamid, compression=False):
+    def send(self, f, streamid, compression=None):
         body = StringIO()
         self.send_body(body)
         body = body.getvalue()
         version = PROTOCOL_VERSION | HEADER_DIRECTION_FROM_CLIENT
-        flags = 0 # no compression supported yet
+        flags = 0
+        if compression is not None and len(body) > 0:
+            body = compression(body)
+            flags |= 0x1
         msglen = int32_pack(len(body))
header = '%c%c%c%c%s' % (version, flags, streamid, self.opcode, msglen)
         f.write(header)
@@ -102,7 +105,7 @@
         return '<%s(%s)>' % (self.__class__.__name__, ', '.join(paramstrs))
     __repr__ = __str__

-def read_frame(f):
+def read_frame(f, decompressor=None):
     header = f.read(8)
     version, flags, stream, opcode = map(ord, header[:4])
     body_len = int32_unpack(header[4:])
@@ -111,9 +114,14 @@
     assert version & HEADER_DIRECTION_MASK == HEADER_DIRECTION_TO_CLIENT, \
"Unexpected request from server with opcode %04x, stream id %r" % (opcode, stream)
     assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len
+    body = f.read(body_len)
+    if flags & 0x1:
+        if decompressor is None:
+ raise ProtocolException("No decompressor available for compressed frame!")
+        body = decompressor(body)
+        flags ^= 0x1
     if flags:
warn("Unknown protocol flags set: %02x. May cause problems." % flags)
-    body = f.read(body_len)
     msgclass = _message_types_by_opcode[opcode]
     msg = msgclass.recv_body(StringIO(body))
     msg.stream_id = stream
@@ -670,10 +678,10 @@
         self.rowcount = len(self.result)

     def get_compression(self):
-        return None
+        return self._connection.compression

     def set_compression(self, val):
-        if val is not None:
+        if val != self.get_compression():
raise NotImplementedError("Setting per-cursor compression is not "
                                       "supported in NativeCursor.")

@@ -702,6 +710,20 @@
     def close(self):
         pass

+locally_supported_compressions = {}
+
+try:
+    import snappy
+except ImportError:
+    pass
+else:
+    # work around apparently buggy snappy decompress
+    def decompress(byts):
+        if byts == '\x00':
+            return ''
+        return snappy.decompress(byts)
+ locally_supported_compressions['snappy'] = (snappy.compress, decompress)
+
 class NativeConnection(Connection):
     cursorclass = NativeCursor

@@ -710,6 +732,7 @@
         self.responses = {}
         self.waiting = {}
         self.conn_ready = False
+        self.compressor = self.decompressor = None
         Connection.__init__(self, *args, **kwargs)

     def establish_connection(self):
@@ -721,7 +744,7 @@
         self.open_socket = True
         supported = self.wait_for_request(OptionsMessage())
         self.supported_cql_versions = supported.cqlversions
-        self.supported_compressions = supported.options['COMPRESSION']
+ self.remote_supported_compressions = supported.options['COMPRESSION']

         if self.cql_version:
             if self.cql_version not in self.supported_cql_versions:
@@ -733,20 +756,30 @@
             self.cql_version = self.supported_cql_versions[0]

         opts = {}
+        compresstype = None
         if self.compression:
-            if self.compression not in self.supported_compressions:
- raise ProgrammingError("Compression type %r is not supported by" - " remote. Supported compression types: %r" - % (self.compression, self.supported_compressions))
-            # XXX: Remove this once some compressions are supported
- raise NotImplementedError("CQL driver does not yet support compression")
-            opts['COMPRESSION'] = self.compression
+            overlap = set(locally_supported_compressions) \
+                    & set(self.remote_supported_compressions)
+            if len(overlap) == 0:
+ warn("No available compression types supported on both ends."
+                     " locally supported: %r. remotely supported: %r"
+                     % (locally_supported_compressions,
+                        self.remote_supported_compressions))
+            else:
+                compresstype = iter(overlap).next() # choose any
+                opts['COMPRESSION'] = compresstype
+ compr, decompr = locally_supported_compressions[compresstype] + # set the decompressor here, but set the compressor only after
+                # a successful Ready message
+                self.decompressor = decompr

         sm = StartupMessage(cqlversion=self.cql_version, options=opts)
         startup_response = self.wait_for_request(sm)
         while True:
             if isinstance(startup_response, ReadyMessage):
                 self.conn_ready = True
+                if compresstype:
+                    self.compressor = compr
                 break
             if isinstance(startup_response, AuthenticateMessage):
                 self.authenticator = startup_response.authenticator
@@ -779,6 +812,11 @@

         return self.wait_for_requests(msg)[0]

+    def send_msg(self, msg):
+        reqid = self.make_reqid()
+        msg.send(self.socketf, reqid, compression=self.compressor)
+        return reqid
+
     def wait_for_requests(self, *msgs):
         """
         Given any number of message objects, send them all to the server
@@ -789,9 +827,8 @@

         reqids = []
         for msg in msgs:
-            reqid = self.make_reqid()
+            reqid = self.send_msg(msg)
             reqids.append(reqid)
-            msg.send(self.socketf, reqid)
         resultdict = self.wait_for_results(*reqids)
         return [resultdict[reqid] for reqid in reqids]

@@ -813,7 +850,7 @@
                 results[r] = result
                 waiting_for.remove(r)
         while waiting_for:
-            newmsg = read_frame(self.socketf)
+ newmsg = read_frame(self.socketf, decompressor=self.decompressor)
             if newmsg.stream_id in waiting_for:
                 results[newmsg.stream_id] = newmsg
                 waiting_for.remove(newmsg.stream_id)
@@ -867,6 +904,5 @@
         it may have to wait until something else waits on a result.
         """

-        reqid = self.make_reqid()
-        msg.send(self.socketf, reqid)
+        reqid = self.send_msg(msg)
         self.callback_when(reqid, cb)

==============================================================================
Revision: 56d24cd277c2
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 13:44:33 2012
Log:      update tests; move thrift_client in test_cql

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=56d24cd277c2

Modified:
 /cql/cqltypes.py
 /cql/cursor.py
 /cql/native.py
 /test/test_connection.py
 /test/test_cql.py
 /test/test_prepared_queries.py

=======================================
--- /cql/cqltypes.py    Wed Sep 12 13:32:53 2012
+++ /cql/cqltypes.py    Tue Sep 25 13:44:33 2012
@@ -137,7 +137,7 @@

     """

-    if isinstance(casstype, CassandraType):
+    if isinstance(casstype, (CassandraType, CassandraTypeType)):
         return casstype
     try:
         return parse_casstype_args(casstype)
=======================================
--- /cql/cursor.py      Tue Sep 11 17:31:33 2012
+++ /cql/cursor.py      Tue Sep 25 13:44:33 2012
@@ -189,5 +189,5 @@
     ###

     def __checksock(self):
-        if self._connection is None:
+        if self._connection is None or not self._connection.open_socket:
             raise cql.ProgrammingError("Cursor has been closed.")
=======================================
--- /cql/native.py      Tue Sep 25 11:26:45 2012
+++ /cql/native.py      Tue Sep 25 13:44:33 2012
@@ -388,7 +388,7 @@
         return CqlResult(column_metadata=colspecs, rows=rows)

     @classmethod
-    def recv_results_prepared(self, f):
+    def recv_results_prepared(cls, f):
         queryid = read_int(f)
         colspecs = cls.recv_results_metadata(f)
         return (queryid, colspecs)
@@ -625,7 +625,8 @@
         return self._connection.wait_for_request(QueryMessage(query=query))

     def get_response_prepared(self, prepared_query, params):
- em = ExecuteMessage(queryid=prepared_query.itemid, queryparams=params)
+        qparams = [params[pname] for pname in prepared_query.paramnames]
+ em = ExecuteMessage(queryid=prepared_query.itemid, queryparams=qparams)
         return self._connection.wait_for_request(em)

     def get_column_metadata(self, column_id):
=======================================
--- /test/test_connection.py    Thu Sep 20 10:38:41 2012
+++ /test/test_connection.py    Tue Sep 25 13:44:33 2012
@@ -29,34 +29,35 @@
 randstring = test_cql.randstring
 del test_cql

[email protected]
+def with_keyspace(randstr, cursor, cqlver):
+ ksname = randstr + '_conntest_' + cqlver.encode('ascii').replace('.', '_')
+    if cqlver.startswith('2.'):
+ cursor.execute("create keyspace '%s' with strategy_class='SimpleStrategy'" + " and strategy_options:replication_factor=1;" % ksname)
+        cursor.execute("use '%s'" % ksname)
+        yield ksname
+        cursor.execute("use system;")
+        cursor.execute("drop keyspace '%s'" % ksname)
+    elif cqlver == '3.0.0-beta1': # for cassandra 1.1
+ cursor.execute("create keyspace \"%s\" with strategy_class='SimpleStrategy'" + " and strategy_options:replication_factor=1;" % ksname)
+        cursor.execute('use "%s"' % ksname)
+        yield ksname
+        cursor.execute('use system;')
+        cursor.execute('drop keyspace "%s"' % ksname)
+    else:
+        cursor.execute("create keyspace \"%s\" with replication = "
+ "{'class': 'SimpleStrategy', 'replication_factor': 1};" % ksname)
+        cursor.execute('use "%s"' % ksname)
+        yield ksname
+        cursor.execute('use system;')
+        cursor.execute('drop keyspace "%s"' % ksname)
+
 class TestConnection(unittest.TestCase):
     def setUp(self):
         self.randstr = randstring()
-
-    @contextlib.contextmanager
-    def with_keyspace(self, cursor, cqlver):
-        ksname = self.randstr + '_conntest_' + cqlver.replace('.', '_')
-        if cqlver.startswith('2.'):
- cursor.execute("create keyspace '%s' with strategy_class='SimpleStrategy'" - " and strategy_options:replication_factor=1;" % ksname)
-            cursor.execute("use '%s'" % ksname)
-            yield ksname
-            cursor.execute("use system;")
-            cursor.execute("drop keyspace '%s'" % ksname)
-        elif cqlver == '3.0.0-beta1': # for cassandra 1.1
- cursor.execute("create keyspace \"%s\" with strategy_class='SimpleStrategy'" - " and strategy_options:replication_factor=1;" % ksname)
-            cursor.execute('use "%s"' % ksname)
-            yield ksname
-            cursor.execute('use system;')
-            cursor.execute('drop keyspace "%s"' % ksname)
-        else:
-            cursor.execute("create keyspace \"%s\" with replication = "
- "{'class': 'SimpleStrategy', 'replication_factor': 1};" % ksname)
-            cursor.execute('use "%s"' % ksname)
-            yield ksname
-            cursor.execute('use system;')
-            cursor.execute('drop keyspace "%s"' % ksname)
+ self.with_keyspace = lambda curs, ver: with_keyspace(self.randstr, curs, ver)

     def test_connecting_with_cql_version(self):
         conn = cql.connect(TEST_HOST, TEST_PORT, cql_version='2.0.0')
@@ -100,4 +101,4 @@
             curs.execute('create table blah (a int primary key, b int);')
             curs.execute('select * from blah;')
         conn.close()
- self.assertRaises(TTransport.TTransportException, curs.execute, 'select * from blah;') + self.assertRaises(cql.ProgrammingError, curs.execute, 'select * from blah;')
=======================================
--- /test/test_cql.py   Thu Sep 20 10:58:57 2012
+++ /test/test_cql.py   Tue Sep 25 13:44:33 2012
@@ -52,7 +52,6 @@
     client.transport = transport
     client.transport.open()
     return client
-thrift_client = get_thrift_client()

 def uuid1bytes_to_millis(uuidbytes):
return (uuid.UUID(bytes=uuidbytes).get_time() / 10000) - 12219292800000L
@@ -164,6 +163,8 @@
     keyspace = None

     def setUp(self):
+        self.thrift_client = get_thrift_client()
+
# all tests in this module are against cql 2. change would be welcomed.
         dbconn = cql.connect(TEST_HOST, TEST_PORT, cql_version='2.0.0')
         self.cursor = dbconn.cursor()
@@ -186,7 +187,7 @@
         return ksname

     def get_partitioner(self):
-        return thrift_client.describe_partitioner()
+        return self.thrift_client.describe_partitioner()

     def assertIsSubclass(self, class_a, class_b):
assert issubclass(class_a, class_b), '%r is not a subclass of %r' % (class_a, class_b)
@@ -519,13 +520,13 @@
         """, {'ks': ksname2})

         # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname1)
+        ksdef = self.thrift_client.describe_keyspace(ksname1)

strategy_class = "org.apache.cassandra.locator.NetworkTopologyStrategy"
         self.assertEqual(ksdef.strategy_class, strategy_class)
         self.assertEqual(ksdef.strategy_options['DC1'], "1")

-        ksdef = thrift_client.describe_keyspace(ksname2)
+        ksdef = self.thrift_client.describe_keyspace(ksname2)

strategy_class = "org.apache.cassandra.locator.NetworkTopologyStrategy"
         self.assertEqual(ksdef.strategy_class, strategy_class)
@@ -542,14 +543,14 @@
         """, {'ks': ksname})

         # TODO: temporary (until this can be done with CQL).
-        thrift_client.describe_keyspace(ksname)
+        self.thrift_client.describe_keyspace(ksname)

         cursor.execute('DROP SCHEMA :ks;', {'ks': ksname})

         # Technically this should throw a ttypes.NotFound(), but this is
         # temporary and so not worth requiring it on PYTHONPATH.
         self.assertRaises(Exception,
-                          thrift_client.describe_keyspace,
+                          self.thrift_client.describe_keyspace,
                           ksname)

     def test_create_column_family(self):
@@ -573,7 +574,7 @@
         """)

         # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 1)
         cfam= ksdef.cf_defs[0]
         self.assertEqual(len(cfam.column_metadata), 4)
@@ -597,7 +598,7 @@
         # No column defs
         cursor.execute("""CREATE COLUMNFAMILY NewCf3
(KEY varint PRIMARY KEY) WITH comparator = bigint""")
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 2)
         cfam = [i for i in ksdef.cf_defs if i.name == "NewCf3"][0]
self.assertEqual(cfam.comparator_type, "org.apache.cassandra.db.marshal.LongType")
@@ -606,7 +607,7 @@
         cursor.execute("""CREATE COLUMNFAMILY NewCf4
(KEY varint PRIMARY KEY, 'a' varint, 'b' varint)
                             WITH comparator = text;""")
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 3)
         cfam = [i for i in ksdef.cf_defs if i.name == "NewCf4"][0]
         self.assertEqual(len(cfam.column_metadata), 2)
@@ -626,12 +627,12 @@
cursor.execute('CREATE COLUMNFAMILY CF4Drop (KEY varint PRIMARY KEY);')

         # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         assert len(ksdef.cf_defs), "Column family not created!"

         cursor.execute('DROP COLUMNFAMILY CF4Drop;')

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         assert not len(ksdef.cf_defs), "Column family not deleted!"

     def test_create_indexs(self):
@@ -643,7 +644,7 @@
         cursor.execute("CREATE INDEX ON CreateIndex1 (stuff)")

         # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(self.keyspace)
+        ksdef = self.thrift_client.describe_keyspace(self.keyspace)
         cfam = [i for i in ksdef.cf_defs if i.name == "CreateIndex1"][0]
         items = [i for i in cfam.column_metadata if i.name == "items"][0]
         stuff = [i for i in cfam.column_metadata if i.name == "stuff"][0]
@@ -667,7 +668,7 @@
cursor.execute("CREATE COLUMNFAMILY IndexedCF (KEY text PRIMARY KEY, n text)")
         cursor.execute("CREATE INDEX namedIndex ON IndexedCF (n)")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         columns = ksdef.cf_defs[0].column_metadata

         self.assertEqual(columns[0].index_name, "namedIndex")
@@ -676,7 +677,7 @@
         # testing "DROP INDEX <INDEX_NAME>"
         cursor.execute("DROP INDEX namedIndex")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         columns = ksdef.cf_defs[0].column_metadata

         self.assertEqual(columns[0].index_type, None)
@@ -1243,7 +1244,7 @@
         """)

         # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 1)
         cfam = ksdef.cf_defs[0]

@@ -1252,7 +1253,7 @@
         # testing "add a new column"
         cursor.execute("ALTER COLUMNFAMILY NewCf1 ADD name varchar")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 1)
         columns = ksdef.cf_defs[0].column_metadata

@@ -1263,7 +1264,7 @@
         # testing "alter a column type"
         cursor.execute("ALTER COLUMNFAMILY NewCf1 ALTER name TYPE ascii")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 1)
         columns = ksdef.cf_defs[0].column_metadata

@@ -1279,7 +1280,7 @@
         # testing 'drop an existing column'
         cursor.execute("ALTER COLUMNFAMILY NewCf1 DROP name")

-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         self.assertEqual(len(ksdef.cf_defs), 1)
         columns = ksdef.cf_defs[0].column_metadata

@@ -1396,7 +1397,7 @@
         """)

         # TODO: temporary (until this can be done with CQL).
-        ksdef = thrift_client.describe_keyspace(ksname)
+        ksdef = self.thrift_client.describe_keyspace(ksname)
         cfdef = ksdef.cf_defs[0]

         self.assertEqual(len(ksdef.cf_defs), 1)
=======================================
--- /test/test_prepared_queries.py      Tue Sep 11 16:26:59 2012
+++ /test/test_prepared_queries.py      Tue Sep 25 13:44:33 2012
@@ -25,7 +25,7 @@

 TEST_HOST = os.environ.get('CQL_TEST_HOST', 'localhost')
 TEST_PORT = int(os.environ.get('CQL_TEST_PORT', 9170))
-TEST_CQL_VERSION = '3.0.0-beta1'
+TEST_CQL_VERSION = os.environ.get('CQL_TEST_VERSION', '3.0.0-beta1')

 sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

@@ -40,7 +40,7 @@
     def setUp(self):
         try:
self.dbconn = cql.connect(TEST_HOST, TEST_PORT, cql_version=TEST_CQL_VERSION)
-        except cql.cursor.TApplicationException:
+        except cql.thrifteries.TApplicationException:
# set_cql_version (and thus, cql3) not supported; skip all of these
             self.cursor = None
             return

==============================================================================
Revision: 553647f4b1b9
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 13:46:01 2012
Log:      add basic tests for native protocol

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=553647f4b1b9

Added:
 /test/test_native_connection.py

=======================================
--- /dev/null
+++ /test/test_native_connection.py     Tue Sep 25 13:46:01 2012
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# to configure behavior, define $CQL_TEST_HOST to the destination address
+# for native connections, and $CQL_TEST_NATIVE_PORT to the associated port.
+
+import os
+import unittest
+import contextlib
+from thrift.transport import TTransport
+import test_cql
+from test_prepared_queries import MIN_THRIFT_FOR_CQL_3_0_0_FINAL
+from test_connection import with_keyspace, TEST_HOST, randstring, cql
+
+TEST_NATIVE_PORT = int(os.environ.get('CQL_TEST_NATIVE_PORT', '8000'))
+
+class TestNativeConnection(unittest.TestCase):
+    def setUp(self):
+        self.randstr = randstring()
+ self.with_keyspace = lambda curs, ver: with_keyspace(self.randstr, curs, ver)
+
+    def test_connecting_with_cql_version(self):
+        # 2.0.0 won't be supported by binary protocol
+        self.assertRaises(cql.ProgrammingError,
+                          cql.connect, TEST_HOST, TEST_NATIVE_PORT,
+                          native=True, cql_version='2.0.0')
+
+    def test_connecting_with_keyspace(self):
+        # this conn is just for creating the keyspace
+        conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True)
+        curs = conn.cursor()
+        with self.with_keyspace(curs, conn.cql_version) as ksname:
+ curs.execute('create table blah1_%s (a int primary key, b int);' % self.randstr) + conn2 = cql.connect(TEST_HOST, TEST_NATIVE_PORT, keyspace=ksname,
+                                native=True, cql_version=conn.cql_version)
+            curs2 = conn2.cursor()
+            curs2.execute('select * from blah1_%s;' % self.randstr)
+            conn2.close()
+
+    def test_execution_fails_after_close(self):
+        conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True)
+        curs = conn.cursor()
+        with self.with_keyspace(curs, conn.cql_version) as ksname:
+            curs.execute('create table blah (a int primary key, b int);')
+            curs.execute('select * from blah;')
+        conn.close()
+ self.assertRaises(cql.ProgrammingError, curs.execute, 'select * from blah;')
+
+    def try_basic_stuff(self, conn):
+        curs = conn.cursor()
+        with self.with_keyspace(curs, conn.cql_version) as ksname:
+ curs.execute('create table moo (a text primary key, b int, c float);')
+            curs.execute("insert into moo (a, b, c) values (:d, :e, :f);",
+                         {'d': 'hi', 'e': 1234, 'f': 1.234});
+ qprep = curs.prepare_query("select * from moo where a = :fish;")
+            curs.execute_prepared(qprep, {'fish': 'hi'})
+            res = curs.fetchall()
+            self.assertEqual(len(res), 1)
+            self.assertEqual(res[0][0], 'hi')
+            self.assertEqual(res[0][1], 1234)
+            self.assertAlmostEqual(res[0][2], 1.234)
+
+    def test_connecting_without_compression(self):
+ conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True, compression=False)
+        self.assertEqual(conn.compressor, None)
+        self.try_basic_stuff(conn)
+
+    def test_connecting_with_compression(self):
+        try:
+            import snappy
+        except ImportError:
+            if hasattr(unittest, 'skipTest'):
+                unittest.skipTest('Snappy compression not available')
+            else:
+                return
+ conn = cql.connect(TEST_HOST, TEST_NATIVE_PORT, native=True, compression=True)
+        self.assertEqual(conn.compressor, snappy.compress)
+        self.try_basic_stuff(conn)

==============================================================================
Revision: 96b064c3159b
Author:   paul cannon <[email protected]>
Date:     Tue Sep 25 13:45:33 2012
Log:      support for callbacks on native-proto events

i.e., STATUS_CHANGE, TOPOLOGY_CHANGE

http://code.google.com/a/apache-extras.org/p/cassandra-dbapi2/source/detail?r=96b064c3159b

Modified:
 /cql/native.py

=======================================
--- /cql/native.py      Tue Sep 25 13:44:33 2012
+++ /cql/native.py      Tue Sep 25 13:45:33 2012
@@ -15,7 +15,8 @@
 # limitations under the License.

 import cql
-from cql.marshal import int32_pack, int32_unpack, uint16_pack, uint16_unpack +from cql.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
+                         int8_pack, int8_unpack)
 from cql.cqltypes import lookup_cqltype
 from cql.connection import Connection
 from cql.cursor import Cursor, _VOID_DESCRIPTION, _COUNT_DESCRIPTION
@@ -32,8 +33,6 @@
 PROTOCOL_VERSION             = 0x01
 PROTOCOL_VERSION_MASK        = 0x7f

-# XXX: should these be called request/response instead? unclear which one will
-# apply if/when the server initiates streams in the other direction.
 HEADER_DIRECTION_FROM_CLIENT = 0x00
 HEADER_DIRECTION_TO_CLIENT   = 0x80
 HEADER_DIRECTION_MASK        = 0x80
@@ -95,7 +94,8 @@
             body = compression(body)
             flags |= 0x1
         msglen = int32_pack(len(body))
- header = '%c%c%c%c%s' % (version, flags, streamid, self.opcode, msglen) + header = ''.join(map(int8_pack, (version, flags, streamid, self.opcode))) \
+                 + msglen
         f.write(header)
         if len(body) > 0:
             f.write(body)
@@ -107,7 +107,7 @@

 def read_frame(f, decompressor=None):
     header = f.read(8)
-    version, flags, stream, opcode = map(ord, header[:4])
+    version, flags, stream, opcode = map(int8_unpack, header[:4])
     body_len = int32_unpack(header[4:])
     assert version & PROTOCOL_VERSION_MASK == PROTOCOL_VERSION, \
             "Unsupported CQL protocol version %d" % version
@@ -496,10 +496,10 @@


 def read_byte(f):
-    return ord(f.read(1))
+    return int8_unpack(f.read(1))

 def write_byte(f, b):
-    f.write(chr(b))
+    f.write(int8_pack(b))

 def read_int(f):
     return int32_unpack(f.read(4))
@@ -734,6 +734,7 @@
         self.waiting = {}
         self.conn_ready = False
         self.compressor = self.decompressor = None
+        self.event_watchers = {}
         Connection.__init__(self, *args, **kwargs)

     def establish_connection(self):
@@ -838,6 +839,10 @@
Given any number of stream-ids, wait until responses have arrived for
         each one, and return a dictionary mapping the stream-ids to the
         appropriate results.
+
+ For internal use, None may be passed in place of a reqid, which will + be considered satisfied when a message of any kind is received (and, if
+        appropriate, handled).
         """

         waiting_for = set(reqids)
@@ -857,6 +862,9 @@
                 waiting_for.remove(newmsg.stream_id)
             else:
                 self.handle_incoming(newmsg)
+            if None in waiting_for:
+                results[None] = newmsg
+                waiting_for.remove(None)
         return results

     def wait_for_result(self, reqid):
@@ -907,3 +915,79 @@

         reqid = self.send_msg(msg)
         self.callback_when(reqid, cb)
+
+    def handle_pushed(self, msg):
+        """
+        Process an incoming message originated by the server.
+        """
+        watchers = self.event_watchers.get(msg.eventtype, ())
+        for cb in watchers:
+            cb(msg.eventargs)
+
+    def register_watcher(self, eventtype, cb):
+        """
+        Request that any events of the given type be passed to the given
+        callback when they arrive. Note that the callback may not be called
+ immediately upon the arrival of the event packet; it may have to wait
+        until something else waits on a result, or until wait_for_even() is
+        called.
+
+        If the event type has not been registered for already, this may
+        block while a new REGISTER message is sent to the server.
+
+        The available event types are in the cql.native.known_event_types
+        list.
+
+        When an event arrives, a dictionary will be passed to the callback
+        with the info about the event. Some example result dictionaries:
+
+        (For STATUS_CHANGE events:)
+
+          {'changetype': u'DOWN', 'address': ('12.114.19.76', 8000)}
+
+        (For TOPOLOGY_CHANGE events:)
+
+          {'changetype': u'NEW_NODE', 'address': ('19.10.122.13', 8000)}
+        """
+
+        if isinstance(eventtype, str):
+            eventtype = eventtype.decode('utf8')
+        try:
+            watchers = self.event_watchers[eventtype]
+        except KeyError:
+ ans = self.wait_for_request(RegisterMessage(eventlist=(eventtype,)))
+            if isinstance(ans, ErrorMessage):
+ raise cql.ProgrammingError("Server did not accept registration"
+                                           " for %s events: %s"
+                                           % (eventtype, ans.summarymsg()))
+            watchers = self.event_watchers.setdefault(eventtype, [])
+        watchers.append(cb)
+
+    def unregister_watcher(self, eventtype, cb):
+        """
+        Given an eventtype and a callback previously registered with
+ register_watcher(), remove that callback from the list of watchers for
+        the given event type.
+        """
+
+        if isinstance(eventtype, str):
+            eventtype = eventtype.decode('utf8')
+        self.event_watchers[eventtype].remove(cb)
+
+    def wait_for_event(self):
+        """
+        Wait for any sort of event to arrive, and handle it via the
+        registered callbacks. It is recommended that some event watchers
+        be registered before calling this; otherwise, no events will be
+        sent by the server.
+        """
+        eventsseen = []
+        def i_saw_an_event(ev):
+            eventsseen.append(ev)
+        wlists = self.event_watchers.values()
+        for wlist in wlists:
+            wlist.append(i_saw_an_event)
+        while not eventsseen:
+            self.wait_for_result(None)
+        for wlist in wlists:
+            wlist.remove(i_saw_an_event)

Reply via email to