Author: christian.heimes
Date: Thu Jan 10 02:49:44 2008
New Revision: 59886

Modified:
   python/branches/py3k-importhook/Include/import.h
   python/branches/py3k-importhook/Lib/test/test_imp.py
   python/branches/py3k-importhook/Python/import.c
Log:
Post import hook implementation
This implementation has been cleaned up a bit.

Modified: python/branches/py3k-importhook/Include/import.h
==============================================================================
--- python/branches/py3k-importhook/Include/import.h    (original)
+++ python/branches/py3k-importhook/Include/import.h    Thu Jan 10 02:49:44 2008
@@ -35,6 +35,12 @@
 PyAPI_FUNC(PyObject *)_PyImport_FindExtension(char *, char *);
 PyAPI_FUNC(PyObject *)_PyImport_FixupExtension(char *, char *);
 
+/* post import hook API */
+PyAPI_FUNC(PyObject *) PyImport_GetPostImportHooks(void);
+PyAPI_FUNC(PyObject *) PyImport_NotifyModuleLoaded(PyObject *module);
+PyAPI_FUNC(PyObject *) PyImport_RegisterPostImportHook(
+       PyObject *callable, PyObject *mod_name);
+
 struct _inittab {
     char *name;
     void (*initfunc)(void);

Modified: python/branches/py3k-importhook/Lib/test/test_imp.py
==============================================================================
--- python/branches/py3k-importhook/Lib/test/test_imp.py        (original)
+++ python/branches/py3k-importhook/Lib/test/test_imp.py        Thu Jan 10 
02:49:44 2008
@@ -1,4 +1,5 @@
 import imp
+import sys
 import thread
 import unittest
 from test import test_support
@@ -68,12 +69,79 @@
         ## import sys
         ## self.assertRaises(ImportError, reload, sys)
 
+class CallBack:
+    def __init__(self):
+        self.mods = {}
+
+    def __call__(self, mod):
+        self.mods[mod.__name__] = mod
+
+class PostImportHookTests(unittest.TestCase):
+
+    def setUp(self):
+        if "telnetlib" in sys.modules:
+            del sys.modules["telnetlib"]
+        self.pihr = sys.post_import_hooks.copy()
+
+    def tearDown(self):
+        if "telnetlib" in sys.modules:
+            del sys.modules["telnetlib"]
+        sys.post_import_hooks = self.pihr
+
+    def test_registry(self):
+        reg = sys.post_import_hooks
+        self.assert_(isinstance(reg, dict))
+
+    def test_invalid_registry(self):
+        sys.post_import_hooks = []
+        self.assertRaises(TypeError, imp.register_post_import_hook,
+                          lambda mod: None, "sys")
+        sys.post_import_hooks = {}
+        imp.register_post_import_hook(lambda mod: None, "sys")
+
+        sys.post_import_hooks["telnetlib"] = lambda mod: None
+        self.assertRaises(TypeError, __import__, "telnetlib")
+        sys.post_import_hooks = self.pihr
+
+    def test_register_callback_existing(self):
+        callback = CallBack()
+        imp.register_post_import_hook(callback, "sys")
+
+        # sys is already loaded and the callback is fired immediately
+        self.assert_("sys" in callback.mods, callback.mods)
+        self.assert_(callback.mods["sys"] is sys, callback.mods)
+        self.failIf("telnetlib" in callback.mods, callback.mods)
+        regc = sys.post_import_hooks.get("sys", False)
+        self.assert_(regc is False, regc)
+
+    def test_register_callback_new(self):
+        callback = CallBack()
+        # an arbitrary module
+        if "telnetlib" in sys.modules:
+            del sys.modules["telnetlib"]
+        imp.register_post_import_hook(callback, "telnetlib")
+
+        regc = sys.post_import_hooks.get("telnetlib")
+        self.assert_(regc is not None, regc)
+        self.assert_(isinstance(regc, list), regc)
+        self.assert_(callback in regc, regc)
+
+        import telnetlib
+        self.assert_("telnetlib" in callback.mods, callback.mods)
+        self.assert_(callback.mods["telnetlib"] is telnetlib, callback.mods)
+
+    def test_post_import_notify(self):
+        imp.notify_module_loaded(sys)
+        self.failUnlessRaises(TypeError, imp.notify_module_loaded, None)
+        self.failUnlessRaises(TypeError, imp.notify_module_loaded, object())
+
 
 def test_main():
     test_support.run_unittest(
-                LockTests,
-                ImportTests,
-            )
+        LockTests,
+        ImportTests,
+        PostImportHookTests,
+    )
 
 if __name__ == "__main__":
     test_main()

Modified: python/branches/py3k-importhook/Python/import.c
==============================================================================
--- python/branches/py3k-importhook/Python/import.c     (original)
+++ python/branches/py3k-importhook/Python/import.c     Thu Jan 10 02:49:44 2008
@@ -163,7 +163,7 @@
 void
 _PyImportHooks_Init(void)
 {
-       PyObject *v, *path_hooks = NULL, *zimpimport;
+       PyObject *v, *path_hooks = NULL, *zimpimport, *pihr;
        int err = 0;
 
        /* adding sys.path_hooks and sys.path_importer_cache, setting up
@@ -200,6 +200,14 @@
                              );
        }
 
+       pihr = PyDict_New();
+       if (pihr == NULL ||
+            PySys_SetObject("post_import_hooks", pihr) != 0) {
+               PyErr_Print();
+               Py_FatalError("initialization of post import hook registry "
+                             "failed");
+       }
+
        zimpimport = PyImport_ImportModule("zipimport");
        if (zimpimport == NULL) {
                PyErr_Clear(); /* No zip import module -- okay */
@@ -371,6 +379,7 @@
        "path", "argv", "ps1", "ps2",
        "last_type", "last_value", "last_traceback",
        "path_hooks", "path_importer_cache", "meta_path",
+       "post_import_hooks",
        NULL
 };
 
@@ -625,6 +634,214 @@
                              "sys.modules failed");
 }
 
