ehlo,

python wrapper for retrieving netgroups was push to early.
Attached patch fixes it.

LS
>From 92b7a3adf688d68b37e2a545cdbe8405186859e3 Mon Sep 17 00:00:00 2001
From: Lukas Slebodnik <lsleb...@redhat.com>
Date: Wed, 10 Aug 2016 20:05:52 +0200
Subject: [PATCH] sssd_netgroup.py: Resolve nested netgroups

---
 src/tests/intg/sssd_netgroup.py | 149 ++++++++++++++++++++++++++--------------
 1 file changed, 99 insertions(+), 50 deletions(-)

diff --git a/src/tests/intg/sssd_netgroup.py b/src/tests/intg/sssd_netgroup.py
index 
3525261cb28707db9031ee1dfeb144ae4c362833..7c28d93835339d20446c8c3b83a5e59138df000c
 100644
--- a/src/tests/intg/sssd_netgroup.py
+++ b/src/tests/intg/sssd_netgroup.py
@@ -71,49 +71,118 @@ class Netgrent(Structure):
                 ("nip", c_void_p)]
 
 
-def call_sssd_setnetgrent(netgroup):
-    libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
-    libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+class NetgroupRetriever(object):
+    def __init__(self, name):
+        self.name = name
+        self.needed_groups = []
+        self.known_groups = []
+        self.netgroups = []
 
-    func = libnss_sss._nss_sss_setnetgrent
-    func.restype = c_int
-    func.argtypes = [c_char_p, POINTER(Netgrent)]
+    def _setnetgrent(self, netgroup):
+        libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
+        libnss_sss = cdll.LoadLibrary(libnss_sss_path)
 
-    result = Netgrent()
-    result_p = POINTER(Netgrent)(result)
+        func = libnss_sss._nss_sss_setnetgrent
+        func.restype = c_int
+        func.argtypes = [c_char_p, POINTER(Netgrent)]
 
-    res = func(c_char_p(netgroup), result_p)
+        result = Netgrent()
+        result_p = POINTER(Netgrent)(result)
 
-    return (int(res), result_p)
+        res = func(c_char_p(netgroup), result_p)
 
+        return (int(res), result_p)
 
-def call_sssd_getnetgrent_r(result_p, buff, buff_len):
-    libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
-    libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+    def _getnetgrent_r(self, result_p, buff, buff_len):
+        libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
+        libnss_sss = cdll.LoadLibrary(libnss_sss_path)
 
-    func = libnss_sss._nss_sss_getnetgrent_r
-    func.restype = c_int
-    func.argtypes = [POINTER(Netgrent), POINTER(c_char), c_size_t,
-                     POINTER(c_int)]
+        func = libnss_sss._nss_sss_getnetgrent_r
+        func.restype = c_int
+        func.argtypes = [POINTER(Netgrent), POINTER(c_char), c_size_t,
+                         POINTER(c_int)]
 
-    errno = POINTER(c_int)(c_int(0))
+        errno = POINTER(c_int)(c_int(0))
 
-    res = func(result_p, buff, buff_len, errno)
+        res = func(result_p, buff, buff_len, errno)
 
-    return (int(res), int(errno[0]), result_p)
+        return (int(res), int(errno[0]), result_p)
 
+    def _endnetgrent(self, result_p):
+        libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
+        libnss_sss = cdll.LoadLibrary(libnss_sss_path)
 
-def call_sssd_endnetgrent(result_p):
-    libnss_sss_path = config.NSS_MODULE_DIR + "/libnss_sss.so.2"
-    libnss_sss = cdll.LoadLibrary(libnss_sss_path)
+        func = libnss_sss._nss_sss_endnetgrent
+        func.restype = c_int
+        func.argtypes = [POINTER(Netgrent)]
 
-    func = libnss_sss._nss_sss_endnetgrent
-    func.restype = c_int
-    func.argtypes = [POINTER(Netgrent)]
+        res = func(result_p)
 
