Author: eevans
Date: Fri Dec 31 17:24:57 2010
New Revision: 1054142

URL: http://svn.apache.org/viewvc?rev=1054142&view=rev
Log:
port python driver, avro -> thrift

Patch by eevans; reviewed by jbellis for CASSANDRA-1913

Added:
    cassandra/trunk/drivers/py/test/
    cassandra/trunk/drivers/py/test/test_query_compression.py
Modified:
    cassandra/trunk/drivers/py/cql/__init__.py
    cassandra/trunk/test/system/test_cql.py

Modified: cassandra/trunk/drivers/py/cql/__init__.py
URL: 
http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/cql/__init__.py?rev=1054142&r1=1054141&r2=1054142&view=diff
==============================================================================
--- cassandra/trunk/drivers/py/cql/__init__.py (original)
+++ cassandra/trunk/drivers/py/cql/__init__.py Fri Dec 31 17:24:57 2010
@@ -1,85 +1,79 @@
 
-from avro.ipc  import HTTPTransceiver, Requestor, AvroRemoteException
-import avro.protocol, zlib, socket
-from os.path   import exists, abspath, dirname, join
-
-def _load_protocol():
-    # By default, look for the proto schema in the same dir as this file.
-    avpr = join(abspath(dirname(__file__)), 'cassandra.avpr')
-    if exists(avpr):
-        return avro.protocol.parse(open(avpr).read())
-
-    # Fall back to ../../interface/avro/cassandra.avpr (dev environ).
-    avpr = join(abspath(dirname(__file__)),
-                '..',
-                '..',
-                '..',
-                'interface',
-                'avro',
-                'cassandra.avpr')
-    if exists(avpr):
-        return avro.protocol.parse(open(avpr).read())
-
-    raise Exception("Unable to locate an avro protocol schema!")
-
+from os.path import exists, abspath, dirname, join
+from thrift.transport import TTransport, TSocket
+from thrift.protocol import TBinaryProtocol
+from thrift.Thrift import TApplicationException
+import zlib
+
+try:
+    from cassandra import Cassandra
+    from cassandra.ttypes import Compression, InvalidRequestException, \
+                                 CqlResultType
+except ImportError:
+    # Hack to run from a source tree
+    import sys
+    sys.path.append(join(abspath(dirname(__file__)),
+                         '..',
+                         '..',
+                         '..',
+                         'interface',
+                         'thrift',
+                         'gen-py'))
+    from cassandra import Cassandra
+    from cassandra.ttypes import Compression, InvalidRequestException, \
+                          CqlResultType
+    
 COMPRESSION_SCHEMES = ['GZIP']
 DEFAULT_COMPRESSION = 'GZIP'
 
-
 class Connection(object):
     def __init__(self, keyspace, host, port=9160):
-        client = HTTPTransceiver(host, port)
-        # disabled nagle
-        client.conn.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-        self.requestor = Requestor(_load_protocol(), client)
+        socket = TSocket.TSocket(host, port)
+        self.transport = TTransport.TFramedTransport(socket)
+        protocol = TBinaryProtocol.TBinaryProtocolAccelerated(self.transport)
+        self.client = Cassandra.Client(protocol)
+        socket.open()
+
         if keyspace:
             self.execute('USE %s' % keyspace)
 
     def execute(self, query, compression=None):
         compress = compression is None and DEFAULT_COMPRESSION \
                 or compression.upper()
-        if not compress in COMPRESSION_SCHEMES:
-            raise InvalidCompressionScheme(compress)
     
         compressed_query = Connection.compress_query(query, compress)
-        request_params = dict(query=compressed_query, compression=compress)
+        request_compression = getattr(Compression, compress)
 
         try:
