On 23.4.2012 18:47, Petr Viktorin wrote:
On 04/23/2012 04:33 PM, Jan Cholasta wrote:
Hi,

this patch replaces _call_exc_callbacks with a function wrapper, which
will automatically call exception callbacks when an exception is raised
from the function. This removes the need to specify the function and its
arguments twice (once in the function call itself and once in
_call_exc_callbacks).

Code like this:

try:
# original call
ret = func(arg, kwarg=0)
except ExecutionError, e:
try:
# the function and its arguments need to be specified again!
ret = self._call_exc_callbacks(args, options, e, func, arg,
kwarg=0)
except ExecutionErrorSubclass, e:
handle_error(e)

becomes this:

try:
ret = self._exc_wrapper(args, options, func)(arg, kwarg=0)
except ExecutionErrorSubclass, e:
handle_error(e)

As you can see, the resulting code is shorter and you don't have to
remember to make changes to the arguments in two places.

Honza

Please add a test, too. I've attached one you can use.

OK, thanks.


See also some style nitpicks below.

--
Jan Cholasta

freeipa-jcholast-76-refactor-exc-callback.patch


From 8e070f571472ed5a27339bcc980b67ecca41b337 Mon Sep 17 00:00:00 2001
From: Jan Cholasta<jchol...@redhat.com>
Date: Thu, 19 Apr 2012 08:06:32 -0400
Subject: [PATCH] Refactor exc_callback invocation.

Replace _call_exc_callbacks with a function wrapper, which will
automatically
call exception callbacks when an exception is raised from the
function. This
removes the need to specify the function and its arguments twice (once
in the
function call itself and once in _call_exc_callbacks).
---
ipalib/plugins/baseldap.py | 227
++++++++++++++----------------------------
ipalib/plugins/entitle.py | 19 ++--
ipalib/plugins/group.py | 7 +-
ipalib/plugins/permission.py | 27 +++---
ipalib/plugins/pwpolicy.py | 11 +-
5 files changed, 109 insertions(+), 182 deletions(-)

diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py
index f185977..f7a3bbc 100644
--- a/ipalib/plugins/baseldap.py
+++ b/ipalib/plugins/baseldap.py
@@ -744,26 +744,24 @@ class CallbackInterface(Method):
else:
klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)

- def _call_exc_callbacks(self, args, options, exc, call_func,
*call_args, **call_kwargs):
- rv = None
- for i in xrange(len(getattr(self, 'EXC_CALLBACKS', []))):
- callback = self.EXC_CALLBACKS[i]
- try:
- if hasattr(callback, 'im_self'):
- rv = callback(
- args, options, exc, call_func, *call_args, **call_kwargs
- )
- else:
- rv = callback(
- self, args, options, exc, call_func, *call_args,
- **call_kwargs
- )
- except errors.ExecutionError, e:
- if (i + 1)< len(self.EXC_CALLBACKS):
- exc = e
- continue
- raise e
- return rv
+ def _exc_wrapper(self, keys, options, call_func):

Consider adding a docstring, e.g.
"""Function wrapper that automatically calls exception callbacks"""

Added.


+ def wrapped(*call_args, **call_kwargs):
+ func = call_func
+ callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
+ while True:
+ try:

You have some clever code here, rebinding `func` like you do.
It'd be nice if there was a comment warning that you're redefining a
function, in case someone who's not a Python expert looks at this.
Consider:
# `func` is either the original function, or the current error callback

Changed the code a bit so that it is more readable.


+ return func(*call_args, **call_kwargs)
+ except errors.ExecutionError, e:
+ if len(callbacks) == 0:

Use just `if not callbacks`, as per PEP8.

OK.


+ raise
+ callback = callbacks.pop(0)
+ if hasattr(callback, 'im_self'):
+ def func(*args, **kwargs): #pylint: disable=E0102
+ return callback(keys, options, e, call_func, *args, **kwargs)
+ else:
+ def func(*args, **kwargs): #pylint: disable=E0102
+ return callback(self, keys, options, e, call_func, *args, **kwargs)
+ return wrapped


