https://git.reactos.org/?p=reactos.git;a=commitdiff;h=d12880829ffee337c46bc4cedbcc37a7a702cffb

commit d12880829ffee337c46bc4cedbcc37a7a702cffb
Author:     Mark Jansen <[email protected]>
AuthorDate: Sat Apr 15 22:21:33 2023 +0200
Commit:     Mark Jansen <[email protected]>
CommitDate: Sat Apr 22 21:23:55 2023 +0200

    [ATL] Add OBJECT_ENTRY_AUTO for simpler com object registration
    
    Of course gcc needs a nasty hack to include the symbol.
    CORE-18936
---
 sdk/lib/atl/atlbase.h | 177 ++++++++++++++++++++++++++++++++++++++++++++------
 sdk/lib/atl/atlcom.h  |  58 +++++++++++++++++
 2 files changed, 216 insertions(+), 19 deletions(-)

diff --git a/sdk/lib/atl/atlbase.h b/sdk/lib/atl/atlbase.h
index 4bc00d2e9bf..1b0b0361b6a 100644
--- a/sdk/lib/atl/atlbase.h
+++ b/sdk/lib/atl/atlbase.h
@@ -59,6 +59,15 @@ class CAtlComModule;
 __declspec(selectany) CAtlModule *_pAtlModule = NULL;
 __declspec(selectany) CComModule *_pModule = NULL;
 
