bneradt commented on code in PR #13220:
URL: https://github.com/apache/trafficserver/pull/13220#discussion_r3337526497
##########
tests/gold_tests/tunnel/dumb_proxy.py:
##########
@@ -40,84 +37,82 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
-def initialize_thread_local_data():
- thread_local_data.client_to_server_bytes = 0
- thread_local_data.server_to_client_bytes = 0
-
-
-def forward(source, destination, is_client_to_server):
+def forward(source, destination, is_client_to_server, byte_counts):
"""Forward traffic from source to destination.
:param source: socket to read from.
:param destination: socket to write to.
:param is_client_to_server: True if forwarding from client to server.
+ :param byte_counts: A dictionary to record bytes sent in each direction.
"""
- # Initialize thread-local data.
- initialize_thread_local_data()
+ bytes_transferred = 0
- while True:
- try:
+ try:
+ while True:
data = source.recv(4096)
if not data:
break
destination.sendall(data)
- except Exception as e:
- # Catching all exceptions.
- break
+ bytes_transferred += len(data)
+ except OSError:
+ pass
+ finally:
+ try:
+ destination.shutdown(socket.SHUT_WR)
+ except OSError:
+ pass
+ # Forwarding done. Print the number of bytes transferred in the direction.
+ if bytes_transferred > 0:
if is_client_to_server:
- thread_local_data.client_to_server_bytes += len(data)
+ byte_counts["client-to-server"] = bytes_transferred
+ print(f"client-to-server: {bytes_transferred}", flush=True)
else:
- thread_local_data.server_to_client_bytes += len(data)
- # Forwarding done. Print the number of bytes transferred in the direction.
- if thread_local_data.client_to_server_bytes > 0:
- print(f"client-to-server: {thread_local_data.client_to_server_bytes}")
- elif thread_local_data.server_to_client_bytes > 0:
- print(f"server-to-client: {thread_local_data.server_to_client_bytes}")
+ byte_counts["server-to-client"] = bytes_transferred
+ print(f"server-to-client: {bytes_transferred}", flush=True)
def start_bidirectional_forwarding(client_socket, forwarding_port):
"""Start forwarding traffic between client and server.
:param client_socket: socket connected to the client.
:param forwarding_port: server port to forward to.
+ :return: The number of bytes forwarded in both directions.
"""
CLIENT_TO_SERVER = True
SERVER_TO_CLIENT = False
+ byte_counts = {}
with client_socket, socket.socket(socket.AF_INET, socket.SOCK_STREAM) as
server_socket:
- client_socket.settimeout(TIMEOUT)
- server_socket.settimeout(TIMEOUT)
server_socket.connect((LOCAL_HOST, forwarding_port))
# Spawn a thread to forward traffic from client to server.
Review Comment:
Good catch. I added an OSError guard around the forwarding-port connect so
readiness-probe connections are treated as zero-byte traffic and the proxy
keeps listening for the real test connection.
##########
tests/gold_tests/h2/grpc/grpc_server.py:
##########
@@ -32,45 +32,62 @@
class Talker(simple_pb2_grpc.TalkerServicer):
"""A gRPC servicer."""
- async def MakeRequest(self, request: simple_pb2.SimpleRequest, context:
grpc.aio.ServicerContext):
- """An example gRPC method."""
+ def __init__(self, num_expected_messages: int, done_event: asyncio.Event):
+ self._num_expected_messages = num_expected_messages
+ self._done_event = done_event
+
+ def _record_message(self) -> None:
global global_message_counter
+
global_message_counter += 1
+ if global_message_counter >= self._num_expected_messages:
+ asyncio.get_running_loop().call_soon(self._done_event.set)
+
+ async def MakeRequest(self, request: simple_pb2.SimpleRequest, context:
grpc.aio.ServicerContext):
+ """An example gRPC method."""
+ self._record_message()
print(f'Received request: {request.message}')
response = simple_pb2.SimpleResponse(message=f"Echo:
{request.message}")
return response
async def MakeAnotherRequest(self, request: simple_pb2.SimpleRequest,
context: grpc.aio.ServicerContext):
"""An example gRPC method."""
- global global_message_counter
- global_message_counter += 1
+ self._record_message()
print(f'Received another request: {request.message}')
response = simple_pb2.SimpleResponse(message=f"Another echo:
{request.message}")
return response
-async def run_grpc_server(port: int, server_cert: str, server_key: str) -> int:
+async def run_grpc_server(port: int, server_cert: str, server_key: str,
num_expected_messages: int) -> int:
"""Run the gRPC server.
:param port: The port on which to listen.
:param server_cert: The public TLS certificate to use.
:param server_key: The private TLS key to use.
+ :param num_expected_messages: The number of messages expected from clients.
:return: The exit code.
"""
credentials = grpc.ssl_server_credentials([(server_key, server_cert)])
server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))
- simple_pb2_grpc.add_TalkerServicer_to_server(Talker(), server)
+ done_event = asyncio.Event()
+ simple_pb2_grpc.add_TalkerServicer_to_server(Talker(num_expected_messages,
done_event), server)
server_endpoint = f'127.0.0.1:{port}'
server.add_secure_port(server_endpoint, credentials)
print(f'Listening on: {server_endpoint}')
try:
await server.start()
- await server.wait_for_termination()
+ await done_event.wait()
except asyncio.exceptions.CancelledError:
print('Shutting down the server.')
finally:
- await server.stop(0)
- return 0
+ await server.stop(5)
+
Review Comment:
I wrapped the wait for expected messages in a bounded timeout and leave the
existing message-count check to return non-zero if the client path fails.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]