-            response = self.requestor.request('execute_cql_query', 
request_params)
-        except AvroRemoteException, are:
-            raise CQLException(are)
-
-        if response['type'] == 'ROWS':
-            return response['rows']
-        if response['type'] == 'INT':
-            return response['num']
+            response = self.client.execute_cql_query(compressed_query,
+                                                     request_compression)
+        except InvalidRequestException, ire:
+            raise CQLException("Bad Request: %s" % ire.why)
+        except TApplicationException, tapp:
+            raise CQLException("Internal application error")
+        except Exception, exc:
+            raise CQLException(exc)
+
+        if response.type == CqlResultType.ROWS:
+            return response.rows
+        if response.type == CqlResultType.INT:
+            return response.num
 
         return None
 
+    def close(self):
+        self.transport.close()
+
     @classmethod
     def compress_query(cls, query, compression):
+        if not compression in COMPRESSION_SCHEMES:
+            raise InvalidCompressionScheme(compression)
+
         if compression == 'GZIP':
             return zlib.compress(query)
 
 
 class InvalidCompressionScheme(Exception): pass
-
-class CQLException(Exception):
-    def __init__(self, arg):
-        if isinstance(arg, AvroRemoteException):
-            if arg.args and isinstance(arg.args[0], dict) and 
arg.args[0].has_key('why'):
-                message = arg.args[0]['why']
-            else:
-                message = str(arg)
-            Exception.__init__(self, message)
-        else:
-            Exception.__init__(self, arg)
-
-if __name__ == '__main__':
-    dbconn = Connection('localhost', 9160)
-    query = 'USE Keyspace1;'
-    dbconn.execute(query, 'GZIP') 
-    query = 'UPDATE Standard2 WITH ROW("k", COL("c", "v"));'
-    dbconn.execute(query, 'GZIP') 
+class CQLException(Exception): pass
 
 # vi: ai ts=4 tw=0 sw=4 et

Added: cassandra/trunk/drivers/py/test/test_query_compression.py
URL: 
http://svn.apache.org/viewvc/cassandra/trunk/drivers/py/test/test_query_compression.py?rev=1054142&view=auto
==============================================================================
--- cassandra/trunk/drivers/py/test/test_query_compression.py (added)
+++ cassandra/trunk/drivers/py/test/test_query_compression.py Fri Dec 31 
17:24:57 2010
@@ -0,0 +1,16 @@
+
+from os.path import abspath, exists, join, dirname
+
+if exists(join(abspath(dirname(__file__)), '..', 'cql')):
+    import sys; sys.path.append(join(abspath(dirname(__file__)), '..'))
+
+import unittest, zlib
+from cql import Connection
+
+class TestCompression(unittest.TestCase):
+    def test_gzip(self):
+        "compressing a string w/ gzip"
+        query = "SELECT \"foo\" FROM Standard1 WHERE KEY = \"bar\";"
+        compressed = Connection.compress_query(query, 'GZIP')
+        decompressed = zlib.decompress(compressed)
+        assert query == decompressed, "Decompressed query did not match"

Modified: cassandra/trunk/test/system/test_cql.py
URL: 
http://svn.apache.org/viewvc/cassandra/trunk/test/system/test_cql.py?rev=1054142&r1=1054141&r2=1054142&view=diff
==============================================================================
--- cassandra/trunk/test/system/test_cql.py (original)
+++ cassandra/trunk/test/system/test_cql.py Fri Dec 31 17:24:57 2010
@@ -5,7 +5,7 @@ import sys
 sys.path.append(join(abspath(dirname(__file__)), '../../drivers/py'))
 
 from cql import Connection, CQLException
-from . import AvroTester
+from . import ThriftTester
 from avro_utils import assert_raises
 
 def load_sample(dbconn):
@@ -49,33 +49,33 @@ def init(keyspace="Keyspace1"):
     load_sample(dbconn)
     return dbconn
 
-class TestCql(AvroTester):
+class TestCql(ThriftTester):
     def test_select_simple(self):
         "retrieve a column"
         conn = init()
         r = conn.execute('SELECT "ca1" FROM Standard1 WHERE KEY="ka"')