-    res = func(result_p)
+        return int(res)
 
-    return int(res)
+    def get_netgroups(self):
+        res, errno, result = self._flat_fetch_netgroups(self.name)
+        if res != NssReturnCode.SUCCESS:
+            return (res, errno, self.netgroups)
+
+        self.netgroups += result
+
+        while len(self.needed_groups) > 0:
+            name, self.needed_groups = self.needed_groups[0], 
self.needed_groups[1:]
+
+            nest_res, nest_errno, result = self._flat_fetch_netgroups(name)
+            # do not fail for missing nested netgroup
+            if nest_res not in (NssReturnCode.SUCCESS, NssReturnCode.NOTFOUND):
+                return (nest_res, nest_errno, self.netgroups)
+
+            self.netgroups = result + self.netgroups
+
+        return (res, errno, self.netgroups)
+
+    def _flat_fetch_netgroups(self, name):
+        """
+        Function will return netgroup triplets for given user. It will gather
+        netgroups only provided by sssd.
+        The equivalent of "getent netgroup -s sss user"
+
+        @param string name name of netgroup
+
+        @return (int, int, List[(string, string, string]) (err, errno, 
netgroups)
+            if err is NssReturnCode.SUCCESS netgroups will contain list of 
touples.
+            Each touple will consist of 3 elemets either string or None
+            (host, user, domain).
+        """
+        buff_len = 1024 * 1024
+        buff = create_string_buffer(buff_len)
+
+        result = []
+
+        res, result_p = self._setnetgrent(name)
+        if res != NssReturnCode.SUCCESS:
+            return (res, get_errno(), result)
+
+        res, errno, result_p = self._getnetgrent_r(result_p, buff, buff_len)
+        while res == NssReturnCode.SUCCESS:
+            if result_p[0].type == NetgroupType.GROUP_VAL:
+                nested_netgroup = result_p[0].val.group
+                if nested_netgroup not in self.known_groups:
+                    self.needed_groups.append(nested_netgroup)
+                    self.known_groups.append(nested_netgroup)
+
+            if result_p[0].type == NetgroupType.TRIPLE_VAL:
+                result.append((result_p[0].val.triple.host,
+                               result_p[0].val.triple.user,
+                               result_p[0].val.triple.domain))
+
+            res, errno, result_p = self._getnetgrent_r(result_p, buff,
+                                                       buff_len)
+
+        if res != NssReturnCode.RETURN:
+            return (res, errno, result)
+
+        res = self._endnetgrent(result_p)
+
+        return (res, errno, result)
 
 
 def get_sssd_netgroups(name):
@@ -129,27 +198,7 @@ def get_sssd_netgroups(name):
         Each touple will consist of 3 elemets either string or None
         (host, user, domain).
     """
-    buff_len = 1024 * 1024
-    buff = create_string_buffer(buff_len)
 
-    result = []
+    retriever = NetgroupRetriever(name)
 
-    res, result_p = call_sssd_setnetgrent(name)
-    if res != NssReturnCode.SUCCESS:
-        return (res, get_errno(), result)
-
-    res, errno, result_p = call_sssd_getnetgrent_r(result_p, buff, buff_len)
-    while res == NssReturnCode.SUCCESS:
-        assert result_p[0].type == NetgroupType.TRIPLE_VAL
-        result.append((result_p[0].val.triple.host,
-                       result_p[0].val.triple.user,
-                       result_p[0].val.triple.domain))
-        res, errno, result_p = call_sssd_getnetgrent_r(result_p, buff,
-                                                       buff_len)
-
-    if res != NssReturnCode.RETURN:
-        return (res, errno, result)
-
-    res = call_sssd_endnetgrent(result_p)
-
-    return (res, errno, result)
+    return retriever.get_netgroups()
-- 
2.9.3

_______________________________________________
sssd-devel mailing list
sssd-devel@lists.fedorahosted.org
https://lists.fedorahosted.org/admin/lists/sssd-devel@lists.fedorahosted.org

Reply via email to