Author: cito
Date: Sat Jul 23 09:21:57 2016
New Revision: 880

Log:
Enable garbage collection after deleting DB instance

Needed to add destructor and weak references since desctructors
are not called when there are circular references.

Modified:
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Sat Jul 23 08:06:19 2016        (r879)
+++ trunk/pg.py Sat Jul 23 09:21:57 2016        (r880)
@@ -35,6 +35,7 @@
 
 import select
 import warnings
+import weakref
 
 from datetime import date, time, datetime, timedelta, tzinfo
 from decimal import Decimal
@@ -321,11 +322,7 @@
     _re_array_escape = _re_record_escape = regex(r'(["\\])')
 
     def __init__(self, db):
-        self.db = db
-        self.encode_json = db.encode_json
-        db = db.db
-        self.escape_bytea = db.escape_bytea
-        self.escape_string = db.escape_string
+        self.db = weakref.proxy(db)
 
     @classmethod
     def _adapt_bool(cls, v):
@@ -356,7 +353,7 @@
 
     def _adapt_bytea(self, v):
         """Adapt a bytea parameter."""
-        return self.escape_bytea(v)
+        return self.db.escape_bytea(v)
 
     def _adapt_json(self, v):
         """Adapt a json parameter."""
@@ -364,7 +361,7 @@
             return None
         if isinstance(v, basestring):
             return v
-        return self.encode_json(v)
+        return self.db.encode_json(v)
 
     @classmethod
     def _adapt_text_array(cls, v):
@@ -417,7 +414,7 @@
                 self._adapt_bytea_array(v) for v in v) + b'}'
         if v is None:
             return b'null'
-        return self.escape_bytea(v).replace(b'\\', b'\\\\')
+        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
 
     def _adapt_json_array(self, v):
         """Adapt a json array parameter."""
@@ -427,7 +424,7 @@
         if not v:
             return 'null'
         if not isinstance(v, basestring):
-            v = self.encode_json(v)
+            v = self.db.encode_json(v)
         if self._re_array_quote.search(v):
             v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
         return v
@@ -549,17 +546,17 @@
         if isinstance(value, Literal):
             return value
         if isinstance(value, Bytea):
-            value = self.escape_bytea(value)
+            value = self.db.escape_bytea(value)
             if bytes is not str:  # Python >= 3.0
                 value = value.decode('ascii')
         elif isinstance(value, Json):
             if value.encode:
                 return value.encode()
-            value = self.encode_json(value)
+            value = self.db.encode_json(value)
         elif isinstance(value, (datetime, date, time, timedelta)):
             value = str(value)
         if isinstance(value, basestring):
-            value = self.escape_string(value)
+            value = self.db.escape_string(value)
             return "'%s'" % value
         if isinstance(value, bool):
             return 'true' if value else 'false'
@@ -1085,14 +1082,11 @@
     def __init__(self, db):
         """Initialize type cache for connection."""
         super(DbTypes, self).__init__()
+        self._db = weakref.proxy(db)
         self._regtypes = False
-        self._get_attnames = db.get_attnames
         self._typecasts = Typecasts()
         self._typecasts.get_attnames = self.get_attnames
-        self._typecasts.connection = db
-        db = db.db
-        self.query = db.query
-        self.escape_string = db.escape_string
+        self._typecasts.connection = self._db
         if db.server_version < 80400:
             # older remote databases (not officially supported)
             self._query_pg_type = (
@@ -1127,7 +1121,7 @@
         """Get the type info from the database if it is not cached."""
         try:
             q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
-            res = self.query(q, (key,)).getresult()
+            res = self._db.query(q, (key,)).getresult()
         except ProgrammingError:
             res = None
         if not res:
@@ -1152,7 +1146,7 @@
                 return None
         if not typ.relid:
             return None
-        return self._get_attnames(typ.relid, with_oid=False)
+        return self._db.get_attnames(typ.relid, with_oid=False)
 
     def get_typecast(self, typ):
         """Get the typecast function for the given database type."""
@@ -1451,6 +1445,16 @@
         else:
             self.rollback()
 
+    def __del__(self):
+        try:
+            db = self.db
+        except AttributeError:
+            db = None
+        if db:
+            db.set_cast_hook(None)
+            if self._closeable:
+                db.close()
+
     # Auxiliary methods
 
     def _do_debug(self, *args):

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Sat Jul 23 08:06:19 2016        
(r879)
+++ trunk/tests/test_classic_dbwrapper.py       Sat Jul 23 09:21:57 2016        
(r880)
@@ -4466,10 +4466,8 @@
             db = DB()
             db.query("select $1::int as r", 42).dictresult()
             db.close()
-            del db
         self.getLeaks(fut)
 
-    @unittest.skip("this still needs to be resolved")
     def testLeaksWithoutClose(self):
         def fut():
             db = DB()
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to