+/* post import hook API */
+PyObject *
+PyImport_GetPostImportHooks(void)
+{
+       PyObject *pihr;
+
+       pihr = PySys_GetObject("post_import_hooks");
+       /* This should only happen during initialization */
+       if (pihr == NULL) {
+               PyErr_Clear();
+               return NULL;
+       }
+
+       if (!PyDict_Check(pihr)) {
+               PyErr_SetString(PyExc_TypeError,
+                               "post import registry is not a dict");
+               return NULL;
+       }
+       return pihr;
+}
+
+PyObject *
+PyImport_NotifyModuleLoaded(PyObject *module)
+{
+       static PyObject *name = NULL;
+       PyObject *mod_name = NULL, *registry = NULL, *o;
+       PyObject *hooks = NULL, *hook, *it = NULL;
+       int status = -1;
+
+       if (name == NULL) {
+               name = PyUnicode_InternFromString("__name__");
+               if (name == NULL) {
+                       return NULL;
+               }
+       }
+
+       if (module == NULL) {
+               return NULL;
+       }
+
+       /* Should I allow all kinds of objects ? */
+       if (!PyModule_Check(module)) {
+               PyErr_Format(PyExc_TypeError,
+                            "A module object was expected, got '%.200s'",
+                            Py_TYPE(module)->tp_name);
+               goto error;
+       }
+
+       /* XXX check if module is in sys.modules ? */
+       registry = PyImport_GetPostImportHooks();
+       if (registry == NULL) {
+               /* warn about invalid registry? */
+               PyErr_Clear();
+               return module;
+       }
+
+       mod_name = PyObject_GetAttr(module, name);
+       if (mod_name == NULL) {
+               goto error;
+       }
+       if (!PyUnicode_Check(mod_name)) {
+               PyObject *repr;
+               char *name;
+
+               repr = PyObject_Repr(module);
+               name = repr ? PyUnicode_AsString(repr) : "<unknown>";
+               PyErr_Format(PyExc_TypeError,
+                            "Module __name__ attribute of '%.200s' is not "
+                            "string", name);
+               Py_XDECREF(repr);
+               goto error;
+       }
+
+       hooks = PyDict_GetItem(registry, mod_name);
+       if (hooks == NULL) {
+               /* Either no hooks are defined or they are already fired */
+               PyErr_Clear();
+               goto end;
+       }
+       if (!PyList_Check(hooks)) {
+               PyErr_Format(PyExc_TypeError,
+                            "expected None or list of hooks, got '%.200s'",
+                            Py_TYPE(hooks)->tp_name);
+               goto error;
+       }
+
+       /* fire hooks */
+       it = PyObject_GetIter(hooks);
+       if (it == NULL) {
+               goto error;
+       }
+       while ((hook = PyIter_Next(it)) != NULL) {
+               o = PyObject_CallFunctionObjArgs(hook, module, NULL);
+               Py_DECREF(hook);
+               if (o == NULL) {
+                       goto error;
+               }
+               Py_DECREF(o);
+       }
+
+       /* Mark hooks as fired */
+       if (PyDict_DelItem(registry, mod_name) < 0) {
+               goto error;
+       }
+
+    end:
+       status = 0;
+    error:
+       Py_XDECREF(mod_name);
+       Py_XDECREF(it);
+       if (status < 0) {
+               Py_XDECREF(module);
+               return NULL;
+       }
+       else {
+               return module;
+       }
+}
+
+PyObject *
+PyImport_RegisterPostImportHook(PyObject *callable, PyObject *mod_name)
+{
+       PyObject *registry = NULL, *hooks = NULL;
+       int status = -1, locked = 0;
+
+       if (!PyCallable_Check(callable)) {
+               PyErr_SetString(PyExc_TypeError, "expected callable");
+               goto error;
+       }
+       if (!PyUnicode_Check(mod_name)) {
+               PyErr_SetString(PyExc_TypeError, "expected string");
+               goto error;
+       }
+
+       registry = PyImport_GetPostImportHooks();
+       if (registry == NULL) {
+               goto error;
+       }
+
+       lock_import();
+       locked = 1;
+
+       hooks = PyDict_GetItem(registry, mod_name);
+       /* module may be already loaded, get the module object from sys */
+       if (hooks == NULL) {
+               PyObject *o, *modules;
+               PyObject *module = NULL;
+
+               modules = PyImport_GetModuleDict();
+               if (modules == NULL) {
+                       goto error;
+               }
+               module = PyDict_GetItem(modules, mod_name);
+               if (module != NULL) {
+                       /* module is already loaded, fire hook immediately */
+                       o = PyObject_CallFunctionObjArgs(callable, module, 
NULL);
+                       if (o == NULL) {
+                               goto error;
+                       }
+                       Py_DECREF(o);
+                       goto end;
+               }
+       }
+       /* no hook registered so far */
+       if (hooks == NULL) {
+               PyErr_Clear();
+               hooks = PyList_New(0);
+               if (hooks == NULL) {
+                       goto error;
+               }
+               if (PyDict_SetItem(registry, mod_name, hooks) < 0) {
+                       goto error;
+               }
+       }
+       else {
+               if (!PyList_Check(hooks)) {
+                       PyErr_Format(PyExc_TypeError,
+                                    "expected list of hooks, got '%.200s'",
+                                    Py_TYPE(hooks)->tp_name);
+                       goto error;
+               }
+       }
+       /* append a new callable */
+       if (PyList_Append(hooks, callable) < 0) {
+               goto error;
+       }
+
+    end:
+       status = 0;
+    error:
+       Py_XDECREF(callable);
+       Py_XDECREF(hooks);
+       Py_XDECREF(mod_name);
+       if (locked) {
+               if (unlock_import() < 0) {
+                       PyErr_SetString(PyExc_RuntimeError,
+                                       "not holding the import lock");
+                       return NULL;
+               }
+       }
+       if (status < 0) {
+               return NULL;
+       }
+       else {
+               Py_RETURN_NONE;
+       }
+}
+
 static PyObject * get_sourcefile(const char *file);
 
 /* Execute a code object in a module and return the module object
@@ -2066,6 +2283,7 @@
        PyObject *result;
        lock_import();
        result = import_module_level(name, globals, locals, fromlist, level);
+       result = PyImport_NotifyModuleLoaded(result);
        if (unlock_import() < 0) {
                Py_XDECREF(result);
                PyErr_SetString(PyExc_RuntimeError,
@@ -2979,6 +3197,31 @@
 }
 
 static PyObject *
+imp_register_post_import_hook(PyObject *self, PyObject *args)
+{
+       PyObject *callable, *mod_name;
+       
+       if (!PyArg_ParseTuple(args, "OO:register_post_import_hook",
+                             &callable, &mod_name))
+               return NULL;
+       Py_INCREF(callable);
+       Py_INCREF(mod_name);
+       return PyImport_RegisterPostImportHook(callable, mod_name);
+}
+
+static PyObject *
+imp_notify_module_loaded(PyObject *self, PyObject *args)
+{
+       PyObject *mod;
+
+        if (!PyArg_ParseTuple(args, "O:notify_module_loaded", &mod))
+                return NULL;
+
+       Py_INCREF(mod);
+       return PyImport_NotifyModuleLoaded(mod);
+}
+
+static PyObject *
 imp_reload(PyObject *self, PyObject *v)
 {
         return PyImport_ReloadModule(v);
@@ -3038,6 +3281,13 @@
 Release the interpreter's import lock.\n\
 On platforms without threads, this function does nothing.");
 
+PyDoc_STRVAR(doc_register_post_import_hook,
+"register_post_import_hook(callable, module_name) -> None");
+
+PyDoc_STRVAR(doc_notify_module_loaded,
+"notify_module_loaded(module) -> module");
+
+
 static PyMethodDef imp_methods[] = {
        {"find_module",  imp_find_module,  METH_VARARGS, doc_find_module},
        {"get_magic",    imp_get_magic,    METH_NOARGS,  doc_get_magic},
@@ -3047,6 +3297,10 @@
        {"lock_held",    imp_lock_held,    METH_NOARGS,  doc_lock_held},
        {"acquire_lock", imp_acquire_lock, METH_NOARGS,  doc_acquire_lock},
        {"release_lock", imp_release_lock, METH_NOARGS,  doc_release_lock},
+       {"register_post_import_hook",   imp_register_post_import_hook,
+               METH_VARARGS, doc_register_post_import_hook},
+       {"notify_module_loaded", imp_notify_module_loaded, METH_VARARGS,
+               doc_notify_module_loaded},
        {"reload",       imp_reload,       METH_O,       doc_reload},
        /* The rest are obsolete */
        {"get_frozen_object",   imp_get_frozen_object,  METH_VARARGS},
_______________________________________________
Python-3000-checkins mailing list
[email protected]
http://mail.python.org/mailman/listinfo/python-3000-checkins

Reply via email to