-        assert r[0]['key'] == 'ka'
-        assert r[0]['columns'][0]['name'] == 'ca1'
-        assert r[0]['columns'][0]['value'] == 'va1'
+        assert r[0].key == 'ka'
+        assert r[0].columns[0].name == 'ca1'
+        assert r[0].columns[0].value == 'va1'
 
     def test_select_columns(self):
         "retrieve multiple columns"
         conn = init()
         r = conn.execute('SELECT "cd1", "col" FROM Standard1 WHERE KEY = "kd"')
-        assert "cd1" in [i['name'] for i in r[0]['columns']]
-        assert "col" in [i['name'] for i in r[0]['columns']]
+        assert "cd1" in [i.name for i in r[0].columns]
+        assert "col" in [i.name for i in r[0].columns]
 
     def test_select_row_range(self):
         "retrieve a range of rows with columns"
         conn = init()
         r = conn.execute('SELECT 4L FROM StandardLong1 WHERE KEY > "ad" AND 
KEY < "ag";')
         assert len(r) == 3
-        assert r[0]['key'] == "ad"
-        assert r[1]['key'] == "ae"
-        assert r[2]['key'] == "af"
-        assert len(r[0]['columns']) == 1
-        assert len(r[1]['columns']) == 1
-        assert len(r[2]['columns']) == 1
+        assert r[0].key == "ad"
+        assert r[1].key == "ae"
+        assert r[2].key == "af"
+        assert len(r[0].columns) == 1
+        assert len(r[1].columns) == 1
+        assert len(r[2].columns) == 1
 
     def test_select_row_range_with_limit(self):
         "retrieve a limited range of rows with columns"
@@ -90,26 +90,26 @@ class TestCql(AvroTester):
         conn = init()
         r = conn.execute('SELECT 1L..3L FROM StandardLong1 WHERE KEY = "aa";')
         assert len(r) == 1
-        assert r[0]['columns'][0]['value'] == "1"
-        assert r[0]['columns'][1]['value'] == "2"
-        assert r[0]['columns'][2]['value'] == "3"
+        assert r[0].columns[0].value == "1"
+        assert r[0].columns[1].value == "2"
+        assert r[0].columns[2].value == "3"
 
     def test_select_columns_slice_with_limit(self):
         "range of columns (slice) by row with limit"
         conn = init()
         r = conn.execute('SELECT FIRST 1 1L..3L FROM StandardLong1 WHERE KEY = 
"aa";')
         assert len(r) == 1
-        assert len(r[0]['columns']) == 1
-        assert r[0]['columns'][0]['value'] == "1"
+        assert len(r[0].columns) == 1
+        assert r[0].columns[0].value == "1"
 
     def test_select_columns_slice_reversed(self):
         "range of columns (slice) by row reversed"
         conn = init()
         r = conn.execute('SELECT FIRST 2 REVERSED 3L..1L FROM StandardLong1 
WHERE KEY = "aa";')
         assert len(r) == 1, "%d != 1" % len(r)
-        assert len(r[0]['columns']) == 2
-        assert r[0]['columns'][0]['value'] == "3"
-        assert r[0]['columns'][1]['value'] == "2"
+        assert len(r[0].columns) == 2
+        assert r[0].columns[0].value == "3"
+        assert r[0].columns[1].value == "2"
 
     def test_error_on_multiple_key_by(self):
         "ensure multiple key-bys in where clause excepts"
@@ -122,10 +122,10 @@ class TestCql(AvroTester):
         conn = init()
         r = conn.execute('SELECT "birthdate" FROM Indexed1 WHERE "birthdate" = 
100L')
         assert len(r) == 2
-        assert r[0]['key'] == "asmith"
-        assert r[1]['key'] == "dozer"
-        assert len(r[0]['columns']) == 1
-        assert len(r[1]['columns']) == 1
+        assert r[0].key == "asmith"
+        assert r[1].key == "dozer"
+        assert len(r[0].columns) == 1
+        assert len(r[1].columns) == 1
 
     def test_index_scan_greater_than(self):
         "indexed scan where a column is greater than a value"
@@ -134,7 +134,7 @@ class TestCql(AvroTester):
             SELECT "birthdate" FROM Indexed1 WHERE "birthdate" = 100L AND 
"unindexed" > 200L
         """)
         assert len(r) == 1
-        assert r[0]['key'] == "asmith"
+        assert r[0].key == "asmith"
 
     def test_index_scan_with_start_key(self):
         "indexed scan with a starting key"
@@ -143,22 +143,22 @@ class TestCql(AvroTester):
             SELECT "birthdate" FROM Indexed1 WHERE "birthdate" = 100L AND KEY 
> "asmithZ"
         """)
         assert len(r) == 1
-        assert r[0]['key'] == "dozer"
+        assert r[0].key == "dozer"
 
     def test_no_where_clause(self):
         "empty where clause (range query w/o start key)"
         conn = init()
         r = conn.execute('SELECT "col" FROM Standard1 LIMIT 3')
         assert len(r) == 3
-        assert r[0]['key'] == "ka"
-        assert r[1]['key'] == "kb"
-        assert r[2]['key'] == "kc"
+        assert r[0].key == "ka"
+        assert r[1].key == "kb"
+        assert r[2].key == "kc"
 
     def test_column_count(self):
         "getting a result count instead of results"
         conn = init()
         r = conn.execute('SELECT COUNT(1L..4L) FROM StandardLong1 WHERE KEY = 
"aa";')
-        assert r == 4
+        assert r == 4, "expected 4 results, got %d" % (r and r or 0)
 
     def test_truncate_columnfamily(self):
         "truncating a column family"
@@ -171,33 +171,33 @@ class TestCql(AvroTester):
         "delete columns from a row"
         conn = init()
         r = conn.execute('SELECT "cd1", "col" FROM Standard1 WHERE KEY = "kd"')
-        assert "cd1" in [i['name'] for i in r[0]['columns']]
-        assert "col" in [i['name'] for i in r[0]['columns']]
+        assert "cd1" in [i.name for i in r[0].columns]
+        assert "col" in [i.name for i in r[0].columns]
         conn.execute('DELETE "cd1", "col" FROM Standard1 WHERE KEY = "kd"')
         r = conn.execute('SELECT "cd1", "col" FROM Standard1 WHERE KEY = "kd"')
-        assert len(r[0]['columns']) == 0
+        assert len(r[0].columns) == 0
 
     def test_delete_columns_multi_rows(self):
         "delete columns from multiple rows"
         conn = init()
         r = conn.execute('SELECT "col" FROM Standard1 WHERE KEY = "kc"')
-        assert len(r[0]['columns']) == 1
+        assert len(r[0].columns) == 1
         r = conn.execute('SELECT "col" FROM Standard1 WHERE KEY = "kd"')
-        assert len(r[0]['columns']) == 1
+        assert len(r[0].columns) == 1
 
         conn.execute('DELETE "col" FROM Standard1 WHERE KEY IN ("kc", "kd")')
         r = conn.execute('SELECT "col" FROM Standard1 WHERE KEY = "kc"')
-        assert len(r[0]['columns']) == 0
+        assert len(r[0].columns) == 0
         r = conn.execute('SELECT "col" FROM Standard1 WHERE KEY = "kd"')
-        assert len(r[0]['columns']) == 0
+        assert len(r[0].columns) == 0
 
     def test_delete_rows(self):
         "delete entire rows"
         conn = init()
         r = conn.execute('SELECT "cd1", "col" FROM Standard1 WHERE KEY = "kd"')
-        assert "cd1" in [i['name'] for i in r[0]['columns']]
-        assert "col" in [i['name'] for i in r[0]['columns']]
+        assert "cd1" in [i.name for i in r[0].columns]
+        assert "col" in [i.name for i in r[0].columns]
         conn.execute('DELETE FROM Standard1 WHERE KEY = "kd"')
         r = conn.execute('SELECT "cd1", "col" FROM Standard1 WHERE KEY = "kd"')
-        assert len(r[0]['columns']) == 0
+        assert len(r[0].columns) == 0
 


Reply via email to