Author: Armin Rigo <ar...@tunes.org> Branch: Changeset: r481:a225a58c9067 Date: 2012-06-21 16:36 +0200 http://bitbucket.org/cffi/cffi/changeset/a225a58c9067/
Log: (fijal, arigo) (early sprint) cdata pointer comparison. diff --git a/c/_ffi_backend.c b/c/_ffi_backend.c --- a/c/_ffi_backend.c +++ b/c/_ffi_backend.c @@ -1148,34 +1148,52 @@ static PyObject *cdata_richcompare(PyObject *v, PyObject *w, int op) { - CDataObject *obv, *obw; - int equal; - PyObject *res; - - if (op != Py_EQ && op != Py_NE) - goto Unimplemented; + int res, full_order; + PyObject *pyres; + char *v_cdata, *w_cdata; + + full_order = (op != Py_EQ && op != Py_NE); assert(CData_Check(v)); - obv = (CDataObject *)v; + v_cdata = ((CDataObject *)v)->c_data; + if (full_order && + (((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY)) + goto Error; if (w == Py_None) { - equal = (obv->c_data == NULL); + w_cdata = NULL; } else if (CData_Check(w)) { - obw = (CDataObject *)w; - equal = (obv->c_type == obw->c_type) && (obv->c_data == obw->c_data); + w_cdata = ((CDataObject *)w)->c_data; + if (full_order && + (((CDataObject *)w)->c_type->ct_flags & CT_PRIMITIVE_ANY)) + goto Error; } else goto Unimplemented; - res = (equal ^ (op == Py_NE)) ? Py_True : Py_False; + switch (op) { + case Py_EQ: res = (v_cdata == w_cdata); break; + case Py_NE: res = (v_cdata != w_cdata); break; + case Py_LT: res = (v_cdata < w_cdata); break; + case Py_LE: res = (v_cdata <= w_cdata); break; + case Py_GT: res = (v_cdata > w_cdata); break; + case Py_GE: res = (v_cdata >= w_cdata); break; + default: res = -1; + } + pyres = res ? Py_True : Py_False; done: - Py_INCREF(res); - return res; + Py_INCREF(pyres); + return pyres; Unimplemented: - res = Py_NotImplemented; + pyres = Py_NotImplemented; goto done; + + Error: + PyErr_SetString(PyExc_TypeError, + "cannot do comparison on a primitive cdata"); + return NULL; } static long cdata_hash(CDataObject *cd) diff --git a/cffi/backend_ctypes.py b/cffi/backend_ctypes.py --- a/cffi/backend_ctypes.py +++ b/cffi/backend_ctypes.py @@ -1,4 +1,4 @@ -import ctypes, ctypes.util +import ctypes, ctypes.util, operator from . import model class CTypesData(object): @@ -76,10 +76,42 @@ raise TypeError("cdata %r does not support iteration" % ( self._get_c_name()),) + def _make_cmp(name): + cmpfunc = getattr(operator, name) + def cmp(self, other): + if isinstance(other, CTypesData): + return cmpfunc(self._convert_to_address(None), + other._convert_to_address(None)) + elif other is None: + return cmpfunc(self._convert_to_address(None), 0) + else: + return NotImplemented + cmp.func_name = name + return cmp + + __eq__ = _make_cmp('__eq__') + __ne__ = _make_cmp('__ne__') + __lt__ = _make_cmp('__lt__') + __le__ = _make_cmp('__le__') + __gt__ = _make_cmp('__gt__') + __ge__ = _make_cmp('__ge__') + + def __hash__(self): + return hash(type(self)) ^ hash(self._convert_to_address(None)) + class CTypesGenericPrimitive(CTypesData): __slots__ = [] + def __eq__(self, other): + return self is other + + def __ne__(self, other): + return self is not other + + def __hash__(self): + return object.__hash__(self) + class CTypesGenericArray(CTypesData): __slots__ = [] @@ -119,18 +151,6 @@ def __nonzero__(self): return bool(self._address) - def __eq__(self, other): - if other is None: - return not bool(self._address) - return (type(self) is type(other) and - self._address == other._address) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(type(self)) ^ hash(self._address) - @classmethod def _to_ctypes(cls, value): if value is None: diff --git a/testing/backend_tests.py b/testing/backend_tests.py --- a/testing/backend_tests.py +++ b/testing/backend_tests.py @@ -65,6 +65,7 @@ q = ffi.cast(c_decl, long(min - 1)) assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max assert q != p + assert int(q) == int(p) assert hash(q) != hash(p) # unlikely py.test.raises(OverflowError, ffi.new, c_decl, min - 1) py.test.raises(OverflowError, ffi.new, c_decl, max + 1) @@ -818,6 +819,67 @@ assert p == s+0 assert p+1 == s+1 + def test_pointer_comparison(self): + ffi = FFI(backend=self.Backend()) + s = ffi.new("short[]", range(100)) + p = ffi.cast("short *", s) + assert (p < s) is False + assert (p <= s) is True + assert (p == s) is True + assert (p != s) is False + assert (p > s) is False + assert (p >= s) is True + assert (s < p) is False + assert (s <= p) is True + assert (s == p) is True + assert (s != p) is False + assert (s > p) is False + assert (s >= p) is True + q = p + 1 + assert (q < s) is False + assert (q <= s) is False + assert (q == s) is False + assert (q != s) is True + assert (q > s) is True + assert (q >= s) is True + assert (s < q) is True + assert (s <= q) is True + assert (s == q) is False + assert (s != q) is True + assert (s > q) is False + assert (s >= q) is False + assert (q < p) is False + assert (q <= p) is False + assert (q == p) is False + assert (q != p) is True + assert (q > p) is True + assert (q >= p) is True + assert (p < q) is True + assert (p <= q) is True + assert (p == q) is False + assert (p != q) is True + assert (p > q) is False + assert (p >= q) is False + # + assert (None == s) is False + assert (None != s) is True + assert (s == None) is False + assert (s != None) is True + assert (None == q) is False + assert (None != q) is True + assert (q == None) is False + assert (q != None) is True + + def test_no_integer_comparison(self): + ffi = FFI(backend=self.Backend()) + x = ffi.cast("int", 123) + y = ffi.cast("int", 456) + py.test.raises(TypeError, "x < y") + # + z = ffi.cast("double", 78.9) + py.test.raises(TypeError, "x < z") + py.test.raises(TypeError, "z < y") + def test_ffi_buffer_ptr(self): ffi = FFI(backend=self.Backend()) a = ffi.new("short", 100) _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit