https://git.reactos.org/?p=reactos.git;a=commitdiff;h=4c95339da0addaa59a0d584c240a34f708079ce5

commit 4c95339da0addaa59a0d584c240a34f708079ce5
Author:     Victor Perevertkin <[email protected]>
AuthorDate: Wed Nov 25 03:18:31 2020 +0300
Commit:     Victor Perevertkin <[email protected]>
CommitDate: Mon Jan 4 16:50:33 2021 +0300

    [NTOS:IO] Refactoring of the driver initialization code (2)
    
    - Do not hold the IopDriverLoadResource while trying to reference a
      driver object (but still acquire it when we actually need to load a
      driver)
    - Change IopLoadDriver and IopInitializeDriverModule to use registry
      handle instead of a service name string and/or full registry path
    - Do not try to reference a driver object inside IopLoadDriver. It's
      supposed to be done before the function call
---
 ntoskrnl/include/internal/io.h |   7 +-
 ntoskrnl/io/iomgr/driver.c     | 387 ++++++++++++++++++++++++-----------------
 ntoskrnl/io/pnpmgr/devaction.c |  37 +++-
 3 files changed, 267 insertions(+), 164 deletions(-)

diff --git a/ntoskrnl/include/internal/io.h b/ntoskrnl/include/internal/io.h
index 5da5950ba67..ad51265ba3f 100644
--- a/ntoskrnl/include/internal/io.h
+++ b/ntoskrnl/include/internal/io.h
@@ -1127,14 +1127,13 @@ IopLoadServiceModule(
 
 NTSTATUS
 IopLoadDriver(
-    _In_opt_ PCUNICODE_STRING RegistryPath,
-    _Inout_ PDRIVER_OBJECT *DriverObject
-);
+    _In_ HANDLE ServiceHandle,
+    _Out_ PDRIVER_OBJECT *DriverObject);
 
 NTSTATUS
 IopInitializeDriverModule(
     _In_ PLDR_DATA_TABLE_ENTRY ModuleObject,
-    _In_ PUNICODE_STRING ServiceName,
+    _In_ HANDLE ServiceHandle,
     _Out_ PDRIVER_OBJECT *DriverObject,
     _Out_ NTSTATUS *DriverEntryStatus);
 
diff --git a/ntoskrnl/io/iomgr/driver.c b/ntoskrnl/io/iomgr/driver.c
index a4b39b77f6d..994d7375af8 100644
--- a/ntoskrnl/io/iomgr/driver.c
+++ b/ntoskrnl/io/iomgr/driver.c
@@ -28,6 +28,7 @@ KSPIN_LOCK DriverBootReinitListLock;
 
 UNICODE_STRING IopHardwareDatabaseKey =
    RTL_CONSTANT_STRING(L"\\REGISTRY\\MACHINE\\HARDWARE\\DESCRIPTION\\SYSTEM");
+static const WCHAR ServicesKeyName[] = 
L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\";
 
 POBJECT_TYPE IoDriverObjectType = NULL;
 
@@ -47,7 +48,7 @@ PLIST_ENTRY IopGroupTable;
 typedef struct _LOAD_UNLOAD_PARAMS
 {
     NTSTATUS Status;
-    PCUNICODE_STRING RegistryPath;
+    PUNICODE_STRING RegistryPath;
     WORK_QUEUE_ITEM WorkItem;
     KEVENT Event;
     PDRIVER_OBJECT DriverObject;
@@ -56,7 +57,7 @@ typedef struct _LOAD_UNLOAD_PARAMS
 
 NTSTATUS
 IopDoLoadUnloadDriver(
-    _In_opt_ PCUNICODE_STRING RegistryPath,
+    _In_opt_ PUNICODE_STRING RegistryPath,
     _Inout_ PDRIVER_OBJECT *DriverObject);
 
 /* PRIVATE FUNCTIONS 
**********************************************************/
@@ -135,7 +136,6 @@ IopGetDriverObject(
     DPRINT("IopGetDriverObject(%p '%wZ' %x)\n",
            DriverObject, ServiceName, FileSystem);
 
-    ASSERT(ExIsResourceAcquiredExclusiveLite(&IopDriverLoadResource));
     *DriverObject = NULL;
 
     /* Create ModuleName string */
@@ -480,8 +480,8 @@ IopLoadServiceModule(
  *     Module object representing the driver. It can be retrieved by 
IopLoadServiceModule.
  *     Freed on failure, so in a such case this should not be accessed anymore
  *
- * @param[in]  ServiceName
- *     Name of the service (as in the registry)
+ * @param[in]  ServiceHandle
+ *     Handle to a driver's CCS/Services/<ServiceName> key
  *
  * @param[out] DriverObject
  *     This contains the driver object if it was created (even with 
unsuccessfull result)
@@ -495,56 +495,58 @@ IopLoadServiceModule(
 NTSTATUS
 IopInitializeDriverModule(
     _In_ PLDR_DATA_TABLE_ENTRY ModuleObject,
-    _In_ PUNICODE_STRING ServiceName,
+    _In_ HANDLE ServiceHandle,
     _Out_ PDRIVER_OBJECT *OutDriverObject,
     _Out_ NTSTATUS *DriverEntryStatus)
 {
-    static const WCHAR ServicesKeyName[] = 
L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\";
-    UNICODE_STRING DriverName;
-    UNICODE_STRING RegistryKey;
+    UNICODE_STRING DriverName, RegistryPath, ServiceName;
     NTSTATUS Status;
 
     PAGED_CODE();
 
-    ASSERT(ServiceName && ServiceName->Length != 0);
-
-    // Make the registry path for the driver
-    RegistryKey.Length = 0;
-    RegistryKey.MaximumLength = sizeof(ServicesKeyName) + ServiceName->Length;
-    RegistryKey.Buffer = ExAllocatePoolWithTag(PagedPool,
-                                               RegistryKey.MaximumLength,
-                                               TAG_IO);
-    if (RegistryKey.Buffer == NULL)
+    // get the ServiceName
+    PKEY_BASIC_INFORMATION basicInfo;
+    ULONG infoLength;
+    Status = ZwQueryKey(ServiceHandle, KeyBasicInformation, NULL, 0, 
&infoLength);
+    if (Status == STATUS_BUFFER_TOO_SMALL)
     {
-        return STATUS_INSUFFICIENT_RESOURCES;
-    }
-    RtlAppendUnicodeToString(&RegistryKey, ServicesKeyName);
-    RtlAppendUnicodeStringToString(&RegistryKey, ServiceName);
+        basicInfo = ExAllocatePoolWithTag(PagedPool, infoLength, TAG_IO);
+        if (!basicInfo)
+        {
+            MmUnloadSystemImage(ModuleObject);
+            return STATUS_INSUFFICIENT_RESOURCES;
+        }
 
-    // Open the registry key for this driver (it has to exist)
-    HANDLE serviceHandle;
-    PKEY_VALUE_FULL_INFORMATION kvInfo;
+        Status = ZwQueryKey(ServiceHandle, KeyBasicInformation, basicInfo, 
infoLength, &infoLength);
+        if (!NT_SUCCESS(Status))
+        {
+            ExFreePoolWithTag(basicInfo, TAG_IO);
+            MmUnloadSystemImage(ModuleObject);
+            return Status;
+        }
 
-    Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryKey, 
KEY_READ);
-    if (!NT_SUCCESS(Status))
+        ServiceName.Length = basicInfo->NameLength;
+        ServiceName.MaximumLength = basicInfo->NameLength;
+        ServiceName.Buffer = basicInfo->Name;
+    }
+    else
     {
-        RtlFreeUnicodeString(&RegistryKey);
         MmUnloadSystemImage(ModuleObject);
-        return Status;
+        return NT_SUCCESS(Status) ? STATUS_UNSUCCESSFUL : Status;
     }
 
     // Make the DriverName field of a DRIVER_OBJECT
+    PKEY_VALUE_FULL_INFORMATION kvInfo;
 
     // 1. Check the "ObjectName" field in the driver's registry key (it has 
the priority)
-    Status = IopGetRegistryValue(serviceHandle, L"ObjectName", &kvInfo);
+    Status = IopGetRegistryValue(ServiceHandle, L"ObjectName", &kvInfo);
     if (NT_SUCCESS(Status))
     {
         // we're got the ObjectName. Use it to create the DRIVER_OBJECT
         if (kvInfo->Type != REG_SZ || kvInfo->DataLength == 0)
         {
             ExFreePool(kvInfo);
-            ZwClose(serviceHandle);
-            RtlFreeUnicodeString(&RegistryKey);
+            ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
             MmUnloadSystemImage(ModuleObject);
             return STATUS_ILL_FORMED_SERVICE_ENTRY;
         }
@@ -555,8 +557,7 @@ IopInitializeDriverModule(
         if (!DriverName.Buffer)
         {
             ExFreePool(kvInfo);
-            ZwClose(serviceHandle);
-            RtlFreeUnicodeString(&RegistryKey);
+            ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
             MmUnloadSystemImage(ModuleObject);
             return STATUS_INSUFFICIENT_RESOURCES;
         }
@@ -571,12 +572,11 @@ IopInitializeDriverModule(
         // 2. there is no "ObjectName" - construct it ourselves. Depending on 
a driver type,
         // it will be either "\Driver\<ServiceName>" or 
"\FileSystem\<ServiceName>"
 
-        Status = IopGetRegistryValue(serviceHandle, L"Type", &kvInfo);
+        Status = IopGetRegistryValue(ServiceHandle, L"Type", &kvInfo);
         if (!NT_SUCCESS(Status) || kvInfo->Type != REG_DWORD)
         {
             ExFreePool(kvInfo);
-            ZwClose(serviceHandle);
-            RtlFreeUnicodeString(&RegistryKey);
+            ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
             MmUnloadSystemImage(ModuleObject);
             return STATUS_ILL_FORMED_SERVICE_ENTRY;
         }
@@ -587,14 +587,13 @@ IopInitializeDriverModule(
 
         DriverName.Length = 0;
         if (driverType == SERVICE_RECOGNIZER_DRIVER || driverType == 
SERVICE_FILE_SYSTEM_DRIVER)
-            DriverName.MaximumLength = sizeof(FILESYSTEM_ROOT_NAME) + 
ServiceName->Length;
+            DriverName.MaximumLength = sizeof(FILESYSTEM_ROOT_NAME) + 
ServiceName.Length;
         else
-            DriverName.MaximumLength = sizeof(DRIVER_ROOT_NAME) + 
ServiceName->Length;
+            DriverName.MaximumLength = sizeof(DRIVER_ROOT_NAME) + 
ServiceName.Length;
         DriverName.Buffer = ExAllocatePoolWithTag(NonPagedPool, 
DriverName.MaximumLength, TAG_IO);
         if (!DriverName.Buffer)
         {
-            ZwClose(serviceHandle);
-            RtlFreeUnicodeString(&RegistryKey);
+            ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
             MmUnloadSystemImage(ModuleObject);
             return STATUS_INSUFFICIENT_RESOURCES;
         }
@@ -604,13 +603,53 @@ IopInitializeDriverModule(
         else
             RtlAppendUnicodeToString(&DriverName, DRIVER_ROOT_NAME);
 
-        RtlAppendUnicodeStringToString(&DriverName, ServiceName);
+        RtlAppendUnicodeStringToString(&DriverName, &ServiceName);
     }
 
-    ZwClose(serviceHandle);
-
     DPRINT("Driver name: '%wZ'\n", &DriverName);
 
+    // obtain the registry path for the DriverInit routine
+    PKEY_NAME_INFORMATION nameInfo;
+    Status = ZwQueryKey(ServiceHandle, KeyNameInformation, NULL, 0, 
&infoLength);
+    if (Status == STATUS_BUFFER_TOO_SMALL)
+    {
+        nameInfo = ExAllocatePoolWithTag(NonPagedPool, infoLength, TAG_IO);
+        if (nameInfo)
+        {
+            Status = ZwQueryKey(ServiceHandle,
+                                KeyNameInformation,
+                                nameInfo,
+                                infoLength,
+                                &infoLength);
+            if (NT_SUCCESS(Status))
+            {
+                RegistryPath.Length = nameInfo->NameLength;
+                RegistryPath.MaximumLength = nameInfo->NameLength;
+                RegistryPath.Buffer = nameInfo->Name;
+            }
+            else
+            {
+                ExFreePoolWithTag(nameInfo, TAG_IO);
+            }
+        }
+        else
+        {
+            Status = STATUS_INSUFFICIENT_RESOURCES;
+        }
+    }
+    else
+    {
+        Status = NT_SUCCESS(Status) ? STATUS_UNSUCCESSFUL : Status;
+    }
+
+    if (!NT_SUCCESS(Status))
+    {
+        ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
+        RtlFreeUnicodeString(&DriverName);
+        MmUnloadSystemImage(ModuleObject);
+        return Status;
+    }
+
     // create the driver object
     UINT32 ObjectSize = sizeof(DRIVER_OBJECT) + 
sizeof(EXTENDED_DRIVER_EXTENSION);
     OBJECT_ATTRIBUTES objAttrs;
@@ -632,7 +671,8 @@ IopInitializeDriverModule(
                             (PVOID*)&driverObject);
     if (!NT_SUCCESS(Status))
     {
-        RtlFreeUnicodeString(&RegistryKey);
+        ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath
+        ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
         RtlFreeUnicodeString(&DriverName);
         MmUnloadSystemImage(ModuleObject);
         DPRINT1("Error while creating driver object \"%wZ\" status %x\n", 
&DriverName, Status);
@@ -665,9 +705,9 @@ IopInitializeDriverModule(
     Status = ObInsertObject(driverObject, NULL, FILE_READ_DATA, 0, NULL, 
&hDriver);
     if (!NT_SUCCESS(Status))
     {
-        RtlFreeUnicodeString(&RegistryKey);
+        ExFreePoolWithTag(nameInfo, TAG_IO);
+        ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
         RtlFreeUnicodeString(&DriverName);
-        MmUnloadSystemImage(ModuleObject); // TODO: is it needed here?
         return Status;
     }
 
@@ -684,7 +724,8 @@ IopInitializeDriverModule(
 
     if (!NT_SUCCESS(Status))
     {
-        RtlFreeUnicodeString(&RegistryKey);
+        ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath
+        ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
         RtlFreeUnicodeString(&DriverName);
         return Status;
     }
@@ -693,7 +734,7 @@ IopInitializeDriverModule(
     UNICODE_STRING serviceKeyName;
     serviceKeyName.Length = 0;
     // put a NULL character at the end for Windows compatibility
-    serviceKeyName.MaximumLength = ServiceName->MaximumLength + 
sizeof(UNICODE_NULL);
+    serviceKeyName.MaximumLength = ServiceName.MaximumLength + 
sizeof(UNICODE_NULL);
     serviceKeyName.Buffer = ExAllocatePoolWithTag(NonPagedPool,
                                                   serviceKeyName.MaximumLength,
                                                   TAG_IO);
@@ -701,13 +742,15 @@ IopInitializeDriverModule(
     {
         ObMakeTemporaryObject(driverObject);
         ObDereferenceObject(driverObject);
-        RtlFreeUnicodeString(&RegistryKey);
+        ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath
+        ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
         RtlFreeUnicodeString(&DriverName);
         return STATUS_INSUFFICIENT_RESOURCES;
     }
 
     /* Copy the name and set it in the driver extension */
-    RtlCopyUnicodeString(&serviceKeyName, ServiceName);
+    RtlCopyUnicodeString(&serviceKeyName, &ServiceName);
+    ExFreePoolWithTag(basicInfo, TAG_IO); // container for ServiceName
     driverObject->DriverExtension->ServiceKeyName = serviceKeyName;
 
     /* Make a copy of the driver name to store in the driver object */
@@ -722,7 +765,7 @@ IopInitializeDriverModule(
     {
         ObMakeTemporaryObject(driverObject);
         ObDereferenceObject(driverObject);
-        RtlFreeUnicodeString(&RegistryKey);
+        ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath
         RtlFreeUnicodeString(&DriverName);
         return STATUS_INSUFFICIENT_RESOURCES;
     }
@@ -731,7 +774,7 @@ IopInitializeDriverModule(
     driverObject->DriverName = driverNamePaged;
 
     /* Finally, call its init function */
-    Status = driverObject->DriverInit(driverObject, &RegistryKey);
+    Status = driverObject->DriverInit(driverObject, &RegistryPath);
     *DriverEntryStatus = Status;
     if (!NT_SUCCESS(Status))
     {
@@ -770,7 +813,7 @@ IopInitializeDriverModule(
 
     // TODO: for legacy drivers, unload the driver if it didn't create any DO
 
-    RtlFreeUnicodeString(&RegistryKey);
+    ExFreePoolWithTag(nameInfo, TAG_IO); // container for RegistryPath
     RtlFreeUnicodeString(&DriverName);
 
     if (!NT_SUCCESS(Status))
@@ -830,37 +873,62 @@ IopAttachFilterDriversCallback(
         ServiceName.MaximumLength =
         ServiceName.Length = (USHORT)wcslen(Filters) * sizeof(WCHAR);
 
-        KeEnterCriticalRegion();
-        ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE);
+        UNICODE_STRING RegistryPath;
+
+        // Make the registry path for the driver
+        RegistryPath.Length = 0;
+        RegistryPath.MaximumLength = sizeof(ServicesKeyName) + 
ServiceName.Length;
+        RegistryPath.Buffer = ExAllocatePoolWithTag(PagedPool, 
RegistryPath.MaximumLength, TAG_IO);
+        if (RegistryPath.Buffer == NULL)
+        {
+            return STATUS_INSUFFICIENT_RESOURCES;
+        }
+        RtlAppendUnicodeToString(&RegistryPath, ServicesKeyName);
+        RtlAppendUnicodeStringToString(&RegistryPath, &ServiceName);
+
+        HANDLE serviceHandle;
+        Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryPath, 
KEY_READ);
+        RtlFreeUnicodeString(&RegistryPath);
+        if (!NT_SUCCESS(Status))
+        {
+            return Status;
+        }
+
         Status = IopGetDriverObject(&DriverObject,
                                     &ServiceName,
                                     FALSE);
         if (!NT_SUCCESS(Status))
         {
+            KeEnterCriticalRegion();
+            ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE);
+
             /* Load and initialize the filter driver */
             Status = IopLoadServiceModule(&ServiceName, &ModuleObject);
             if (!NT_SUCCESS(Status))
             {
                 ExReleaseResourceLite(&IopDriverLoadResource);
                 KeLeaveCriticalRegion();
+                ZwClose(serviceHandle);
                 return Status;
             }
 
             NTSTATUS driverEntryStatus;
             Status = IopInitializeDriverModule(ModuleObject,
-                                               &ServiceName,
+                                               serviceHandle,
                                                &DriverObject,
                                                &driverEntryStatus);
+
+            ExReleaseResourceLite(&IopDriverLoadResource);
+            KeLeaveCriticalRegion();
+
             if (!NT_SUCCESS(Status))
             {
-                ExReleaseResourceLite(&IopDriverLoadResource);
-                KeLeaveCriticalRegion();
+                ZwClose(serviceHandle);
                 return Status;
             }
         }
 
-        ExReleaseResourceLite(&IopDriverLoadResource);
-        KeLeaveCriticalRegion();
+        ZwClose(serviceHandle);
 
         Status = IopInitializeDevice(DeviceNode, DriverObject);
 
@@ -1104,6 +1172,28 @@ IopInitializeBuiltinDriver(IN PLDR_DATA_TABLE_ENTRY 
BootLdrEntry)
         FileExtension[0] = UNICODE_NULL;
     }
 
+    UNICODE_STRING RegistryPath;
+
+    // Make the registry path for the driver
+    RegistryPath.Length = 0;
+    RegistryPath.MaximumLength = sizeof(ServicesKeyName) + ServiceName.Length;
+    RegistryPath.Buffer = ExAllocatePoolWithTag(PagedPool, 
RegistryPath.MaximumLength, TAG_IO);
+    if (RegistryPath.Buffer == NULL)
+    {
+        return STATUS_INSUFFICIENT_RESOURCES;
+    }
+    RtlAppendUnicodeToString(&RegistryPath, ServicesKeyName);
+    RtlAppendUnicodeStringToString(&RegistryPath, &ServiceName);
+    RtlFreeUnicodeString(&ServiceName);
+
+    HANDLE serviceHandle;
+    Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, &RegistryPath, 
KEY_READ);
+    RtlFreeUnicodeString(&RegistryPath);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
     /* Lookup the new Ldr entry in PsLoadedModuleList */
     NextEntry = PsLoadedModuleList.Flink;
     while (NextEntry != &PsLoadedModuleList)
@@ -1125,10 +1215,10 @@ IopInitializeBuiltinDriver(IN PLDR_DATA_TABLE_ENTRY 
BootLdrEntry)
      */
     NTSTATUS driverEntryStatus;
     Status = IopInitializeDriverModule(LdrEntry,
-                                       &ServiceName,
+                                       serviceHandle,
                                        &DriverObject,
                                        &driverEntryStatus);
-    RtlFreeUnicodeString(&ServiceName);
+    ZwClose(serviceHandle);
 
     if (!NT_SUCCESS(Status))
     {
@@ -2050,69 +2140,47 @@ IoGetDriverObjectExtension(IN PDRIVER_OBJECT 
DriverObject,
 
 NTSTATUS
 IopLoadDriver(
-    _In_opt_ PCUNICODE_STRING RegistryPath,
-    _Inout_ PDRIVER_OBJECT *DriverObject)
+    _In_ HANDLE ServiceHandle,
+    _Out_ PDRIVER_OBJECT *DriverObject)
 {
-    RTL_QUERY_REGISTRY_TABLE QueryTable[3];
     UNICODE_STRING ImagePath;
-    UNICODE_STRING ServiceName;
     NTSTATUS Status;
-    ULONG Type;
     PLDR_DATA_TABLE_ENTRY ModuleObject;
     PVOID BaseAddress;
-    WCHAR *cur;
-
-    RtlInitUnicodeString(&ImagePath, NULL);
 
-    /*
-     * Get the service name from the registry key name.
-     */
-    ASSERT(RegistryPath->Length >= sizeof(WCHAR));
-
-    ServiceName = *RegistryPath;
-    cur = RegistryPath->Buffer + RegistryPath->Length / sizeof(WCHAR) - 1;
-    while (RegistryPath->Buffer != cur)
+    PKEY_VALUE_FULL_INFORMATION kvInfo;
+    Status = IopGetRegistryValue(ServiceHandle, L"ImagePath", &kvInfo);
+    if (NT_SUCCESS(Status))
     {
-        if (*cur == L'\\')
+        if (kvInfo->Type != REG_EXPAND_SZ || kvInfo->DataLength == 0)
         {
-            ServiceName.Buffer = cur + 1;
-            ServiceName.Length = RegistryPath->Length -
-                                 (USHORT)((ULONG_PTR)ServiceName.Buffer -
-                                          (ULONG_PTR)RegistryPath->Buffer);
-            break;
+            ExFreePool(kvInfo);
+            return STATUS_ILL_FORMED_SERVICE_ENTRY;
         }
-        cur--;
-    }
-
-    /*
-     * Get service type.
-     */
-    RtlZeroMemory(&QueryTable, sizeof(QueryTable));
-
-    RtlInitUnicodeString(&ImagePath, NULL);
-
-    QueryTable[0].Name = L"Type";
-    QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT | 
RTL_QUERY_REGISTRY_REQUIRED;
-    QueryTable[0].EntryContext = &Type;
 
-    QueryTable[1].Name = L"ImagePath";
-    QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT;
-    QueryTable[1].EntryContext = &ImagePath;
+        ImagePath.Length = kvInfo->DataLength - sizeof(UNICODE_NULL),
+        ImagePath.MaximumLength = kvInfo->DataLength,
+        ImagePath.Buffer = ExAllocatePoolWithTag(PagedPool, 
ImagePath.MaximumLength, TAG_RTLREGISTRY);
+        if (!ImagePath.Buffer)
+        {
+            ExFreePool(kvInfo);
+            return STATUS_INSUFFICIENT_RESOURCES;
+        }
 
-    Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
-                                    RegistryPath->Buffer,
-                                    QueryTable, NULL, NULL);
-    if (!NT_SUCCESS(Status))
+        RtlMoveMemory(ImagePath.Buffer,
+                      (PVOID)((ULONG_PTR)kvInfo + kvInfo->DataOffset),
+                      ImagePath.Length);
+        ExFreePool(kvInfo);
+    }
+    else
     {
-        DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
-        if (ImagePath.Buffer) ExFreePool(ImagePath.Buffer);
         return Status;
     }
 
     /*
      * Normalize the image path for all later processing.
      */
-    Status = IopNormalizeImagePath(&ImagePath, &ServiceName);
+    Status = IopNormalizeImagePath(&ImagePath, NULL);
     if (!NT_SUCCESS(Status))
     {
         DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status);
@@ -2120,67 +2188,65 @@ IopLoadDriver(
     }
 
     DPRINT("FullImagePath: '%wZ'\n", &ImagePath);
-    DPRINT("Type: %lx\n", Type);
 
     KeEnterCriticalRegion();
     ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE);
+
     /*
-     * Get existing DriverObject pointer (in case the driver
-     * has already been loaded and initialized).
+     * Load the driver module
      */
-    Status = IopGetDriverObject(DriverObject,
-                                &ServiceName,
-                                (Type == SERVICE_FILE_SYSTEM_DRIVER ||
-                                 Type == SERVICE_RECOGNIZER_DRIVER));
+    DPRINT("Loading module from %wZ\n", &ImagePath);
+    Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, 
(PVOID)&ModuleObject, &BaseAddress);
+    RtlFreeUnicodeString(&ImagePath);
 
     if (!NT_SUCCESS(Status))
     {
-        /*
-         * Load the driver module
-         */
-        DPRINT("Loading module from %wZ\n", &ImagePath);
-        Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, 
(PVOID)&ModuleObject, &BaseAddress);
-        if (!NT_SUCCESS(Status))
-        {
-            DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status);
-            ExReleaseResourceLite(&IopDriverLoadResource);
-            KeLeaveCriticalRegion();
-            return Status;
-        }
-
-        /*
-         * Initialize the driver module if it's loaded for the first time
-         */
-        IopDisplayLoadingMessage(&ServiceName);
-
-        NTSTATUS driverEntryStatus;
-        Status = IopInitializeDriverModule(ModuleObject,
-                                           &ServiceName,
-                                           DriverObject,
-                                           &driverEntryStatus);
-        if (!NT_SUCCESS(Status))
-        {
-            DPRINT1("IopInitializeDriverModule() failed (Status %lx)\n", 
Status);
-            ExReleaseResourceLite(&IopDriverLoadResource);
-            KeLeaveCriticalRegion();
-            return Status;
-        }
-
+        DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status);
         ExReleaseResourceLite(&IopDriverLoadResource);
         KeLeaveCriticalRegion();
+        return Status;
     }
-    else
+
+    // Display the loading message
+    ULONG infoLength;
+    Status = ZwQueryKey(ServiceHandle, KeyBasicInformation, NULL, 0, 
&infoLength);
+    if (Status == STATUS_BUFFER_TOO_SMALL)
     {
-        ExReleaseResourceLite(&IopDriverLoadResource);
-        KeLeaveCriticalRegion();
+        PKEY_BASIC_INFORMATION servName = ExAllocatePoolWithTag(PagedPool, 
infoLength, TAG_IO);
+        if (servName)
+        {
+            Status = ZwQueryKey(ServiceHandle,
+                                KeyBasicInformation,
+                                servName,
+                                infoLength,
+                                &infoLength);
+            if (NT_SUCCESS(Status))
+            {
+                UNICODE_STRING serviceName = {
+                    .Length = servName->NameLength,
+                    .MaximumLength = servName->NameLength,
+                    .Buffer = servName->Name
+                };
 
-        DPRINT("DriverObject already exist in ObjectManager\n");
-        Status = STATUS_IMAGE_ALREADY_LOADED;
+                IopDisplayLoadingMessage(&serviceName);
+            }
+            ExFreePoolWithTag(servName, TAG_IO);
+        }
+    }
 
-        /* IopGetDriverObject references the DriverObject, so dereference it */
-        ObDereferenceObject(*DriverObject);
+    NTSTATUS driverEntryStatus;
+    Status = IopInitializeDriverModule(ModuleObject,
+                                       ServiceHandle,
+                                       DriverObject,
+                                       &driverEntryStatus);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("IopInitializeDriverModule() failed (Status %lx)\n", Status);
     }
 
+    ExReleaseResourceLite(&IopDriverLoadResource);
+    KeLeaveCriticalRegion();
+
     return Status;
 }
 
@@ -2203,7 +2269,18 @@ IopLoadUnloadDriverWorker(
     else
     {
         // load request
-        LoadParams->Status = IopLoadDriver(LoadParams->RegistryPath, 
&LoadParams->DriverObject);
+        HANDLE serviceHandle;
+        NTSTATUS status;
+        status = IopOpenRegistryKeyEx(&serviceHandle, NULL, 
LoadParams->RegistryPath, KEY_READ);
+        if (!NT_SUCCESS(status))
+        {
+            LoadParams->Status = status;
+        }
+        else
+        {
+            LoadParams->Status = IopLoadDriver(serviceHandle, 
&LoadParams->DriverObject);
+            ZwClose(serviceHandle);
+        }
     }
 
     if (LoadParams->SetEvent)
@@ -2223,7 +2300,7 @@ IopLoadUnloadDriverWorker(
  */
 NTSTATUS
 IopDoLoadUnloadDriver(
-    _In_opt_ PCUNICODE_STRING RegistryPath,
+    _In_opt_ PUNICODE_STRING RegistryPath,
     _Inout_ PDRIVER_OBJECT *DriverObject)
 {
     LOAD_UNLOAD_PARAMS LoadParams;
diff --git a/ntoskrnl/io/pnpmgr/devaction.c b/ntoskrnl/io/pnpmgr/devaction.c
index aebbf1a01c2..17347573216 100644
--- a/ntoskrnl/io/pnpmgr/devaction.c
+++ b/ntoskrnl/io/pnpmgr/devaction.c
@@ -47,6 +47,7 @@ WORK_QUEUE_ITEM IopDeviceActionWorkItem;
 BOOLEAN IopDeviceActionInProgress;
 KSPIN_LOCK IopDeviceActionLock;
 KEVENT PiEnumerationFinished;
+static const WCHAR ServicesKeyName[] = 
L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\";
 
 /* TYPES *********************************************************************/
 
@@ -1062,8 +1063,6 @@ IopActionInitChildServices(PDEVICE_NODE DeviceNode,
         PLDR_DATA_TABLE_ENTRY ModuleObject;
         PDRIVER_OBJECT DriverObject;
 
-        KeEnterCriticalRegion();
-        ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE);
         /* Get existing DriverObject pointer (in case the driver has
            already been loaded and initialized) */
         Status = IopGetDriverObject(
@@ -1073,17 +1072,44 @@ IopActionInitChildServices(PDEVICE_NODE DeviceNode,
 
         if (!NT_SUCCESS(Status))
         {
+            KeEnterCriticalRegion();
+            ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE);
+
             /* Driver is not initialized, try to load it */
             Status = IopLoadServiceModule(&DeviceNode->ServiceName, 
&ModuleObject);
 
             if (NT_SUCCESS(Status) || Status == STATUS_IMAGE_ALREADY_LOADED)
             {
+                UNICODE_STRING RegistryPath;
+
+                // obtain a handle for driver's RegistryPath
+                RegistryPath.Length = 0;
+                RegistryPath.MaximumLength = sizeof(ServicesKeyName) + 
DeviceNode->ServiceName.Length;
+                RegistryPath.Buffer = ExAllocatePoolWithTag(PagedPool, 
RegistryPath.MaximumLength, TAG_IO);
+                if (RegistryPath.Buffer == NULL)
+                {
+                    Status = STATUS_INSUFFICIENT_RESOURCES;
+                    goto OpenHandleFail;
+                }
+                RtlAppendUnicodeToString(&RegistryPath, ServicesKeyName);
+                RtlAppendUnicodeStringToString(&RegistryPath, 
&DeviceNode->ServiceName);
+
+                HANDLE serviceHandle;
+                Status = IopOpenRegistryKeyEx(&serviceHandle, NULL, 
&RegistryPath, KEY_READ);
+                RtlFreeUnicodeString(&RegistryPath);
+                if (!NT_SUCCESS(Status))
+                {
+                    goto OpenHandleFail;
+                }
+
                 /* Initialize the driver */
                 NTSTATUS driverEntryStatus;
                 Status = IopInitializeDriverModule(ModuleObject,
-                                                   &DeviceNode->ServiceName,
+                                                   serviceHandle,
                                                    &DriverObject,
                                                    &driverEntryStatus);
+                ZwClose(serviceHandle);
+
                 if (!NT_SUCCESS(Status))
                     DeviceNode->Problem = CM_PROB_FAILED_DRIVER_ENTRY;
             }
@@ -1099,9 +1125,10 @@ IopActionInitChildServices(PDEVICE_NODE DeviceNode,
                 if (!BootDrivers)
                     DeviceNode->Problem = CM_PROB_DRIVER_FAILED_LOAD;
             }
+OpenHandleFail:
+            ExReleaseResourceLite(&IopDriverLoadResource);
+            KeLeaveCriticalRegion();
         }
-        ExReleaseResourceLite(&IopDriverLoadResource);
-        KeLeaveCriticalRegion();
 
         /* Driver is loaded and initialized at this point */
         if (NT_SUCCESS(Status))

Reply via email to