+template <bool isDll, typename T> struct CAtlValidateModuleConfiguration
+{
+#if !defined(_WINDLL) && !defined(_USRDLL)
+    static_assert(!isDll, "_WINDLL or _USRDLL must be defined when 
'CAtlDllModuleT<T>' is used");
+#else
+    static_assert(isDll, "_WINDLL or _USRDLL must be defined when 
'CAtlExeModuleT<T>' is used");
+#endif
+};
+
 
 struct _ATL_CATMAP_ENTRY
 {
@@ -173,6 +182,46 @@ struct _ATL_WIN_MODULE70
 };
 typedef _ATL_WIN_MODULE70 _ATL_WIN_MODULE;
 
+
+// Auto object map
+
+#if defined(_MSC_VER)
+#pragma section("ATL$__a", read, write)
+#pragma section("ATL$__z", read, write)
+#pragma section("ATL$__m", read, write)
+#define _ATLALLOC(x) __declspec(allocate(x))
+
+#if defined(_M_IX86)
+#define OBJECT_ENTRY_PRAGMA(class) __pragma(comment(linker, 
"/include:___pobjMap_" #class));
+#elif defined(_M_IA64) || defined(_M_AMD64) || (_M_ARM) || defined(_M_ARM64)
+#define OBJECT_ENTRY_PRAGMA(class) __pragma(comment(linker, 
"/include:__pobjMap_" #class));
+#else
+#error  Your platform is not supported.
+#endif
+
+#elif defined(__GNUC__)
+
+// GCC completely ignores __attribute__((unused)) on the __pobjMap_ pointer, 
so we pass it to a function that is not allowed to be optimized....
+static int __attribute__((optimize("O0"), unused)) hack_for_gcc(const 
_ATL_OBJMAP_ENTRY * const *)
+{
+    return 1;
+}
+
+#define _ATLALLOC(x) __attribute__((section(x)))
+#define OBJECT_ENTRY_PRAGMA(class) static int __pobjMap_hack_##class = 
hack_for_gcc(&__pobjMap_##class);
+
+#else
+#error Your compiler is not supported.
+#endif
+
+
+extern "C"
+{
+    __declspec(selectany) _ATLALLOC("ATL$__a") _ATL_OBJMAP_ENTRY 
*__pobjMapEntryFirst = NULL;
+    __declspec(selectany) _ATLALLOC("ATL$__z") _ATL_OBJMAP_ENTRY 
*__pobjMapEntryLast = NULL;
+}
+
+
 struct _ATL_REGMAP_ENTRY
 {
     LPCOLESTR szKey;
@@ -551,8 +600,9 @@ public:
     CAtlComModule()
     {
         GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, 
(LPCWSTR)this, &m_hInstTypeLib);
-        m_ppAutoObjMapFirst = NULL;
-        m_ppAutoObjMapLast = NULL;
+
+        m_ppAutoObjMapFirst = &__pobjMapEntryFirst + 1;
+        m_ppAutoObjMapLast = &__pobjMapEntryLast;
         if (FAILED(m_csObjMap.Init()))
         {
             ATLASSERT(0);
@@ -577,17 +627,37 @@ public:
         return AtlComModuleUnregisterServer(this, bUnRegTypeLib, pCLSID);
     }
 
-
     void Term()
     {
         if (cbSize != 0)
         {
-            ATLASSERT(m_ppAutoObjMapFirst == NULL);
-            ATLASSERT(m_ppAutoObjMapLast == NULL);
+            for (_ATL_OBJMAP_ENTRY **iter = m_ppAutoObjMapFirst; iter < 
m_ppAutoObjMapLast; iter++)
+            {
+                _ATL_OBJMAP_ENTRY *ptr = *iter;
+                if (!ptr)
+                    continue;
+
+                if (!ptr->pCF)
+                    continue;
+
+                ptr->pCF->Release();
+                ptr->pCF = NULL;
+            }
             m_csObjMap.Term();
             cbSize = 0;
         }
     }
+
+    void ExecuteObjectMain(bool bStarting)
+    {
+        for (_ATL_OBJMAP_ENTRY **iter = m_ppAutoObjMapFirst; iter < 
m_ppAutoObjMapLast; iter++)
+        {
+            if (!*iter)
+                continue;
+
+            (*iter)->pfnObjectMain(bStarting);
+        }
+    }
 };
 
 __declspec(selectany) CAtlComModule _AtlComModule;
@@ -606,11 +676,20 @@ HRESULT CAtlModuleT<T>::UnregisterServer(BOOL 
bUnRegTypeLib, const CLSID *pCLSID
 }
 
 template <class T>
-class CAtlDllModuleT : public CAtlModuleT<T>
+class CAtlDllModuleT
+    : public CAtlModuleT<T>
+    , private CAtlValidateModuleConfiguration<true, T>
+
 {
 public:
     CAtlDllModuleT()
     {
+        _AtlComModule.ExecuteObjectMain(true);
+    }
+
+    ~CAtlDllModuleT()
+    {
+        _AtlComModule.ExecuteObjectMain(false);
     }
 
     HRESULT DllCanUnloadNow()
@@ -659,7 +738,9 @@ public:
 
 
 template <class T>
-class CAtlExeModuleT : public CAtlModuleT<T>
+class CAtlExeModuleT
+    : public CAtlModuleT<T>
+    , private CAtlValidateModuleConfiguration<false, T>
 {
 public:
     DWORD m_dwMainThreadID;
@@ -670,10 +751,12 @@ public:
     CAtlExeModuleT()
         :m_dwMainThreadID(::GetCurrentThreadId())
     {
+        _AtlComModule.ExecuteObjectMain(true);
     }
 
     ~CAtlExeModuleT()
     {
+        _AtlComModule.ExecuteObjectMain(false);
     }
 
     int WinMain(int nShowCmd)
@@ -815,12 +898,19 @@ public:
                 }
             }
         }
+
+        for (_ATL_OBJMAP_ENTRY **iter = _AtlComModule.m_ppAutoObjMapFirst; 
iter < _AtlComModule.m_ppAutoObjMapLast; iter++)
+        {
+            if (*iter != NULL)
+                (*iter)->pfnObjectMain(true);
+        }
+
         return S_OK;
     }
 
     void Term()
     {
-        _ATL_OBJMAP_ENTRY                    *objectMapEntry;
+        _ATL_OBJMAP_ENTRY *objectMapEntry;
 
         if (m_pObjMap != NULL)
         {
@@ -834,12 +924,19 @@ public:
                 objectMapEntry++;
             }
         }
+
+        for (_ATL_OBJMAP_ENTRY **iter = _AtlComModule.m_ppAutoObjMapFirst; 
iter < _AtlComModule.m_ppAutoObjMapLast; iter++)
+        {
+            if (*iter != NULL)
+                (*iter)->pfnObjectMain(false);
+        }
+
     }
 
     HRESULT GetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv)
     {
-        _ATL_OBJMAP_ENTRY                    *objectMapEntry;
-        HRESULT                                hResult;
+        _ATL_OBJMAP_ENTRY *objectMapEntry;
+        HRESULT hResult;
 
         ATLASSERT(ppv != NULL);
         if (ppv == NULL)
@@ -869,8 +966,7 @@ public:
         }
         if (hResult == S_OK && *ppv == NULL)
         {
-            // FIXME: call AtlComModuleGetClassObject
-            hResult = CLASS_E_CLASSNOTAVAILABLE;
+            hResult = AtlComModuleGetClassObject(&_AtlComModule, rclsid, riid, 
ppv);
         }
         return hResult;
     }
@@ -1480,9 +1576,9 @@ inline HRESULT __stdcall AtlAdvise(IUnknown *pUnkCP, 
IUnknown *pUnk, const IID &
 
 inline HRESULT __stdcall AtlUnadvise(IUnknown *pUnkCP, const IID &iid, DWORD 
dw)
 {
-    CComPtr<IConnectionPointContainer>        container;
-    CComPtr<IConnectionPoint>                connectionPoint;
-    HRESULT                                    hResult;
+    CComPtr<IConnectionPointContainer> container;
+    CComPtr<IConnectionPoint> connectionPoint;
+    HRESULT hResult;
 
     if (pUnkCP == NULL)
         return E_INVALIDARG;
@@ -1809,14 +1905,18 @@ inline HRESULT WINAPI 
AtlComModuleRegisterClassObjects(_ATL_COM_MODULE *module,
 
     for (iter = module->m_ppAutoObjMapFirst; iter < 
module->m_ppAutoObjMapLast; iter++)
     {
-        if (!(*iter)->pfnGetClassObject)
+        _ATL_OBJMAP_ENTRY *ptr = *iter;
+        if (!ptr)
             continue;
 
-        hr = (*iter)->pfnGetClassObject((void*)(*iter)->pfnCreateInstance, 
IID_IUnknown, (void**)&unk);
+        if (!ptr->pfnGetClassObject)
+            continue;
+
+        hr = ptr->pfnGetClassObject((void*)ptr->pfnCreateInstance, 
IID_IUnknown, (void**)&unk);
         if (FAILED(hr))
             return hr;
 
-        hr = CoRegisterClassObject(*(*iter)->pclsid, unk, context, flags, 
&(*iter)->dwRegister);
+        hr = CoRegisterClassObject(*ptr->pclsid, unk, context, flags, 
&ptr->dwRegister);
         unk->Release();
         if (FAILED(hr))
             return hr;
@@ -1837,7 +1937,11 @@ inline HRESULT WINAPI 
AtlComModuleRevokeClassObjects(_ATL_COM_MODULE *module)
 
     for (iter = module->m_ppAutoObjMapFirst; iter < 
module->m_ppAutoObjMapLast; iter++)
     {
-        hr = CoRevokeClassObject((*iter)->dwRegister);
+        _ATL_OBJMAP_ENTRY *ptr = *iter;
+        if (!ptr)
+            continue;
+
+        hr = CoRevokeClassObject(ptr->dwRegister);
         if (FAILED(hr))
             return hr;
     }
@@ -1845,6 +1949,41 @@ inline HRESULT WINAPI 
AtlComModuleRevokeClassObjects(_ATL_COM_MODULE *module)
     return S_OK;
 }
 
+// Adapted from dll/win32/atl/atl.c
+inline HRESULT WINAPI
+AtlComModuleGetClassObject(_ATL_COM_MODULE *pm, REFCLSID rclsid, REFIID riid, 
void **ppv)
+{
+    if (!pm)
+        return E_INVALIDARG;
+
+    for (_ATL_OBJMAP_ENTRY **iter = pm->m_ppAutoObjMapFirst; iter < 
pm->m_ppAutoObjMapLast; iter++)
+    {
+        _ATL_OBJMAP_ENTRY *ptr = *iter;
+        if (!ptr)
+            continue;
+
+        if (IsEqualCLSID(*ptr->pclsid, rclsid) && ptr->pfnGetClassObject)
+        {
+            HRESULT hr = CLASS_E_CLASSNOTAVAILABLE;
+
+            if (!ptr->pCF)
+            {
+                CComCritSecLock<CComCriticalSection> 
lock(_AtlComModule.m_csObjMap, true);
+                if (!ptr->pCF)
+                {
+                    hr = ptr->pfnGetClassObject((void 
*)ptr->pfnCreateInstance, IID_IUnknown, (void **)&ptr->pCF);
+                }
+            }
+            if (ptr->pCF)
+                hr = ptr->pCF->QueryInterface(riid, ppv);
+            return hr;
+        }
+    }
+
+    return CLASS_E_CLASSNOTAVAILABLE;
+}
+
+
 }; // namespace ATL
 
 #ifndef _ATL_NO_AUTOMATIC_NAMESPACE
diff --git a/sdk/lib/atl/atlcom.h b/sdk/lib/atl/atlcom.h
index c9d92edbd9c..024d85a0bef 100644
--- a/sdk/lib/atl/atlcom.h
+++ b/sdk/lib/atl/atlcom.h
@@ -30,7 +30,12 @@ namespace ATL
 template <class Base, const IID *piid, class T, class Copy, class ThreadModel 
= CComObjectThreadModel>
 class CComEnum;
 
+#if defined(_WINDLL) | defined(_USRDLL)
 #define DECLARE_CLASSFACTORY_EX(cf) typedef 
ATL::CComCreator<ATL::CComObjectCached<cf> > _ClassFactoryCreatorClass;
+#else
+// Class factory should not change lock count
+#define DECLARE_CLASSFACTORY_EX(cf) typedef 
ATL::CComCreator<ATL::CComObjectNoLock<cf>> _ClassFactoryCreatorClass;
+#endif
 #define DECLARE_CLASSFACTORY() DECLARE_CLASSFACTORY_EX(ATL::CComClassFactory)
 #define DECLARE_CLASSFACTORY_SINGLETON(obj) 
DECLARE_CLASSFACTORY_EX(ATL::CComClassFactorySingleton<obj>)
 
@@ -539,6 +544,40 @@ public:
     }
 };
 