class BaseLDAPCommand(CallbackInterface, Command):
[...]
diff --git a/ipalib/plugins/entitle.py b/ipalib/plugins/entitle.py
index 28d2c5d..6ade854 100644
--- a/ipalib/plugins/entitle.py
+++ b/ipalib/plugins/entitle.py
@@ -642,12 +642,12 @@ class entitle_import(LDAPUpdate):
If we are adding the first entry there are no updates so EmptyModlist
will get thrown. Ignore it.
"""
- if isinstance(exc, errors.EmptyModlist):
- if not getattr(context, 'entitle_import', False):
- raise exc
- return (call_args, {})
- else:
- raise exc
+ if call_func.func_name == 'update_entry':
+ if isinstance(exc, errors.EmptyModlist):
+ if not getattr(context, 'entitle_import', False):

You didn't mention the additional checks for 'update_entry' in the
commit message.

Fixed.


By the way, the need for these checks suggests that a per-class registry
of error callbacks might not be the best design. But that's for more
long-term thinking.

They are not strictly needed, I added them just to be on the safe side.


+ raise exc
+ return (call_args, {})
+ raise exc

def execute(self, *keys, **options):
super(entitle_import, self).execute(*keys, **options)
[...]


Updated patch attached.

Honza

--
Jan Cholasta
>From cd7f8a7246c486739187795e658f372ec3c29c37 Mon Sep 17 00:00:00 2001
From: Jan Cholasta <jchol...@redhat.com>
Date: Thu, 19 Apr 2012 08:06:32 -0400
Subject: [PATCH] Refactor exc_callback invocation.

Replace _call_exc_callbacks with a function wrapper, which will automatically
call exception callbacks when an exception is raised from the function. This
removes the need to specify the function and its arguments twice (once in the
function call itself and once in _call_exc_callbacks).

Add some extra checks to existing exception callbacks.
---
 ipalib/plugins/baseldap.py                |  231 ++++++++++-------------------
 ipalib/plugins/entitle.py                 |   19 ++-
 ipalib/plugins/group.py                   |    7 +-
 ipalib/plugins/permission.py              |   27 ++--
 ipalib/plugins/pwpolicy.py                |   11 +-
 tests/test_xmlrpc/test_baseldap_plugin.py |   66 ++++++++
 6 files changed, 179 insertions(+), 182 deletions(-)
 create mode 100644 tests/test_xmlrpc/test_baseldap_plugin.py

diff --git a/ipalib/plugins/baseldap.py b/ipalib/plugins/baseldap.py
index f185977..7ee5fa1 100644
--- a/ipalib/plugins/baseldap.py
+++ b/ipalib/plugins/baseldap.py
@@ -744,26 +744,28 @@ class CallbackInterface(Method):
         else:
             klass.INTERACTIVE_PROMPT_CALLBACKS.append(callback)
 
-    def _call_exc_callbacks(self, args, options, exc, call_func, *call_args, **call_kwargs):
-        rv = None
-        for i in xrange(len(getattr(self, 'EXC_CALLBACKS', []))):
-            callback = self.EXC_CALLBACKS[i]
-            try:
-                if hasattr(callback, 'im_self'):
-                    rv = callback(
-                        args, options, exc, call_func, *call_args, **call_kwargs
-                    )
-                else:
-                    rv = callback(
-                        self, args, options, exc, call_func, *call_args,
-                        **call_kwargs
-                    )
-            except errors.ExecutionError, e:
-                if (i + 1) < len(self.EXC_CALLBACKS):
-                    exc = e
-                    continue
-                raise e
-        return rv
+    def _exc_wrapper(self, keys, options, call_func):
+        """Function wrapper that automatically calls exception callbacks"""
+        def wrapped(*call_args, **call_kwargs):
+            # call call_func first
+            func = call_func
+            callbacks = list(getattr(self, 'EXC_CALLBACKS', []))
+            while True:
+                try:
+                    return func(*call_args, **call_kwargs)
+                except errors.ExecutionError, e:
+                    if not callbacks:
+                        raise
+                    # call exc_callback in the next loop
+                    callback = callbacks.pop(0)
+                    if hasattr(callback, 'im_self'):
+                        def exc_func(*args, **kwargs):
+                            return callback(keys, options, e, call_func, *args, **kwargs)
+                    else:
+                        def exc_func(*args, **kwargs):
+                            return callback(self, keys, options, e, call_func, *args, **kwargs)
+                    func = exc_func
+        return wrapped
 
 
 class BaseLDAPCommand(CallbackInterface, Command):
@@ -883,17 +885,11 @@ last, after all sets and adds."""),
 
         if needldapattrs:
             try:
