Author: phater
Date: Fri Oct 28 16:37:39 2016
New Revision: 73054

URL: http://svn.reactos.org/svn/reactos?rev=73054&view=rev
Log:
[MSAFD] Fix handle counting in WSPSelect and improve TDI request according to 
MSDN.
CORE-12104

Modified:
    trunk/reactos/dll/win32/msafd/misc/dllmain.c

Modified: trunk/reactos/dll/win32/msafd/misc/dllmain.c
URL: 
http://svn.reactos.org/svn/reactos/trunk/reactos/dll/win32/msafd/misc/dllmain.c?rev=73054&r1=73053&r2=73054&view=diff
==============================================================================
--- trunk/reactos/dll/win32/msafd/misc/dllmain.c        [iso-8859-1] (original)
+++ trunk/reactos/dll/win32/msafd/misc/dllmain.c        [iso-8859-1] Fri Oct 28 
16:37:39 2016
@@ -1031,13 +1031,14 @@
     PAFD_POLL_INFO      PollInfo;
     NTSTATUS            Status;
     ULONG               HandleCount;
-    LONG                OutCount = 0;
     ULONG               PollBufferSize;
     PVOID               PollBuffer;
     ULONG               i, j = 0, x;
     HANDLE              SockEvent;
-    BOOL                HandleCounted;
     LARGE_INTEGER       Timeout;
+    PSOCKET_INFORMATION Socket;
+    SOCKET              Handle;
+    ULONG               Events;
 
     /* Find out how many sockets we have, and how large the buffer needs
      * to be */
@@ -1119,28 +1120,68 @@
     if (readfds != NULL) {
         for (i = 0; i < readfds->fd_count; i++, j++)
         {
+            Socket = GetSocketStructure(readfds->fd_array[i]);
+            if (!Socket)
+            {
+                ERR("Invalid socket handle provided in readfds %d\n", 
readfds->fd_array[i]);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             PollInfo->Handles[j].Handle = readfds->fd_array[i];
             PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
                                           AFD_EVENT_DISCONNECT |
                                           AFD_EVENT_ABORT |
                                           AFD_EVENT_CLOSE |
                                           AFD_EVENT_ACCEPT;
+            if (Socket->SharedData->OobInline != 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
         }
     }
     if (writefds != NULL)
     {
         for (i = 0; i < writefds->fd_count; i++, j++)
         {
+            Socket = GetSocketStructure(writefds->fd_array[i]);
+            if (!Socket)
+            {
+                ERR("Invalid socket handle provided in writefds %d\n", 
writefds->fd_array[i]);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             PollInfo->Handles[j].Handle = writefds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
+            PollInfo->Handles[j].Events = AFD_EVENT_SEND;
+            if (Socket->SharedData->NonBlocking != 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT;
         }
     }
     if (exceptfds != NULL)
     {
         for (i = 0; i < exceptfds->fd_count; i++, j++)
         {
+            Socket = GetSocketStructure(exceptfds->fd_array[i]);
+            if (!Socket)
+            {
+                TRACE("Invalid socket handle provided in exceptfds %d\n", 
exceptfds->fd_array[i]);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | 
AFD_EVENT_CONNECT_FAIL;
+            PollInfo->Handles[j].Events = 0;
+            if (Socket->SharedData->OobInline == 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
+            if (Socket->SharedData->NonBlocking != 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT_FAIL;
+            if (PollInfo->Handles[j].Events == 0)
+            {
+                TRACE("No events can be checked for exceptfds %d. It is 
nonblocking and OOB line is disabled. Skipping it.", exceptfds->fd_array[i]);
+                j--;
+            }
         }
     }
 
@@ -1182,10 +1223,20 @@
     /* Return in FDSET Format */
     for (i = 0; i < HandleCount; i++)
     {
-        HandleCounted = FALSE;
+        Events = PollInfo->Handles[i].Events;
+        Handle = PollInfo->Handles[i].Handle;
         for(x = 1; x; x<<=1)
         {
-            switch (PollInfo->Handles[i].Events & x)
+            Socket = GetSocketStructure(Handle);
+            if (!Socket)
+            {
+                TRACE("Invalid socket handle found %d\n", Handle);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
+            switch (Events & x)
             {
                 case AFD_EVENT_RECEIVE:
                 case AFD_EVENT_DISCONNECT:
@@ -1193,41 +1244,40 @@
                 case AFD_EVENT_ACCEPT:
                 case AFD_EVENT_CLOSE:
                     TRACE("Event %x on handle %x\n",
-                        PollInfo->Handles[i].Events,
-                        PollInfo->Handles[i].Handle);
-                    if (! HandleCounted)
-                    {
-                        OutCount++;
-                        HandleCounted = TRUE;
-                    }
+                        Events,
+                        Handle);
                     if( readfds )
-                        FD_SET(PollInfo->Handles[i].Handle, readfds);
+                        FD_SET(Handle, readfds);
                     break;
                 case AFD_EVENT_SEND:
+                    TRACE("Event %x on handle %x\n",
+                        Events,
+                        Handle);
+                    if (writefds)
+                        FD_SET(Handle, writefds);
+                    break;
                 case AFD_EVENT_CONNECT:
                     TRACE("Event %x on handle %x\n",
-                        PollInfo->Handles[i].Events,
-                        PollInfo->Handles[i].Handle);
-                    if (! HandleCounted)
-                    {
-                        OutCount++;
-                        HandleCounted = TRUE;
-                    }
-                    if( writefds )
-                        FD_SET(PollInfo->Handles[i].Handle, writefds);
+                        Events,
+                        Handle);
+                    if( writefds && Socket->SharedData->NonBlocking != 0 )
+                        FD_SET(Handle, writefds);
                     break;
                 case AFD_EVENT_OOB_RECEIVE:
+                    TRACE("Event %x on handle %x\n",
+                        Events,
+                        Handle);
+                    if( readfds && Socket->SharedData->OobInline != 0 )
+                        FD_SET(Handle, readfds);
+                    if( exceptfds && Socket->SharedData->OobInline == 0 )
+                        FD_SET(Handle, exceptfds);
+                    break;
                 case AFD_EVENT_CONNECT_FAIL:
                     TRACE("Event %x on handle %x\n",
-                        PollInfo->Handles[i].Events,
-                        PollInfo->Handles[i].Handle);
-                    if (! HandleCounted)
-                    {
-                        OutCount++;
-                        HandleCounted = TRUE;
-                    }
-                    if( exceptfds )
-                        FD_SET(PollInfo->Handles[i].Handle, exceptfds);
+                        Events,
+                        Handle);
+                    if( exceptfds && Socket->SharedData->NonBlocking != 0 )
+                        FD_SET(Handle, exceptfds);
                     break;
             }
         }
@@ -1251,9 +1301,13 @@
         TRACE("*lpErrno = %x\n", *lpErrno);
     }
 
-    TRACE("%d events\n", OutCount);
-
-    return OutCount;
+    HandleCount = (readfds ? readfds->fd_count : 0) +
+                  (writefds && writefds != readfds ? writefds->fd_count : 0) +
+                  (exceptfds && exceptfds != readfds && exceptfds != writefds 
? exceptfds->fd_count : 0);
+
+    TRACE("%d events\n", HandleCount);
+
+    return HandleCount;
 }
 
 SOCKET


Reply via email to