+
+template <class Base>
+class CComObjectNoLock : public Base
+{
+  public:
+    CComObjectNoLock(void* = NULL)
+    {
+    }
+
+    virtual ~CComObjectNoLock()
+    {
+        this->FinalRelease();
+    }
+
+    STDMETHOD_(ULONG, AddRef)()
+    {
+        return this->InternalAddRef();
+    }
+
+    STDMETHOD_(ULONG, Release)()
+    {
+        ULONG newRefCount = this->InternalRelease();
+        if (newRefCount == 0)
+            delete this;
+        return newRefCount;
+    }
+
+    STDMETHOD(QueryInterface)(REFIID iid, void **ppvObject)
+    {
+        return this->_InternalQueryInterface(iid, ppvObject);
+    }
+};
+
+
 #define BEGIN_COM_MAP(x)                                                       
 \
 public:                                                                        
    \
     typedef x _ComMapClass;                                                    
    \
@@ -663,6 +702,24 @@ public:
     class::GetCategoryMap,                                                     
   \
     class::ObjectMain },
 
+
+
+#define OBJECT_ENTRY_AUTO(clsid, class)                                        
                                        \
+    ATL::_ATL_OBJMAP_ENTRY __objMap_##class = {                                
                                        \
+        &clsid,                                                                
                                        \
+        class ::UpdateRegistry,                                                
                                        \
+        class ::_ClassFactoryCreatorClass::CreateInstance,                     
                                        \
+        class ::_CreatorClass::CreateInstance,                                 
                                        \
+        NULL,                                                                  
                                        \
+        0,                                                                     
                                        \
+        class ::GetObjectDescription,                                          
                                        \
+        class ::GetCategoryMap,                                                
                                        \
+        class ::ObjectMain};                                                   
                                        \
+    extern "C" _ATLALLOC("ATL$__m") ATL::_ATL_OBJMAP_ENTRY *const 
__pobjMap_##class = &__objMap_##class;               \
+    OBJECT_ENTRY_PRAGMA(class)
+
+
+
 class CComClassFactory :
     public IClassFactory,
     public CComObjectRootEx<CComGlobalsThreadModel>
@@ -772,6 +829,7 @@ class CComCoClass
 {
 public:
     DECLARE_CLASSFACTORY()
+    //DECLARE_AGGREGATABLE(T)   // This should be here, but gcc...
 
     static LPCTSTR WINAPI GetObjectDescription()
     {

Reply via email to