-                (dn, old_entry) = ldap.get_entry(
+                (dn, old_entry) = self._exc_wrapper(keys, options, ldap.get_entry)(
                     dn, needldapattrs, normalize=self.obj.normalize_dn
                 )
-            except errors.ExecutionError, e:
-                try:
-                    (dn, old_entry) = self._call_exc_callbacks(
-                        keys, options, e, ldap.get_entry, dn, [],
-                        normalize=self.obj.normalize_dn
-                    )
-                except errors.NotFound:
-                    self.obj.handle_not_found(*keys)
+            except errors.NotFound:
+                self.obj.handle_not_found(*keys)
             for attr in needldapattrs:
                 entry_attrs[attr] = old_entry.get(attr, [])
 
@@ -1019,29 +1015,23 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
         _check_limit_object_class(self.api.Backend.ldap2.schema.attribute_types(self.obj.disallow_object_classes), entry_attrs.keys(), allow_only=False)
 
         try:
-            ldap.add_entry(dn, entry_attrs, normalize=self.obj.normalize_dn)
-        except errors.ExecutionError, e:
-            try:
-                self._call_exc_callbacks(
-                    keys, options, e, ldap.add_entry, dn, entry_attrs,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                parent = self.obj.parent_object
-                if parent:
-                    raise errors.NotFound(
-                        reason=self.obj.parent_not_found_msg % {
-                            'parent': keys[-2],
-                            'oname': self.api.Object[parent].object_name,
-                        }
-                    )
+            self._exc_wrapper(keys, options, ldap.add_entry)(dn, entry_attrs, normalize=self.obj.normalize_dn)
+        except errors.NotFound:
+            parent = self.obj.parent_object
+            if parent:
                 raise errors.NotFound(
-                    reason=self.obj.container_not_found_msg % {
-                        'container': self.obj.container_dn,
+                    reason=self.obj.parent_not_found_msg % {
+                        'parent': keys[-2],
+                        'oname': self.api.Object[parent].object_name,
                     }
                 )
-            except errors.DuplicateEntry:
-                self.obj.handle_duplicate_entry(*keys)
+            raise errors.NotFound(
+                reason=self.obj.container_not_found_msg % {
+                    'container': self.obj.container_dn,
+                }
+            )
+        except errors.DuplicateEntry:
+            self.obj.handle_duplicate_entry(*keys)
 
         try:
             if self.obj.rdn_attribute:
@@ -1050,22 +1040,16 @@ class LDAPCreate(BaseLDAPCommand, crud.Create):
                     object_class = self.obj.object_class
                 else:
                     object_class = None
-                (dn, entry_attrs) = ldap.find_entry_by_attr(
+                (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.find_entry_by_attr)(
                     self.obj.primary_key.name, keys[-1], object_class, attrs_list,
                     self.obj.container_dn
                 )
             else:
-                (dn, entry_attrs) = ldap.get_entry(
+                (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
                     dn, attrs_list, normalize=self.obj.normalize_dn
                 )
-        except errors.ExecutionError, e:
-            try:
-                (dn, entry_attrs) = self._call_exc_callbacks(
-                    keys, options, e, ldap.get_entry, dn, attrs_list,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                self.obj.handle_not_found(*keys)
+        except errors.NotFound:
+            self.obj.handle_not_found(*keys)
 
         for callback in self.POST_CALLBACKS:
             if hasattr(callback, 'im_self'):
@@ -1181,17 +1165,11 @@ class LDAPRetrieve(LDAPQuery):
                 dn = callback(self, ldap, dn, attrs_list, *keys, **options)
 
         try:
-            (dn, entry_attrs) = ldap.get_entry(
+            (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
                 dn, attrs_list, normalize=self.obj.normalize_dn
             )
-        except errors.ExecutionError, e:
-            try:
-                (dn, entry_attrs) = self._call_exc_callbacks(
-                    keys, options, e, ldap.get_entry, dn, attrs_list,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                self.obj.handle_not_found(*keys)
+        except errors.NotFound:
+            self.obj.handle_not_found(*keys)
 
         if options.get('rights', False) and options.get('all', False):
             entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
@@ -1297,7 +1275,7 @@ class LDAPUpdate(LDAPQuery, crud.Update):
 
             if self.obj.rdn_is_primary_key and self.obj.primary_key.name in entry_attrs:
                 # RDN change
-                ldap.update_entry_rdn(dn,
+                self._exc_wrapper(keys, options, ldap.update_entry_rdn)(dn,
                     unicode('%s=%s' % (self.obj.primary_key.name,
                     entry_attrs[self.obj.primary_key.name])))
                 rdnkeys = keys[:-1] + (entry_attrs[self.obj.primary_key.name], )
@@ -1306,37 +1284,25 @@ class LDAPUpdate(LDAPQuery, crud.Update):
                 options['rdnupdate'] = True
                 rdnupdate = True
 
-            ldap.update_entry(dn, entry_attrs, normalize=self.obj.normalize_dn)
-        except errors.ExecutionError, e:
             # Exception callbacks will need to test for options['rdnupdate']
             # to decide what to do. An EmptyModlist in this context doesn't
             # mean an error occurred, just that there were no other updates to
             # perform.
-            try:
-                self._call_exc_callbacks(
-                    keys, options, e, ldap.update_entry, dn, entry_attrs,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.EmptyModlist, e:
-                if not rdnupdate:
-                    raise e
-            except errors.NotFound:
-                self.obj.handle_not_found(*keys)
+            self._exc_wrapper(keys, options, ldap.update_entry)(dn, entry_attrs, normalize=self.obj.normalize_dn)
+        except errors.EmptyModlist, e:
+            if not rdnupdate:
+                raise e
+        except errors.NotFound:
+            self.obj.handle_not_found(*keys)
 
         try:
-            (dn, entry_attrs) = ldap.get_entry(
+            (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
                 dn, attrs_list, normalize=self.obj.normalize_dn
             )
-        except errors.ExecutionError, e:
-            try:
-                (dn, entry_attrs) = self._call_exc_callbacks(
-                    keys, options, e, ldap.get_entry, dn, attrs_list,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                raise errors.MidairCollision(
-                    format=_('the entry was deleted while being modified')
-                )
+        except errors.NotFound:
+            raise errors.MidairCollision(
+                format=_('the entry was deleted while being modified')
+            )
 
         if options.get('rights', False) and options.get('all', False):
             entry_attrs['attributelevelrights'] = get_effective_rights(ldap, dn)
@@ -1399,15 +1365,9 @@ class LDAPDelete(LDAPMultiQuery):
                         for (dn_, entry_attrs) in subentries:
                             delete_subtree(dn_)
                 try:
-                    ldap.delete_entry(base_dn, normalize=self.obj.normalize_dn)
-                except errors.ExecutionError, e:
-                    try:
-                        self._call_exc_callbacks(
-                            nkeys, options, e, ldap.delete_entry, base_dn,
-                            normalize=self.obj.normalize_dn
-                        )
-                    except errors.NotFound:
-                        self.obj.handle_not_found(*nkeys)
+                    self._exc_wrapper(nkeys, options, ldap.delete_entry)(base_dn, normalize=self.obj.normalize_dn)
+                except errors.NotFound:
+                    self.obj.handle_not_found(*nkeys)
 
             delete_subtree(dn)
 
@@ -1560,17 +1520,11 @@ class LDAPAddMember(LDAPModMember):
             )
 
         try:
-            (dn, entry_attrs) = ldap.get_entry(
+            (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
                 dn, attrs_list, normalize=self.obj.normalize_dn
             )
-        except errors.ExecutionError, e:
-            try:
-                (dn, entry_attrs) = self._call_exc_callbacks(
-                    keys, options, e, ldap.get_entry, dn, attrs_list,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                self.obj.handle_not_found(*keys)
+        except errors.NotFound:
+            self.obj.handle_not_found(*keys)
 
         for callback in self.POST_CALLBACKS:
             if hasattr(callback, 'im_self'):
@@ -1668,17 +1622,11 @@ class LDAPRemoveMember(LDAPModMember):
         time.sleep(.3)
 
         try:
-            (dn, entry_attrs) = ldap.get_entry(
+            (dn, entry_attrs) = self._exc_wrapper(keys, options, ldap.get_entry)(
                 dn, attrs_list, normalize=self.obj.normalize_dn
             )
-        except errors.ExecutionError, e:
-            try:
-                (dn, entry_attrs) = self._call_exc_callbacks(
-                    keys, options, e, ldap.get_entry, dn, attrs_list,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                self.obj.handle_not_found(*keys)
+        except errors.NotFound:
+            self.obj.handle_not_found(*keys)
 
         for callback in self.POST_CALLBACKS:
             if hasattr(callback, 'im_self'):
@@ -1884,20 +1832,13 @@ class LDAPSearch(BaseLDAPCommand, crud.Search):
                 )
 
         try:
-            (entries, truncated) = ldap.find_entries(
+            (entries, truncated) = self._exc_wrapper(args, options, ldap.find_entries)(
                 filter, attrs_list, base_dn, scope,
                 time_limit=options.get('timelimit', None),
                 size_limit=options.get('sizelimit', None)
             )
-        except errors.ExecutionError, e:
-            try:
-                (entries, truncated) = self._call_exc_callbacks(
-                    args, options, e, ldap.find_entries, filter, attrs_list,
-                    base_dn, scope=ldap.SCOPE_ONELEVEL,
-                    normalize=self.obj.normalize_dn
-                )
-            except errors.NotFound:
-                (entries, truncated) = ([], False)
+        except errors.NotFound:
+            (entries, truncated) = ([], False)
 
         for callback in self.POST_CALLBACKS:
             if hasattr(callback, 'im_self'):
@@ -2030,21 +1971,15 @@ class LDAPAddReverseMember(LDAPModReverseMember):
             try:
                 options = {'%s' % self.member_attr: keys[-1]}
                 try:
-                    result = self.api.Command[self.member_command](attr, **options)
+                    result = self._exc_wrapper(keys, options, self.api.Command[self.member_command])(attr, **options)
                     if result['completed'] == 1:
                         completed = completed + 1
                     else:
                         failed['member'][self.reverse_attr].append((attr, result['failed']['member'][self.member_attr][0][1]))
-                except errors.ExecutionError, e:
-                    try:
-                        (dn, entry_attrs) = self._call_exc_callbacks(
-                            keys, options, e, self.member_command, dn, attrs_list,
-                            normalize=self.obj.normalize_dn
-                        )
-                    except errors.NotFound, e:
-                        msg = str(e)
-                        (attr, msg) = msg.split(':', 1)
-                        failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
+                except errors.NotFound, e:
+                    msg = str(e)
+                    (attr, msg) = msg.split(':', 1)
+                    failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
 
             except errors.PublicError, e:
                 failed['member'][self.reverse_attr].append((attr, unicode(msg)))
@@ -2143,21 +2078,15 @@ class LDAPRemoveReverseMember(LDAPModReverseMember):
             try:
                 options = {'%s' % self.member_attr: keys[-1]}
                 try:
-                    result = self.api.Command[self.member_command](attr, **options)
+                    result = self._exc_wrapper(keys, options, self.api.Command[self.member_command])(attr, **options)
                     if result['completed'] == 1:
                         completed = completed + 1
                     else:
                         failed['member'][self.reverse_attr].append((attr, result['failed']['member'][self.member_attr][0][1]))
-                except errors.ExecutionError, e:
-                    try:
-                        (dn, entry_attrs) = self._call_exc_callbacks(
-                            keys, options, e, self.member_command, dn, attrs_list,
-                            normalize=self.obj.normalize_dn
-                        )
-                    except errors.NotFound, e:
-                        msg = str(e)
-                        (attr, msg) = msg.split(':', 1)
-                        failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
+                except errors.NotFound, e:
+                    msg = str(e)
+                    (attr, msg) = msg.split(':', 1)
+                    failed['member'][self.reverse_attr].append((attr, unicode(msg.strip())))
 
             except errors.PublicError, e:
                 failed['member'][self.reverse_attr].append((attr, unicode(msg)))
diff --git a/ipalib/plugins/entitle.py b/ipalib/plugins/entitle.py
index 28d2c5d..6ade854 100644
--- a/ipalib/plugins/entitle.py
+++ b/ipalib/plugins/entitle.py
@@ -642,12 +642,12 @@ class entitle_import(LDAPUpdate):
         If we are adding the first entry there are no updates so EmptyModlist
         will get thrown. Ignore it.
         """
-        if isinstance(exc, errors.EmptyModlist):
-            if not getattr(context, 'entitle_import', False):
-                raise exc
-            return (call_args, {})
-        else:
-            raise exc
+        if call_func.func_name == 'update_entry':
+            if isinstance(exc, errors.EmptyModlist):
+                if not getattr(context, 'entitle_import', False):
+                    raise exc
+                return (call_args, {})
+        raise exc
 
     def execute(self, *keys, **options):
         super(entitle_import, self).execute(*keys, **options)
@@ -729,9 +729,10 @@ class entitle_sync(LDAPUpdate):
         return dn
 
     def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
-        if isinstance(exc, errors.EmptyModlist):
-            # If there is nothing to change we are already synchronized.
-            return
+        if call_func.func_name == 'update_entry':
+            if isinstance(exc, errors.EmptyModlist):
+                # If there is nothing to change we are already synchronized.
+                return
         raise exc
 
 api.register(entitle_sync)
diff --git a/ipalib/plugins/group.py b/ipalib/plugins/group.py
index 096cb9e..1320854 100644
--- a/ipalib/plugins/group.py
+++ b/ipalib/plugins/group.py
@@ -211,9 +211,10 @@ class group_mod(LDAPUpdate):
     def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
         # Check again for GID requirement in case someone tried to clear it
         # using --setattr.
-        if isinstance(exc, errors.ObjectclassViolation):
-            if 'gidNumber' in exc.message and 'posixGroup' in exc.message:
-                raise errors.RequirementError(name='gid')
+        if call_func.func_name == 'update_entry':
+            if isinstance(exc, errors.ObjectclassViolation):
+                if 'gidNumber' in exc.message and 'posixGroup' in exc.message:
+                    raise errors.RequirementError(name='gid')
         raise exc
 
 api.register(group_mod)
diff --git a/ipalib/plugins/permission.py b/ipalib/plugins/permission.py
index 92203f1..2d300e2 100644
--- a/ipalib/plugins/permission.py
+++ b/ipalib/plugins/permission.py
@@ -374,20 +374,19 @@ class permission_mod(LDAPUpdate):
         return dn
 
     def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
-        if isinstance(exc, errors.EmptyModlist):
-            aciupdate = getattr(context, 'aciupdate')
-            opts = copy.copy(options)
-            # Clear the aci attributes out of the permission entry
-            for o in self.obj.aci_attributes + ['all', 'raw', 'rights']:
-                try:
-                    del opts[o]
-                except:
-                    pass
-
-            if len(opts) > 0 and not aciupdate:
-                raise exc
-        else:
-            raise exc
+        if call_func.func_name == 'update_entry':
+            if isinstance(exc, errors.EmptyModlist):
+                aciupdate = getattr(context, 'aciupdate')
+                opts = copy.copy(options)
+                # Clear the aci attributes out of the permission entry
+                for o in self.obj.aci_attributes + ['all', 'raw', 'rights']:
+                    try:
+                        del opts[o]
+                    except:
+                        pass
+                if len(opts) == 0 or aciupdate:
+                    return
+        raise exc
 
     def post_callback(self, ldap, dn, entry_attrs, *keys, **options):
         # rename the underlying ACI after the change to permission
diff --git a/ipalib/plugins/pwpolicy.py b/ipalib/plugins/pwpolicy.py
index 2b586ec..90a2ea1 100644
--- a/ipalib/plugins/pwpolicy.py
+++ b/ipalib/plugins/pwpolicy.py
@@ -414,11 +414,12 @@ class pwpolicy_mod(LDAPUpdate):
         return dn
 
     def exc_callback(self, keys, options, exc, call_func, *call_args, **call_kwargs):
-        if isinstance(exc, errors.EmptyModlist):
-            entry_attrs = call_args[1]
-            cosupdate = getattr(context, 'cosupdate')
-            if not entry_attrs or cosupdate:
-                return
+        if call_func.func_name == 'update_entry':
+            if isinstance(exc, errors.EmptyModlist):
+                entry_attrs = call_args[1]
+                cosupdate = getattr(context, 'cosupdate')
+                if not entry_attrs or cosupdate:
+                    return
         raise exc
 
 api.register(pwpolicy_mod)
diff --git a/tests/test_xmlrpc/test_baseldap_plugin.py b/tests/test_xmlrpc/test_baseldap_plugin.py
new file mode 100644
index 0000000..0800a5d
--- /dev/null
+++ b/tests/test_xmlrpc/test_baseldap_plugin.py
@@ -0,0 +1,66 @@
+# Authors:
+#   Petr Viktorin <pvikt...@redhat.com>
+#
+# Copyright (C) 2012  Red Hat
+# see file 'COPYING' for use and warranty information
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+"""
+Test the `ipalib.plugins.baseldap` module.
+"""
+
+from ipalib import errors
+from ipalib.plugins import baseldap
+
+def test_exc_wrapper():
+    """Test the CallbackInterface._exc_wrapper helper method"""
+    handled_exceptions = []
+
+    class test_callback(baseldap.CallbackInterface):
+        """Fake IPA method"""
+        def test_fail(self):
+            self._exc_wrapper([], {}, self.fail)(1, 2, a=1, b=2)
+
+        def fail(self, *args, **kwargs):
+            assert args == (1, 2)
+            assert kwargs == dict(a=1, b=2)
+            raise errors.ExecutionError('failure')
+
+    instance = test_callback()
+
+    # Test with one callback first
+
+    @test_callback.register_exc_callback
+    def handle_exception(self, keys, options, e, call_func, *args, **kwargs):
+        assert args == (1, 2)
+        assert kwargs == dict(a=1, b=2)
+        handled_exceptions.append(type(e))
+
+    instance.test_fail()
+    assert handled_exceptions == [errors.ExecutionError]
+
+    # Test with another callback added
+
+    handled_exceptions = []
+
+    def dont_handle(self, keys, options, e, call_func, *args, **kwargs):
+        assert args == (1, 2)
+        assert kwargs == dict(a=1, b=2)
+        handled_exceptions.append(None)
+        raise e
+    test_callback.register_exc_callback(dont_handle, first=True)
+
+    instance.test_fail()
+    assert handled_exceptions == [None, errors.ExecutionError]
-- 
1.7.7.6

_______________________________________________
Freeipa-devel mailing list
Freeipa-devel@redhat.com
https://www.redhat.com/mailman/listinfo/freeipa-devel

Reply via email to