Author: cito
Date: Sat Jul 23 09:21:57 2016
New Revision: 880
Log:
Enable garbage collection after deleting DB instance
Needed to add destructor and weak references since desctructors
are not called when there are circular references.
Modified:
trunk/pg.py
trunk/tests/test_classic_dbwrapper.py
Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Sat Jul 23 08:06:19 2016 (r879)
+++ trunk/pg.py Sat Jul 23 09:21:57 2016 (r880)
@@ -35,6 +35,7 @@
import select
import warnings
+import weakref
from datetime import date, time, datetime, timedelta, tzinfo
from decimal import Decimal
@@ -321,11 +322,7 @@
_re_array_escape = _re_record_escape = regex(r'(["\\])')
def __init__(self, db):
- self.db = db
- self.encode_json = db.encode_json
- db = db.db
- self.escape_bytea = db.escape_bytea
- self.escape_string = db.escape_string
+ self.db = weakref.proxy(db)
@classmethod
def _adapt_bool(cls, v):
@@ -356,7 +353,7 @@
def _adapt_bytea(self, v):
"""Adapt a bytea parameter."""
- return self.escape_bytea(v)
+ return self.db.escape_bytea(v)
def _adapt_json(self, v):
"""Adapt a json parameter."""
@@ -364,7 +361,7 @@
return None
if isinstance(v, basestring):
return v
- return self.encode_json(v)
+ return self.db.encode_json(v)
@classmethod
def _adapt_text_array(cls, v):
@@ -417,7 +414,7 @@
self._adapt_bytea_array(v) for v in v) + b'}'
if v is None:
return b'null'
- return self.escape_bytea(v).replace(b'\\', b'\\\\')
+ return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
def _adapt_json_array(self, v):
"""Adapt a json array parameter."""
@@ -427,7 +424,7 @@
if not v:
return 'null'
if not isinstance(v, basestring):
- v = self.encode_json(v)
+ v = self.db.encode_json(v)
if self._re_array_quote.search(v):
v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
return v
@@ -549,17 +546,17 @@
if isinstance(value, Literal):
return value
if isinstance(value, Bytea):
- value = self.escape_bytea(value)
+ value = self.db.escape_bytea(value)
if bytes is not str: # Python >= 3.0
value = value.decode('ascii')
elif isinstance(value, Json):
if value.encode:
return value.encode()
- value = self.encode_json(value)
+ value = self.db.encode_json(value)
elif isinstance(value, (datetime, date, time, timedelta)):
value = str(value)
if isinstance(value, basestring):
- value = self.escape_string(value)
+ value = self.db.escape_string(value)
return "'%s'" % value
if isinstance(value, bool):
return 'true' if value else 'false'
@@ -1085,14 +1082,11 @@
def __init__(self, db):
"""Initialize type cache for connection."""
super(DbTypes, self).__init__()
+ self._db = weakref.proxy(db)
self._regtypes = False
- self._get_attnames = db.get_attnames
self._typecasts = Typecasts()
self._typecasts.get_attnames = self.get_attnames
- self._typecasts.connection = db
- db = db.db
- self.query = db.query
- self.escape_string = db.escape_string
+ self._typecasts.connection = self._db
if db.server_version < 80400:
# older remote databases (not officially supported)
self._query_pg_type = (
@@ -1127,7 +1121,7 @@
"""Get the type info from the database if it is not cached."""
try:
q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
- res = self.query(q, (key,)).getresult()
+ res = self._db.query(q, (key,)).getresult()
except ProgrammingError:
res = None
if not res:
@@ -1152,7 +1146,7 @@
return None
if not typ.relid:
return None
- return self._get_attnames(typ.relid, with_oid=False)
+ return self._db.get_attnames(typ.relid, with_oid=False)
def get_typecast(self, typ):
"""Get the typecast function for the given database type."""
@@ -1451,6 +1445,16 @@
else:
self.rollback()
+ def __del__(self):
+ try:
+ db = self.db
+ except AttributeError:
+ db = None
+ if db:
+ db.set_cast_hook(None)
+ if self._closeable:
+ db.close()
+
# Auxiliary methods
def _do_debug(self, *args):
Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py Sat Jul 23 08:06:19 2016
(r879)
+++ trunk/tests/test_classic_dbwrapper.py Sat Jul 23 09:21:57 2016
(r880)
@@ -4466,10 +4466,8 @@
db = DB()
db.query("select $1::int as r", 42).dictresult()
db.close()
- del db
self.getLeaks(fut)
- @unittest.skip("this still needs to be resolved")
def testLeaksWithoutClose(self):
def fut():
db = DB()
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql