https://git.reactos.org/?p=reactos.git;a=commitdiff;h=141378cfc8c7964516c1f7383110e440aac1f84d

commit 141378cfc8c7964516c1f7383110e440aac1f84d
Author:     Hermès Bélusca-Maïto <[email protected]>
AuthorDate: Tue Jul 28 01:10:58 2020 +0200
Commit:     Hermès Bélusca-Maïto <[email protected]>
CommitDate: Wed Sep 23 00:22:47 2020 +0200

    [CMD] ASSOC: Simplify the code and make it more robust; fix returned 
ERRORLEVEL values.
    
    - Make sure that non-administrator users can list associations, and
      display appropriate error messages when e.g. they don't have sufficient
      privileges to perform an operation.
    
    - Make the helper functions all return Win32 values, used as the
      ERRORVALUE, except when a specific extension association fails to be
      displayed, in which case the ERRORVALUE is normalized to 1.
    
    - Since the 'param' is a modifiable string (that can be modified by the
      command, independently of the way it's called), just use it to isolate
      the extension by zeroing out the equls-sign separator.
---
 base/shell/cmd/assoc.c | 291 ++++++++++++++++++++++++++++---------------------
 1 file changed, 167 insertions(+), 124 deletions(-)

diff --git a/base/shell/cmd/assoc.c b/base/shell/cmd/assoc.c
index 3e5611e6b1f..6441d3d369e 100644
--- a/base/shell/cmd/assoc.c
+++ b/base/shell/cmd/assoc.c
@@ -13,187 +13,232 @@
  *
  * TODO:
  * - PrintAllAssociations could be optimized to not fetch all registry subkeys 
under 'Classes', just the ones that start with '.'
- * - Make sure that non-administrator users can list associations, and get 
appropriate error messages when they don't have sufficient
- *   privileges to perform an operation.
  */
 
 #include "precomp.h"
 
 #ifdef INCLUDE_CMD_ASSOC
 
-static INT
-PrintAssociation(
-    IN LPCTSTR extension)
+static LONG
+PrintAssociationEx(
+    IN HKEY hKeyClasses,
+    IN PCTSTR pszExtension)
 {
-    DWORD lRet;
-    HKEY hKey = NULL, hSubKey = NULL;
-    DWORD fileTypeLength = 0;
-    LPTSTR fileType = NULL;
-
-    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, 
KEY_READ, &hKey);
-    if (lRet != ERROR_SUCCESS)
-        return -1;
-
-    lRet = RegOpenKeyEx(hKey, extension, 0, KEY_READ, &hSubKey);
-    RegCloseKey(hKey);
+    LONG lRet;
+    HKEY hKey;
+    DWORD dwFileTypeLen = 0;
+    PTSTR pszFileType;
 
+    lRet = RegOpenKeyEx(hKeyClasses, pszExtension, 0, KEY_QUERY_VALUE, &hKey);
     if (lRet != ERROR_SUCCESS)
-        return 0;
+    {
+        if (lRet != ERROR_FILE_NOT_FOUND)
+            ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    /* Obtain string length */
-    lRet = RegQueryValueEx(hSubKey, NULL, NULL, NULL, NULL, &fileTypeLength);
+    /* Obtain the string length */
+    lRet = RegQueryValueEx(hKey, NULL, NULL, NULL, NULL, &dwFileTypeLen);
 
-    /* If there is no default value, don't display */
+    /* If there is no default value, don't display it */
     if (lRet == ERROR_FILE_NOT_FOUND)
     {
-        RegCloseKey(hSubKey);
-        return 0;
+        RegCloseKey(hKey);
+        return lRet;
     }
     if (lRet != ERROR_SUCCESS)
     {
-        RegCloseKey(hSubKey);
-        return -2;
+        ErrorMessage(lRet, NULL);
+        RegCloseKey(hKey);
+        return lRet;
     }
 
-    fileType = cmd_alloc(fileTypeLength * sizeof(TCHAR));
-    if (!fileType)
+    ++dwFileTypeLen;
+    pszFileType = cmd_alloc(dwFileTypeLen * sizeof(TCHAR));
+    if (!pszFileType)
     {
-        WARN("Cannot allocate memory for fileType!\n");
-        RegCloseKey(hSubKey);
-        return -2;
+        WARN("Cannot allocate memory for pszFileType!\n");
+        RegCloseKey(hKey);
+        return ERROR_NOT_ENOUGH_MEMORY;
     }
 
-    /* Obtain actual file type */
-    lRet = RegQueryValueEx(hSubKey, NULL, NULL, NULL, (LPBYTE)fileType, 
&fileTypeLength);
-    RegCloseKey(hSubKey);
+    /* Obtain the actual file type */
+    lRet = RegQueryValueEx(hKey, NULL, NULL, NULL, (LPBYTE)pszFileType, 
&dwFileTypeLen);
+    RegCloseKey(hKey);
 
     if (lRet != ERROR_SUCCESS)
     {
-        cmd_free(fileType);
-        return -2;
+        ErrorMessage(lRet, NULL);
+        cmd_free(pszFileType);
+        return lRet;
     }
 
-    /* If there is a default key, display relevant information */
-    if (fileTypeLength != 0)
+    /* If there is a default key, display the relevant information */
+    if (dwFileTypeLen != 0)
     {
-        ConOutPrintf(_T("%s=%s\n"), extension, fileType);
+        ConOutPrintf(_T("%s=%s\n"), pszExtension, pszFileType);
     }
 
-    cmd_free(fileType);
-    return 1;
+    cmd_free(pszFileType);
+    return ERROR_SUCCESS;
 }
 
-static INT
-PrintAllAssociations(VOID)
+static LONG
+PrintAssociation(
+    IN PCTSTR pszExtension)
 {
-    DWORD lRet = 0;
-    HKEY hKey = NULL;
-    DWORD numKeys = 0;
+    LONG lRet;
+    HKEY hKeyClasses;
+
+    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
+                        KEY_ENUMERATE_SUB_KEYS, &hKeyClasses);
+    if (lRet != ERROR_SUCCESS)
+    {
+        ErrorMessage(lRet, NULL);
+        return lRet;
+    }
+
+    lRet = PrintAssociationEx(hKeyClasses, pszExtension);
 
-    DWORD extLength = 0;
-    LPTSTR extName = NULL;
-    DWORD keyCtr = 0;
+    RegCloseKey(hKeyClasses);
+    return lRet;
+}
 
-    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, 
KEY_READ, &hKey);
+static LONG
+PrintAllAssociations(VOID)
+{
+    LONG lRet;
+    HKEY hKeyClasses;
+    DWORD dwKeyCtr;
+    DWORD dwNumKeys = 0;
+    DWORD dwExtLen = 0;
+    PTSTR pszExtName;
+
+    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
+                        KEY_QUERY_VALUE | KEY_ENUMERATE_SUB_KEYS, 
&hKeyClasses);
     if (lRet != ERROR_SUCCESS)
-        return -1;
+    {
+        ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    lRet = RegQueryInfoKey(hKey, NULL, NULL, NULL, &numKeys, &extLength,
+    lRet = RegQueryInfoKey(hKeyClasses, NULL, NULL, NULL, &dwNumKeys, 
&dwExtLen,
                            NULL, NULL, NULL, NULL, NULL, NULL);
     if (lRet != ERROR_SUCCESS)
     {
-        RegCloseKey(hKey);
-        return -2;
+        ErrorMessage(lRet, NULL);
+        RegCloseKey(hKeyClasses);
+        return lRet;
     }
 
-    extLength++;
-    extName = cmd_alloc(extLength * sizeof(TCHAR));
-    if (!extName)
+    ++dwExtLen;
+    pszExtName = cmd_alloc(dwExtLen * sizeof(TCHAR));
+    if (!pszExtName)
     {
-        WARN("Cannot allocate memory for extName!\n");
-        RegCloseKey(hKey);
-        return -2;
+        WARN("Cannot allocate memory for pszExtName!\n");
+        RegCloseKey(hKeyClasses);
+        return ERROR_NOT_ENOUGH_MEMORY;
     }
 
-    for (keyCtr = 0; keyCtr < numKeys; ++keyCtr)
+    for (dwKeyCtr = 0; dwKeyCtr < dwNumKeys; ++dwKeyCtr)
     {
-        DWORD dwBufSize = extLength;
-        lRet = RegEnumKeyEx(hKey, keyCtr, extName, &dwBufSize,
+        DWORD dwBufSize = dwExtLen;
+        lRet = RegEnumKeyEx(hKeyClasses, dwKeyCtr, pszExtName, &dwBufSize,
                             NULL, NULL, NULL, NULL);
 
         if (lRet == ERROR_SUCCESS || lRet == ERROR_MORE_DATA)
         {
-            if (*extName == _T('.'))
-                PrintAssociation(extName);
+            /* Name starts with '.': this is an extension */
+            if (*pszExtName == _T('.'))
+                PrintAssociationEx(hKeyClasses, pszExtName);
         }
         else
         {
-            cmd_free(extName);
-            RegCloseKey(hKey);
-            return -1;
+            ErrorMessage(lRet, NULL);
+            cmd_free(pszExtName);
+            RegCloseKey(hKeyClasses);
+            return lRet;
         }
     }
 
-    RegCloseKey(hKey);
+    RegCloseKey(hKeyClasses);
 
-    cmd_free(extName);
-    return numKeys;
+    cmd_free(pszExtName);
+    return ERROR_SUCCESS;
 }
 
-static INT
+static LONG
 AddAssociation(
-    IN LPCTSTR extension,
-    IN LPCTSTR type)
+    IN PCTSTR pszExtension,
+    IN PCTSTR pszType)
 {
-    DWORD lRet;
-    HKEY hKey = NULL, hSubKey = NULL;
+    LONG lRet;
+    HKEY hKeyClasses, hKey;
 
-    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, 
KEY_ALL_ACCESS, &hKey);
+    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
+                        KEY_CREATE_SUB_KEY, &hKeyClasses);
     if (lRet != ERROR_SUCCESS)
-        return -1;
+    {
+        ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    lRet = RegCreateKeyEx(hKey, extension, 0, NULL, REG_OPTION_NON_VOLATILE,
-                          KEY_ALL_ACCESS, NULL, &hSubKey, NULL);
-    RegCloseKey(hKey);
+    lRet = RegCreateKeyEx(hKeyClasses, pszExtension, 0, NULL, 
REG_OPTION_NON_VOLATILE,
+                          KEY_SET_VALUE, NULL, &hKey, NULL);
+    RegCloseKey(hKeyClasses);
 
     if (lRet != ERROR_SUCCESS)
-        return -1;
+    {
+        ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    lRet = RegSetValueEx(hSubKey, NULL, 0, REG_SZ,
-                         (LPBYTE)type, (_tcslen(type) + 1) * sizeof(TCHAR));
-    RegCloseKey(hSubKey);
+    lRet = RegSetValueEx(hKey, NULL, 0, REG_SZ,
+                         (LPBYTE)pszType, (DWORD)(_tcslen(pszType) + 1) * 
sizeof(TCHAR));
+    RegCloseKey(hKey);
 
     if (lRet != ERROR_SUCCESS)
-        return -2;
+    {
+        ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    return 0;
+    return ERROR_SUCCESS;
 }
 
-static INT
+static LONG
 RemoveAssociation(
-    IN LPCTSTR extension)
+    IN PCTSTR pszExtension)
 {
-    DWORD lRet;
-    HKEY hKey;
+    LONG lRet;
+    HKEY hKeyClasses;
 
-    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, 
KEY_ALL_ACCESS, &hKey);
+    lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
+                        KEY_QUERY_VALUE, &hKeyClasses);
     if (lRet != ERROR_SUCCESS)
-        return -1;
+    {
+        ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    lRet = RegDeleteKey(hKey, extension);
-    RegCloseKey(hKey);
+    lRet = RegDeleteKey(hKeyClasses, pszExtension);
+    RegCloseKey(hKeyClasses);
 
     if (lRet != ERROR_SUCCESS)
-        return -2;
+    {
+        if (lRet != ERROR_FILE_NOT_FOUND)
+            ErrorMessage(lRet, NULL);
+        return lRet;
+    }
 
-    return 0;
+    return ERROR_SUCCESS;
 }
 
 
 INT CommandAssoc(LPTSTR param)
 {
     INT retval = 0;
-    LPTSTR lpEqualSign;
+    PTCHAR pEqualSign;
 
     /* Print help */
     if (!_tcsncmp(param, _T("/?"), 2))
@@ -202,53 +247,51 @@ INT CommandAssoc(LPTSTR param)
         return 0;
     }
 
-    if (_tcslen(param) == 0)
+    /* Print all associations if no parameter has been specified */
+    if (!*param)
     {
         PrintAllAssociations();
         goto Quit;
     }
 
-    lpEqualSign = _tcschr(param, _T('='));
-    if (lpEqualSign != NULL)
+    pEqualSign = _tcschr(param, _T('='));
+    if (pEqualSign != NULL)
     {
-        LPTSTR fileType = lpEqualSign + 1;
-        LPTSTR extension = cmd_alloc((lpEqualSign - param + 1) * 
sizeof(TCHAR));
-        if (!extension)
-        {
-            WARN("Cannot allocate memory for extension!\n");
-            error_out_of_memory();
-            retval = 1;
-            goto Quit;
-        }
+        PTSTR pszFileType = pEqualSign + 1;
 
-        _tcsncpy(extension, param, lpEqualSign - param);
-        extension[lpEqualSign - param] = _T('\0');
+        /* NULL-terminate at the equals sign */
+        *pEqualSign = 0;
 
-        /* If the equal sign is the last character
-         * in the string, then delete the key. */
-        if (_tcslen(fileType) == 0)
+        /* If the equals sign is the last character
+         * in the string, delete the association. */
+        if (*pszFileType == 0)
         {
-            retval = RemoveAssociation(extension);
+            retval = RemoveAssociation(param);
         }
         else
-        /* Otherwise, add the key and print out the association */
+        /* Otherwise, add the association and print it out */
         {
-            retval = AddAssociation(extension, fileType);
-            PrintAssociation(extension);
+            retval = AddAssociation(param, pszFileType);
+            PrintAssociation(param);
         }
 
-        cmd_free(extension);
-
-        if (retval)
-            retval = 1; /* Fixup the error value */
+        if (retval != ERROR_SUCCESS)
+        {
+            if (retval != ERROR_FILE_NOT_FOUND)
+            {
+                // FIXME: Localize
+                ConErrPrintf(_T("Error occurred while processing: %s.\n"), 
param);
+            }
+            // retval = 1; /* Fixup the error value */
+        }
     }
     else
     {
-        /* No equal sign, print all associations */
+        /* No equals sign, print the association */
         retval = PrintAssociation(param);
-        if (retval == 0)    /* If nothing printed out */
+        if (retval != ERROR_SUCCESS)
         {
-            ConOutResPrintf(STRING_ASSOC_ERROR, param);
+            ConErrResPrintf(STRING_ASSOC_ERROR, param);
             retval = 1; /* Fixup the error value */
         }
     }

Reply via email to