Author: Wim Lavrijsen <[email protected]>
Branch: reflex-support
Changeset: r62603:9d5aae59d25e
Date: 2013-03-20 18:24 -0700
http://bitbucket.org/pypy/pypy/changeset/9d5aae59d25e/
Log: allow comparisons to None (and have them respond sanely)
diff --git a/pypy/module/cppyy/interp_cppyy.py
b/pypy/module/cppyy/interp_cppyy.py
--- a/pypy/module/cppyy/interp_cppyy.py
+++ b/pypy/module/cppyy/interp_cppyy.py
@@ -937,6 +937,10 @@
return None
def instance__eq__(self, w_other):
+ # special case: if other is None, compare pointer-style
+ if self.space.is_w(w_other, self.space.w_None):
+ return self.space.wrap(not self._rawobject)
+
# get here if no class-specific overloaded operator is available, try
to
# find a global overload in gbl, in __gnu_cxx (for iterators), or in
the
# scopes of the argument classes (TODO: implement that last option)
diff --git a/pypy/module/cppyy/pythonify.py b/pypy/module/cppyy/pythonify.py
--- a/pypy/module/cppyy/pythonify.py
+++ b/pypy/module/cppyy/pythonify.py
@@ -301,6 +301,24 @@
# general note: use 'in pyclass.__dict__' rather than 'hasattr' to prevent
# adding pythonizations multiple times in derived classes
+ # map __eq__/__ne__ through a comparison to None
+ if '__eq__' in pyclass.__dict__:
+ def __eq__(self, other):
+ if other is None: return not self
+ if type(self) is not type(other): return False
+ if not self and not other: return True
+ return self._cxx_eq(other)
+ pyclass._cxx_eq = pyclass.__dict__['__eq__']
+ pyclass.__eq__ = __eq__
+
+ if '__ne__' in pyclass.__dict__:
+ def __ne__(self, other):
+ if other is None: return not not self
+ if type(self) is not type(other): return True
+ return self._cxx_ne(other)
+ pyclass._cxx_ne = pyclass.__dict__['__ne__']
+ pyclass.__ne__ = __ne__
+
# map size -> __len__ (generally true for STL)
if 'size' in pyclass.__dict__ and not '__len__' in pyclass.__dict__ \
and callable(pyclass.size):
diff --git a/pypy/module/cppyy/test/test_cint.py
b/pypy/module/cppyy/test/test_cint.py
--- a/pypy/module/cppyy/test/test_cint.py
+++ b/pypy/module/cppyy/test/test_cint.py
@@ -539,3 +539,40 @@
#assert s == cppyy.bind_object(cobj, "TString")
assert s == cppyy.bind_object(addr, s.__class__)
assert s == cppyy.bind_object(addr, "TString")
+
+ def test09_object_and_pointer_comparisons(self):
+ """Verify object and pointer comparisons"""
+
+ import cppyy
+ gbl = cppyy.gbl
+
+ c1 = cppyy.bind_object(0, gbl.TCanvas)
+ assert c1 == None
+ assert None == c1
+
+ c2 = cppyy.bind_object(0, gbl.TCanvas)
+ assert c1 == c2
+ assert c2 == c1
+
+ # TLorentzVector overrides operator==
+ l1 = cppyy.bind_object(0, gbl.TLorentzVector)
+ assert l1 == None
+ assert None == l1
+
+ assert c1 != l1
+ assert l1 != c1
+
+ l2 = cppyy.bind_object(0, gbl.TLorentzVector)
+ assert l1 == l2
+ assert l2 == l1
+
+ l3 = gbl.TLorentzVector(1, 2, 3, 4)
+ l4 = gbl.TLorentzVector(1, 2, 3, 4)
+ l5 = gbl.TLorentzVector(4, 3, 2, 1)
+ assert l3 == l4
+ assert l4 == l3
+
+ assert l3 != None # like this to ensure __ne__ is
called
+ assert None != l3 # id.
+ assert l3 != l5
+ assert l5 != l3
diff --git a/pypy/module/cppyy/test/test_datatypes.py
b/pypy/module/cppyy/test/test_datatypes.py
--- a/pypy/module/cppyy/test/test_datatypes.py
+++ b/pypy/module/cppyy/test/test_datatypes.py
@@ -583,7 +583,44 @@
c.destruct()
- def test16_object_validity(self):
+ def test16_object_and_pointer_comparisons(self):
+ """Verify object and pointer comparisons"""
+
+ import cppyy
+ gbl = cppyy.gbl
+
+ c1 = cppyy.bind_object(0, gbl.cppyy_test_data)
+ assert c1 == None
+ assert None == c1
+
+ c2 = cppyy.bind_object(0, gbl.cppyy_test_data)
+ assert c1 == c2
+ assert c2 == c1
+
+ # four_vector overrides operator==
+ l1 = cppyy.bind_object(0, gbl.four_vector)
+ assert l1 == None
+ assert None == l1
+
+ assert c1 != l1
+ assert l1 != c1
+
+ l2 = cppyy.bind_object(0, gbl.four_vector)
+ assert l1 == l2
+ assert l2 == l1
+
+ l3 = gbl.four_vector(1, 2, 3, 4)
+ l4 = gbl.four_vector(1, 2, 3, 4)
+ l5 = gbl.four_vector(4, 3, 2, 1)
+ assert l3 == l4
+ assert l4 == l3
+
+ assert l3 != None # like this to ensure __ne__ is
called
+ assert None != l3 # id.
+ assert l3 != l5
+ assert l5 != l3
+
+ def test17_object_validity(self):
"""Test object validity checking"""
from cppyy import gbl
@@ -597,7 +634,7 @@
assert not d2
- def test17_buffer_reshaping(self):
+ def test18_buffer_reshaping(self):
"""Test usage of buffer sizing"""
import cppyy
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit