This is an automated email from the ASF dual-hosted git repository.
aaronai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git
The following commit(s) were added to refs/heads/master by this push:
new 60115cf8 General message engineering in Producer (#559)
60115cf8 is described below
commit 60115cf878148ac6671163789d9968dc271de45e
Author: Yan Chao Mei <[email protected]>
AuthorDate: Tue Jul 25 10:02:28 2023 +0800
General message engineering in Producer (#559)
* General message engineering in Producer
* remove user
* finish retry & isolation
* add comments&exception handler
* add license
* use snake case naming
* fix name & finish telemetry rebuild
* fix style issues
* finish retry&isolation test
* init delay&fifo message
* finish delay & fifo & transaction message & its tests
---
python/rocketmq/client.py | 306 ++++++++++++++--
python/rocketmq/client_config.py | 30 +-
python/rocketmq/client_id_encoder.py | 19 +-
python/rocketmq/definition.py | 104 ++++++
.../rocketmq/exponential_backoff_retry_policy.py | 100 ++++++
python/rocketmq/producer.py | 387 ++++++++++++++++++++-
python/rocketmq/publish_settings.py | 7 +-
python/rocketmq/publishing_message.py | 86 +++++
python/rocketmq/rpc_client.py | 2 +-
python/rocketmq/send_receipt.py | 27 +-
python/rocketmq/session.py | 40 ++-
python/rocketmq/status_checker.py | 212 +++++++++++
python/rocketmq/utils.py | 5 +
python/tests/test_foo.py | 3 +-
14 files changed, 1270 insertions(+), 58 deletions(-)
diff --git a/python/rocketmq/client.py b/python/rocketmq/client.py
index 0ef32e60..509d991d 100644
--- a/python/rocketmq/client.py
+++ b/python/rocketmq/client.py
@@ -13,84 +13,248 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
import threading
from typing import Set
-from protocol import service_pb2
-from protocol.service_pb2 import QueryRouteRequest
+from protocol import definition_pb2, service_pb2
+from protocol.definition_pb2 import Code as ProtoCode
+from protocol.service_pb2 import HeartbeatRequest, QueryRouteRequest
from rocketmq.client_config import ClientConfig
from rocketmq.client_id_encoder import ClientIdEncoder
-from rocketmq.definition import TopicRouteData
+from rocketmq.definition import Resource, TopicRouteData
+from rocketmq.log import logger
from rocketmq.rpc_client import Endpoints, RpcClient
from rocketmq.session import Session
from rocketmq.signature import Signature
+class ScheduleWithFixedDelay:
+ def __init__(self, action, delay, period):
+ self.action = action
+ self.delay = delay
+ self.period = period
+ self.task = None
+
+ async def start(self):
+ await asyncio.sleep(self.delay)
+ while True:
+ try:
+ await self.action()
+ except Exception as e:
+ logger.error(e, "Failed to execute scheduled task")
+ finally:
+ await asyncio.sleep(self.period)
+
+ def schedule(self):
+ loop1 = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop1)
+ self.task = asyncio.create_task(self.start())
+
+ def cancel(self):
+ if self.task:
+ self.task.cancel()
+
+
class Client:
+ """
+ Main client class which handles interaction with the server.
+ """
def __init__(self, client_config: ClientConfig, topics: Set[str]):
+ """
+ Initialization method for the Client class.
+
+ :param client_config: Client configuration.
+ :param topics: Set of topics that the client is subscribed to.
+ """
self.client_config = client_config
self.client_id = ClientIdEncoder.generate()
self.endpoints = client_config.endpoints
self.topics = topics
+ #: A cache to store topic routes.
self.topic_route_cache = {}
+ #: A table to store session information.
self.sessions_table = {}
self.sessionsLock = threading.Lock()
self.client_manager = ClientManager(self)
+ #: A dictionary to store isolated items.
+ self.isolated = dict()
+
async def start(self):
+ """
+ Start method which initiates fetching of topic routes and schedules
heartbeats.
+ """
# get topic route
+ logger.debug(f"Begin to start the rocketmq client,
client_id={self.client_id}")
for topic in self.topics:
self.topic_route_cache[topic] = await self.fetch_topic_route(topic)
+ scheduler = ScheduleWithFixedDelay(self.heartbeat, 3, 12)
+ scheduler_sync_settings = ScheduleWithFixedDelay(self.sync_settings,
3, 12)
+ scheduler.schedule()
+ scheduler_sync_settings.schedule()
+ logger.debug(f"Start the rocketmq client successfully,
client_id={self.client_id}")
+
+ async def shutdown(self):
+ logger.debug(f"Begin to shutdown rocketmq client,
client_id={self.client_id}")
- def GetTotalRouteEndpoints(self):
+ logger.debug(f"Shutdown the rocketmq client successfully,
client_id={self.client_id}")
+
+ async def heartbeat(self):
+ """
+ Asynchronous method that sends a heartbeat to the server.
+ """
+ try:
+ endpoints = self.get_total_route_endpoints()
+ request = HeartbeatRequest()
+ request.client_type = definition_pb2.PRODUCER
+ topic = Resource()
+ topic.name = "normal_topic"
+ # Collect task into a map.
+ for item in endpoints:
+ try:
+
+ task = await self.client_manager.heartbeat(item, request,
self.client_config.request_timeout)
+ code = task.status.code
+ if code == ProtoCode.OK:
+ logger.info(f"Send heartbeat successfully,
endpoints={item}, client_id={self.client_id}")
+
+ if item in self.isolated:
+ self.isolated.pop(item)
+ logger.info(f"Rejoin endpoints which was isolated
before, endpoints={item}, "
+ + f"client_id={self.client_id}")
+ return
+ status_message = task.status.message
+ logger.info(f"Failed to send heartbeat, endpoints={item},
code={code}, "
+ + f"status_message={status_message},
client_id={self.client_id}")
+ except Exception:
+ logger.error(f"Failed to send heartbeat, endpoints={item}")
+ except Exception as e:
+ logger.error(f"[Bug] unexpected exception raised during heartbeat,
client_id={self.client_id}, Exception: {str(e)}")
+
+ def get_total_route_endpoints(self):
+ """
+ Method that returns all route endpoints.
+ """
endpoints = set()
for item in self.topic_route_cache.items():
for endpoint in [mq.broker.endpoints for mq in
item[1].message_queues]:
endpoints.add(endpoint)
return endpoints
+ async def get_route_data(self, topic):
+ """
+ Asynchronous method that fetches route data for a given topic.
+
+ :param topic: The topic to fetch route data for.
+ """
+ if topic in self.topic_route_cache:
+ return self.topic_route_cache[topic]
+ topic_route_data = await self.fetch_topic_route(topic=topic)
+ return topic_route_data
+
def get_client_config(self):
+ """
+ Method to return client configuration.
+ """
return self.client_config
- async def OnTopicRouteDataFetched(self, topic, topicRouteData):
+ async def sync_settings(self):
+ total_route_endpoints = self.get_total_route_endpoints()
+
+ for endpoints in total_route_endpoints:
+ created, session = await self.get_session(endpoints)
+ await session.sync_settings(True)
+ logger.info(f"Sync settings to remote, endpoints={endpoints}")
+
+ def stats(self):
+ # TODO: stats implement
+ pass
+
+ async def notify_client_termination(self):
+ pass
+
+ async def on_recover_orphaned_transaction_command(self, endpoints,
command):
+ pass
+
+ async def on_verify_message_command(self, endpoints, command):
+ logger.warn(f"Ignore verify message command from remote, which is not
expected, clientId={self.client_id}, "
+ + f"endpoints={endpoints}, command={command}")
+ pass
+
+ async def on_print_thread_stack_trace_command(self, endpoints, command):
+ pass
+
+ async def on_settings_command(self, endpoints, settings):
+ pass
+
+ async def on_topic_route_data_fetched(self, topic, topic_route_data):
+ """
+ Asynchronous method that handles the process once the topic route data
is fetched.
+
+ :param topic: The topic for which the route data is fetched.
+ :param topic_route_data: The fetched topic route data.
+ """
route_endpoints = set()
- for mq in topicRouteData.message_queues:
+ for mq in topic_route_data.message_queues:
route_endpoints.add(mq.broker.endpoints)
- existed_route_endpoints = self.GetTotalRouteEndpoints()
+ existed_route_endpoints = self.get_total_route_endpoints()
new_endpoints = route_endpoints.difference(existed_route_endpoints)
for endpoints in new_endpoints:
- created, session = await self.GetSession(endpoints)
+ created, session = await self.get_session(endpoints)
if not created:
continue
-
+ logger.info(f"Begin to establish session for
endpoints={endpoints}, client_id={self.client_id}")
await session.sync_settings(True)
+ logger.info(f"Establish session for endpoints={endpoints}
successfully, client_id={self.client_id}")
- self.topic_route_cache[topic] = topicRouteData
- # self.OnTopicRouteDataUpdated0(topic, topicRouteData)
+ self.topic_route_cache[topic] = topic_route_data
async def fetch_topic_route0(self, topic):
- req = QueryRouteRequest()
- req.topic.name = topic
- address = req.endpoints.addresses.add()
- address.host = self.endpoints.Addresses[0].host
- address.port = self.endpoints.Addresses[0].port
- req.endpoints.scheme =
self.endpoints.scheme.to_protobuf(self.endpoints.scheme)
- response = await self.client_manager.query_route(self.endpoints, req,
10)
-
- message_queues = response.message_queues
- return TopicRouteData(message_queues)
-
- # return topic data
+ """
+ Asynchronous method that fetches the topic route.
+
+ :param topic: The topic to fetch the route for.
+ """
+ try:
+ req = QueryRouteRequest()
+ req.topic.name = topic
+ address = req.endpoints.addresses.add()
+ address.host = self.endpoints.Addresses[0].host
+ address.port = self.endpoints.Addresses[0].port
+ req.endpoints.scheme =
self.endpoints.scheme.to_protobuf(self.endpoints.scheme)
+ response = await self.client_manager.query_route(self.endpoints,
req, 10)
+ code = response.status.code
+ if code != ProtoCode.OK:
+ logger.error(f"Failed to fetch topic route,
client_id={self.client_id}, topic={topic}, code={code}, "
+ + f"statusMessage={response.status.message}")
+ message_queues = response.message_queues
+ return TopicRouteData(message_queues)
+ except Exception as e:
+ logger.error(e, f"Failed to fetch topic route,
client_id={self.client_id}, topic={topic}")
+ raise
+
async def fetch_topic_route(self, topic):
+ """
+ Asynchronous method that fetches the topic route and updates the data.
+
+ :param topic: The topic to fetch the route for.
+ """
topic_route_data = await self.fetch_topic_route0(topic)
- await self.OnTopicRouteDataFetched(topic, topic_route_data)
+ await self.on_topic_route_data_fetched(topic, topic_route_data)
+ logger.info(f"Fetch topic route successfully,
client_id={self.client_id}, topic={topic}, topicRouteData={topic_route_data}")
return topic_route_data
- async def GetSession(self, endpoints):
+ async def get_session(self, endpoints):
+ """
+ Asynchronous method that gets the session for a given endpoint.
+
+ :param endpoints: The endpoints to get the session for.
+ """
self.sessionsLock.acquire()
try:
# Session exists, return in advance.
@@ -105,21 +269,41 @@ class Client:
if endpoints in self.sessions_table:
return (False, self.sessions_table[endpoints])
- stream = self.client_manager.telemetry(endpoints, 10)
+ stream = self.client_manager.telemetry(endpoints, 10000000)
created = Session(endpoints, stream, self)
self.sessions_table[endpoints] = created
return (True, created)
finally:
self.sessionsLock.release()
+ def get_client_id(self):
+ return self.client_id
+
class ClientManager:
+ """Manager class for RPC Clients in a thread-safe manner.
+ Each instance is created by a specific client and can manage
+ multiple RPC clients.
+ """
+
def __init__(self, client: Client):
+ #: The client that instantiated this manager.
self.__client = client
+
+ #: A dictionary that maps endpoints to the corresponding RPC clients.
self.__rpc_clients = {}
+
+ #: A lock used to ensure thread safety when accessing __rpc_clients.
self.__rpc_clients_lock = threading.Lock()
def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool):
+ """Retrieve the RPC client corresponding to the given endpoints.
+ If not present, a new RPC client is created and stored in
__rpc_clients.
+
+ :param endpoints: The endpoints associated with the RPC client.
+ :param ssl_enabled: A flag indicating whether SSL is enabled.
+ :return: The RPC client associated with the given endpoints.
+ """
with self.__rpc_clients_lock:
rpc_client = self.__rpc_clients.get(endpoints)
if rpc_client:
@@ -134,10 +318,16 @@ class ClientManager:
request: service_pb2.QueryRouteRequest,
timeout_seconds: int,
):
+ """Query the routing information.
+
+ :param endpoints: The endpoints to query.
+ :param request: The request containing the details of the query.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the query.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
-
metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
return await rpc_client.query_route(request, metadata, timeout_seconds)
@@ -147,6 +337,13 @@ class ClientManager:
request: service_pb2.HeartbeatRequest,
timeout_seconds: int,
):
+ """Send a heartbeat to the server to indicate that the client is still
alive.
+
+ :param endpoints: The endpoints to send the heartbeat to.
+ :param request: The request containing the details of the heartbeat.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the heartbeat.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -159,6 +356,13 @@ class ClientManager:
request: service_pb2.SendMessageRequest,
timeout_seconds: int,
):
+ """Send a message to the server.
+
+ :param endpoints: The endpoints to send the message to.
+ :param request: The request containing the details of the message.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the message sending operation.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -171,6 +375,13 @@ class ClientManager:
request: service_pb2.QueryAssignmentRequest,
timeout_seconds: int,
):
+ """Query the assignment information.
+
+ :param endpoints: The endpoints to query.
+ :param request: The request containing the details of the query.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the query.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -183,6 +394,13 @@ class ClientManager:
request: service_pb2.AckMessageRequest,
timeout_seconds: int,
):
+ """Send an acknowledgment for a message to the server.
+
+ :param endpoints: The endpoints to send the acknowledgment to.
+ :param request: The request containing the details of the
acknowledgment.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the acknowledgment.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -195,6 +413,13 @@ class ClientManager:
request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
timeout_seconds: int,
):
+ """Forward a message to the dead letter queue.
+
+ :param endpoints: The endpoints to send the request to.
+ :param request: The request containing the details of the message to
forward.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the forward operation.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -209,6 +434,13 @@ class ClientManager:
request: service_pb2.EndTransactionRequest,
timeout_seconds: int,
):
+ """Ends a transaction.
+
+ :param endpoints: The endpoints to send the request to.
+ :param request: The request to end the transaction.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the end transaction operation.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -221,6 +453,13 @@ class ClientManager:
request: service_pb2.NotifyClientTerminationRequest,
timeout_seconds: int,
):
+ """Notify server about client termination.
+
+ :param endpoints: The endpoints to send the notification to.
+ :param request: The request containing the details of the termination.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the notification operation.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -235,6 +474,13 @@ class ClientManager:
request: service_pb2.ChangeInvisibleDurationRequest,
timeout_seconds: int,
):
+ """Change the invisible duration of a message.
+
+ :param endpoints: The endpoints to send the request to.
+ :param request: The request containing the new invisible duration.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The result of the change operation.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
@@ -248,6 +494,12 @@ class ClientManager:
endpoints: Endpoints,
timeout_seconds: int,
):
+ """Fetch telemetry information.
+
+ :param endpoints: The endpoints to send the request to.
+ :param timeout_seconds: The maximum time to wait for a response.
+ :return: The telemetry information.
+ """
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
diff --git a/python/rocketmq/client_config.py b/python/rocketmq/client_config.py
index 41e691c4..1ccd5e0b 100644
--- a/python/rocketmq/client_config.py
+++ b/python/rocketmq/client_config.py
@@ -18,25 +18,49 @@ from rocketmq.session_credentials import
SessionCredentialsProvider
class ClientConfig:
+ """Client configuration class which holds the settings for a client.
+ The settings include endpoint configurations, session credential provider
and SSL settings.
+ An instance of this class is used to setup the client with necessary
configurations.
+ """
+
def __init__(
self,
endpoints: Endpoints,
session_credentials_provider: SessionCredentialsProvider,
ssl_enabled: bool,
):
+ #: The endpoints for the client to connect to.
self.__endpoints = endpoints
+
+ #: The session credentials provider to authenticate the client.
self.__session_credentials_provider = session_credentials_provider
+
+ #: A flag indicating if SSL is enabled for the client.
self.__ssl_enabled = ssl_enabled
+
+ #: The request timeout for the client in seconds.
self.request_timeout = 10
@property
- def session_credentials_provider(self):
+ def session_credentials_provider(self) -> SessionCredentialsProvider:
+ """The session credentials provider for the client.
+
+ :return: the session credentials provider
+ """
return self.__session_credentials_provider
@property
- def endpoints(self):
+ def endpoints(self) -> Endpoints:
+ """The endpoints for the client to connect to.
+
+ :return: the endpoints
+ """
return self.__endpoints
@property
- def ssl_enabled(self):
+ def ssl_enabled(self) -> bool:
+ """A flag indicating if SSL is enabled for the client.
+
+ :return: True if SSL is enabled, False otherwise
+ """
return self.__ssl_enabled
diff --git a/python/rocketmq/client_id_encoder.py
b/python/rocketmq/client_id_encoder.py
index 138b05f0..b6f3ca3f 100644
--- a/python/rocketmq/client_id_encoder.py
+++ b/python/rocketmq/client_id_encoder.py
@@ -22,12 +22,25 @@ import rocketmq.utils
class ClientIdEncoder:
+ """This class generates a unique client ID for each client based on
+ hostname, process id, index and the monotonic clock time.
+ """
+
+ #: The current index for client id generation.
__INDEX = 0
+
+ #: The lock used for thread-safe incrementing of the index.
__INDEX_LOCK = threading.Lock()
+
+ #: The separator used in the client id string.
__CLIENT_ID_SEPARATOR = "@"
@staticmethod
- def __get_and_increment_sequence():
+ def __get_and_increment_sequence() -> int:
+ """Increment and return the current index in a thread-safe manner.
+
+ :return: the current index after incrementing it.
+ """
with ClientIdEncoder.__INDEX_LOCK:
temp = ClientIdEncoder.__INDEX
ClientIdEncoder.__INDEX += 1
@@ -35,6 +48,10 @@ class ClientIdEncoder:
@staticmethod
def generate() -> str:
+ """Generate a unique client ID.
+
+ :return: the generated client id
+ """
index = ClientIdEncoder.__get_and_increment_sequence()
return (
socket.gethostname()
diff --git a/python/rocketmq/definition.py b/python/rocketmq/definition.py
index b2115b60..3d63748c 100644
--- a/python/rocketmq/definition.py
+++ b/python/rocketmq/definition.py
@@ -17,6 +17,7 @@ from enum import Enum
from typing import List
from protocol.definition_pb2 import Broker as ProtoBroker
+from protocol.definition_pb2 import Encoding as ProtoEncoding
from protocol.definition_pb2 import MessageQueue as ProtoMessageQueue
from protocol.definition_pb2 import MessageType as ProtoMessageType
from protocol.definition_pb2 import Permission as ProtoPermission
@@ -25,20 +26,55 @@ from rocketmq.protocol import definition_pb2
from rocketmq.rpc_client import Endpoints
+class Encoding(Enum):
+ """Enumeration of supported encoding types."""
+ IDENTITY = 0
+ GZIP = 1
+
+
+class EncodingHelper:
+ """Helper class for converting encoding types to protobuf."""
+
+ @staticmethod
+ def to_protobuf(mq_encoding):
+ """Convert encoding type to protobuf.
+
+ :param mq_encoding: The encoding to be converted.
+ :return: The corresponding protobuf encoding.
+ """
+ if mq_encoding == Encoding.IDENTITY:
+ return ProtoEncoding.IDENTITY
+ elif mq_encoding == Encoding.GZIP:
+ return ProtoEncoding.GZIP
+
+
class Broker:
+ """Represent a broker entity."""
+
def __init__(self, broker):
self.name = broker.name
self.id = broker.id
self.endpoints = Endpoints(broker.endpoints)
def to_protobuf(self):
+ """Convert the broker to its protobuf representation.
+
+ :return: The protobuf representation of the broker.
+ """
return ProtoBroker(
Name=self.name, Id=self.id, Endpoints=self.endpoints.to_protobuf()
)
class Resource:
+ """Represent a resource entity."""
+
def __init__(self, name=None, resource=None):
+ """Initialize a resource.
+
+ :param name: The name of the resource.
+ :param resource: The resource object.
+ """
if resource is not None:
self.namespace = resource.ResourceNamespace
self.name = resource.Name
@@ -47,6 +83,10 @@ class Resource:
self.name = name
def to_protobuf(self):
+ """Convert the resource to its protobuf representation.
+
+ :return: The protobuf representation of the resource.
+ """
return ProtoResource(ResourceNamespace=self.namespace, Name=self.name)
def __str__(self):
@@ -54,6 +94,7 @@ class Resource:
class Permission(Enum):
+ """Enumeration of supported permission types."""
NONE = 0
READ = 1
WRITE = 2
@@ -61,8 +102,15 @@ class Permission(Enum):
class PermissionHelper:
+ """Helper class for converting permission types to protobuf and vice
versa."""
+
@staticmethod
def from_protobuf(permission):
+ """Convert protobuf permission to Permission enum.
+
+ :param permission: The protobuf permission to be converted.
+ :return: The corresponding Permission enum.
+ """
if permission == ProtoPermission.READ:
return Permission.READ
elif permission == ProtoPermission.WRITE:
@@ -76,6 +124,11 @@ class PermissionHelper:
@staticmethod
def to_protobuf(permission):
+ """Convert Permission enum to protobuf permission.
+
+ :param permission: The Permission enum to be converted.
+ :return: The corresponding protobuf permission.
+ """
if permission == Permission.READ:
return ProtoPermission.READ
elif permission == Permission.WRITE:
@@ -87,6 +140,11 @@ class PermissionHelper:
@staticmethod
def is_writable(permission):
+ """Check if the permission is writable.
+
+ :param permission: The Permission enum to be checked.
+ :return: True if the permission is writable, False otherwise.
+ """
if permission in [Permission.WRITE, Permission.READ_WRITE]:
return True
else:
@@ -94,6 +152,11 @@ class PermissionHelper:
@staticmethod
def is_readable(permission):
+ """Check if the permission is readable.
+
+ :param permission: The Permission enum to be checked.
+ :return: True if the permission is readable, False otherwise.
+ """
if permission in [Permission.READ, Permission.READ_WRITE]:
return True
else:
@@ -101,6 +164,7 @@ class PermissionHelper:
class MessageType(Enum):
+ """Enumeration of supported message types."""
NORMAL = 0
FIFO = 1
DELAY = 2
@@ -108,8 +172,15 @@ class MessageType(Enum):
class MessageTypeHelper:
+ """Helper class for converting message types to protobuf and vice versa."""
+
@staticmethod
def from_protobuf(message_type):
+ """Convert protobuf message type to MessageType enum.
+
+ :param message_type: The protobuf message type to be converted.
+ :return: The corresponding MessageType enum.
+ """
if message_type == ProtoMessageType.NORMAL:
return MessageType.NORMAL
elif message_type == ProtoMessageType.FIFO:
@@ -123,6 +194,11 @@ class MessageTypeHelper:
@staticmethod
def to_protobuf(message_type):
+ """Convert MessageType enum to protobuf message type.
+
+ :param message_type: The MessageType enum to be converted.
+ :return: The corresponding protobuf message type.
+ """
if message_type == MessageType.NORMAL:
return ProtoMessageType.NORMAL
elif message_type == MessageType.FIFO:
@@ -136,7 +212,13 @@ class MessageTypeHelper:
class MessageQueue:
+ """A class that encapsulates a message queue entity."""
+
def __init__(self, message_queue):
+ """Initialize a MessageQueue instance.
+
+ :param message_queue: The initial message queue to be encapsulated.
+ """
self._topic_resource = Resource(message_queue.topic)
self.queue_id = message_queue.id
self.permission =
PermissionHelper.from_protobuf(message_queue.permission)
@@ -148,12 +230,24 @@ class MessageQueue:
@property
def topic(self):
+ """The topic resource name.
+
+ :return: The name of the topic resource.
+ """
return self._topic_resource.name
def __str__(self):
+ """Get a string representation of the MessageQueue instance.
+
+ :return: A string that represents the MessageQueue instance.
+ """
return f"{self.broker.name}.{self._topic_resource}.{self.queue_id}"
def to_protobuf(self):
+ """Convert the MessageQueue instance to protobuf message queue.
+
+ :return: A protobuf message queue that represents the MessageQueue
instance.
+ """
message_types = [
MessageTypeHelper.to_protobuf(mt) for mt in
self.accept_message_types
]
@@ -167,7 +261,13 @@ class MessageQueue:
class TopicRouteData:
+ """A class that encapsulates a list of message queues."""
+
def __init__(self, message_queues: List[definition_pb2.MessageQueue]):
+ """Initialize a TopicRouteData instance.
+
+ :param message_queues: The initial list of message queues to be
encapsulated.
+ """
message_queue_list = []
for mq in message_queues:
message_queue_list.append(MessageQueue(mq))
@@ -175,4 +275,8 @@ class TopicRouteData:
@property
def message_queues(self) -> List[MessageQueue]:
+ """The list of MessageQueue instances.
+
+ :return: The list of MessageQueue instances that the TopicRouteData
instance encapsulates.
+ """
return self.__message_queue_list
diff --git a/python/rocketmq/exponential_backoff_retry_policy.py
b/python/rocketmq/exponential_backoff_retry_policy.py
new file mode 100644
index 00000000..4dc87cd8
--- /dev/null
+++ b/python/rocketmq/exponential_backoff_retry_policy.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from datetime import timedelta
+
+from google.protobuf.duration_pb2 import Duration
+
+
+class ExponentialBackoffRetryPolicy:
+ """A class implementing exponential backoff retry policy."""
+
+ def __init__(self, max_attempts, initial_backoff, max_backoff,
backoff_multiplier):
+ """Initialize an ExponentialBackoffRetryPolicy instance.
+
+ :param max_attempts: Maximum number of retry attempts.
+ :param initial_backoff: Initial delay duration before the first retry.
+ :param max_backoff: Maximum delay duration between retries.
+ :param backoff_multiplier: Multiplier that determines the delay factor
between retries.
+ """
+ self._max_attempts = max_attempts
+ self.initial_backoff = initial_backoff
+ self.max_backoff = max_backoff
+ self.backoff_multiplier = backoff_multiplier
+
+ def get_max_attempts(self):
+ """Get maximum number of retry attempts.
+
+ :return: Maximum number of retry attempts.
+ """
+ return self._max_attempts
+
+ def inherit_backoff(self, retry_policy):
+ """Inherit backoff parameters from another retry policy.
+
+ :param retry_policy: The retry policy to inherit from.
+ :return: An instance of ExponentialBackoffRetryPolicy with inherited
parameters.
+ :raise ValueError: If the strategy of the retry policy is not
ExponentialBackoff.
+ """
+ if retry_policy.strategy_case != "ExponentialBackoff":
+ raise ValueError("Strategy must be exponential backoff")
+ return self._inherit_backoff(retry_policy.exponential_backoff)
+
+ def _inherit_backoff(self, retry_policy):
+ """Inherit backoff parameters from another retry policy.
+
+ :param retry_policy: The retry policy to inherit from.
+ :return: An instance of ExponentialBackoffRetryPolicy with inherited
parameters.
+ """
+ return ExponentialBackoffRetryPolicy(self._max_attempts,
+
retry_policy.initial.ToTimedelta(),
+ retry_policy.max.ToTimedelta(),
+ retry_policy.multiplier)
+
+ def get_next_attempt_delay(self, attempt):
+ """Calculate the delay before the next retry attempt.
+
+ :param attempt: The number of the current attempt.
+ :return: The delay before the next attempt.
+ """
+ delay_seconds = min(
+ self.initial_backoff.total_seconds() *
math.pow(self.backoff_multiplier, 1.0 * (attempt - 1)),
+ self.max_backoff.total_seconds())
+ return timedelta(seconds=delay_seconds) if delay_seconds >= 0 else
timedelta(seconds=0)
+
+ @staticmethod
+ def immediately_retry_policy(max_attempts):
+ """Create a retry policy that makes immediate retries.
+
+ :param max_attempts: Maximum number of retry attempts.
+ :return: An instance of ExponentialBackoffRetryPolicy with no delay
between retries.
+ """
+ return ExponentialBackoffRetryPolicy(max_attempts,
timedelta(seconds=0), timedelta(seconds=0), 1)
+
+ def to_protobuf(self):
+ """Convert the ExponentialBackoffRetryPolicy instance to protobuf.
+
+ :return: A protobuf message that represents the
ExponentialBackoffRetryPolicy instance.
+ """
+ exponential_backoff = {
+ 'Multiplier': self.backoff_multiplier,
+ 'Max': Duration.FromTimedelta(self.max_backoff),
+ 'Initial': Duration.FromTimedelta(self.initial_backoff)
+ }
+ return {
+ 'MaxAttempts': self._max_attempts,
+ 'ExponentialBackoff': exponential_backoff
+ }
diff --git a/python/rocketmq/producer.py b/python/rocketmq/producer.py
index 3d604a25..9e10a3db 100644
--- a/python/rocketmq/producer.py
+++ b/python/rocketmq/producer.py
@@ -15,31 +15,103 @@
import asyncio
import threading
+import time
+# from status_checker import StatusChecker
+from datetime import datetime, timedelta
+from threading import RLock
from typing import Set
+from unittest.mock import MagicMock, patch
import rocketmq
+from publishing_message import MessageType
from rocketmq.client import Client
from rocketmq.client_config import ClientConfig
-from rocketmq.definition import TopicRouteData
+from rocketmq.definition import PermissionHelper, TopicRouteData
+from rocketmq.exponential_backoff_retry_policy import \
+ ExponentialBackoffRetryPolicy
from rocketmq.log import logger
+from rocketmq.message import Message
from rocketmq.message_id_codec import MessageIdCodec
from rocketmq.protocol.definition_pb2 import Message as ProtoMessage
-from rocketmq.protocol.definition_pb2 import Resource, SystemProperties
-from rocketmq.protocol.service_pb2 import SendMessageRequest
+from rocketmq.protocol.definition_pb2 import Resource
+from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
+from rocketmq.protocol.definition_pb2 import SystemProperties
+from rocketmq.protocol.definition_pb2 import \
+ TransactionResolution as ProtoTransactionResolution
+from rocketmq.protocol.service_pb2 import (EndTransactionRequest,
+ SendMessageRequest)
from rocketmq.publish_settings import PublishingSettings
+from rocketmq.publishing_message import PublishingMessage
from rocketmq.rpc_client import Endpoints
+from rocketmq.send_receipt import SendReceipt
from rocketmq.session_credentials import (SessionCredentials,
SessionCredentialsProvider)
+from status_checker import TooManyRequestsException
+from utils import get_positive_mod
+
+
+class Transaction:
+ MAX_MESSAGE_NUM = 1
+
+ def __init__(self, producer):
+ self.producer = producer
+ self.messages = set()
+ self.messages_lock = RLock()
+ self.message_send_receipt_dict = {}
+
+ def try_add_message(self, message):
+ with self.messages_lock:
+ if len(self.messages) > self.MAX_MESSAGE_NUM:
+ raise ValueError(f"Message in transaction has exceed the
threshold: {self.MAX_MESSAGE_NUM}")
+
+ publishing_message = PublishingMessage(message,
self.producer.publish_settings, True)
+ self.messages.add(publishing_message)
+ return publishing_message
+
+ def try_add_receipt(self, publishing_message, send_receipt):
+ with self.messages_lock:
+ if publishing_message not in self.messages:
+ raise ValueError("Message is not in the transaction")
+
+ self.message_send_receipt_dict[publishing_message] = send_receipt
+
+ async def commit(self):
+ # if self.producer.state != "Running":
+ # raise Exception("Producer is not running")
+
+ if not self.message_send_receipt_dict:
+ raise ValueError("Transactional message has not been sent yet")
+
+ for publishing_message, send_receipt in
self.message_send_receipt_dict.items():
+ await self.producer.end_transaction(send_receipt.endpoints,
publishing_message.message.topic, send_receipt.message_id,
send_receipt.transaction_id, "Commit")
+
+ async def rollback(self):
+ # if self.producer.state != "Running":
+ # raise Exception("Producer is not running")
+
+ if not self.message_send_receipt_dict:
+ raise ValueError("Transactional message has not been sent yet")
+
+ for publishing_message, send_receipt in
self.message_send_receipt_dict.items():
+ await self.producer.end_transaction(send_receipt.endpoints,
publishing_message.message.topic, send_receipt.message_id,
send_receipt.transaction_id, "Rollback")
class PublishingLoadBalancer:
+ """This class serves as a load balancer for message publishing.
+ It keeps track of a rotating index to help distribute the load evenly.
+ """
+
def __init__(self, topic_route_data: TopicRouteData, index: int = 0):
+ #: current index for message queue selection
self.__index = index
+ #: thread lock to ensure atomic update to the index
self.__index_lock = threading.Lock()
+
+ #: filter the message queues which are writable and from the master
broker
message_queues = []
for mq in topic_route_data.message_queues:
if (
- not mq.permission.is_writable()
+ not PermissionHelper().is_writable(mq.permission)
or mq.broker.id is not rocketmq.utils.master_broker_id
):
continue
@@ -48,15 +120,21 @@ class PublishingLoadBalancer:
@property
def index(self):
+ """Property to fetch the current index"""
return self.__index
def get_and_increment_index(self):
+ """Thread safe method to get the current index and increment it by
one"""
with self.__index_lock:
temp = self.__index
self.__index += 1
return temp
def take_message_queues(self, excluded: Set[Endpoints], count: int):
+ """Fetch a specified number of message queues, excluding the ones
provided.
+ It will first try to fetch from non-excluded brokers and if
insufficient,
+ it will select from the excluded ones.
+ """
next_index = self.get_and_increment_index()
candidates = []
candidate_broker_name = set()
@@ -85,41 +163,309 @@ class PublishingLoadBalancer:
return candidates
return candidates
+ def take_message_queue_by_message_group(self, message_group):
+ index = get_positive_mod(hash(message_group),
len(self.__message_queues))
+ return self.__message_queues[index]
+
class Producer(Client):
+ """The Producer class extends the Client class and is used to publish
+ messages to specific topics in RocketMQ.
+ """
+
def __init__(self, client_config: ClientConfig, topics: Set[str]):
+ """Create a new Producer.
+
+ :param client_config: The configuration for the client.
+ :param topics: The set of topics to which the producer can send
messages.
+ """
super().__init__(client_config, topics)
+ retry_policy =
ExponentialBackoffRetryPolicy.immediately_retry_policy(10)
+ #: Set up the publishing settings with the given parameters.
self.publish_settings = PublishingSettings(
- self.client_id, self.endpoints, None, 10, topics
+ self.client_id, self.endpoints, retry_policy, 10, topics
)
+ #: Initialize the routedata cache.
+ self.publish_routedata_cache = {}
async def __aenter__(self):
+ """Provide an asynchronous context manager for the producer."""
await self.start()
async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """Provide an asynchronous context manager for the producer."""
await self.shutdown()
async def start(self):
+ """Start the RocketMQ producer and log the operation."""
logger.info(f"Begin to start the rocketmq producer,
client_id={self.client_id}")
await super().start()
logger.info(f"The rocketmq producer starts successfully,
client_id={self.client_id}")
async def shutdown(self):
+ """Shutdown the RocketMQ producer and log the operation."""
logger.info(f"Begin to shutdown the rocketmq producer,
client_id={self.client_id}")
+ await super().shutdown()
logger.info(f"Shutdown the rocketmq producer successfully,
client_id={self.client_id}")
- async def send_message(self, message):
+ @staticmethod
+ def wrap_send_message_request(message, message_queue):
+ """Wrap the send message request for the RocketMQ producer.
+
+ :param message: The message to be sent.
+ :param message_queue: The queue to which the message will be sent.
+ :return: The SendMessageRequest with the message and queue details.
+ """
req = SendMessageRequest()
- req.messages.extend([message])
- topic_data = self.topic_route_cache["normal_topic"]
- endpoints = topic_data.message_queues[2].broker.endpoints
- return await self.client_manager.send_message(endpoints, req, 10)
+ req.messages.extend([message.to_protobuf(message_queue.queue_id)])
+ return req
+
+ async def send(self, message, transaction: Transaction = None):
+ tx_enabled = True
+ if transaction is None:
+ tx_enabled = False
+ if tx_enabled:
+ logger.debug("Transaction send")
+ publishing_message = transaction.try_add_message(message)
+ send_receipt = await self.send_message(message, tx_enabled)
+ transaction.try_add_receipt(publishing_message, send_receipt)
+ return send_receipt
+ else:
+ return await self.send_message(message)
+
+ async def send_message(self, message, tx_enabled=False):
+ """Send a message using a load balancer, retrying as needed according
to the retry policy.
+
+ :param message: The message to be sent.
+ """
+ publish_load_balancer = await
self.get_publish_load_balancer(message.topic)
+ publishing_message = PublishingMessage(message, self.publish_settings,
tx_enabled)
+ retry_policy = self.get_retry_policy()
+ max_attempts = retry_policy.get_max_attempts()
+
+ exception = None
+ logger.debug(publishing_message.message.message_group)
+ candidates = (
+
publish_load_balancer.take_message_queues(set(self.isolated.keys()),
max_attempts)
+ if publishing_message.message.message_group is None else
+
[publish_load_balancer.take_message_queue_by_message_group(publishing_message.message.message_group)])
+ for attempt in range(1, max_attempts + 1):
+ start_time = time.time()
+ candidate_index = (attempt - 1) % len(candidates)
+ mq = candidates[candidate_index]
+ logger.debug(mq.accept_message_types)
+ if self.publish_settings.is_validate_message_type() and
publishing_message.message_type.value != mq.accept_message_types[0].value:
+ raise ValueError(
+ "Current message type does not match with the accept
message types,"
+ + f" topic={message.topic},
actualMessageType={publishing_message.message_type}"
+ + f" acceptMessageType={','}")
+
+ send_message_request =
self.wrap_send_message_request(publishing_message, mq)
+ # topic_data = self.topic_route_cache["normal_topic"]
+ endpoints = mq.broker.endpoints
+
+ try:
+ invocation = await self.client_manager.send_message(endpoints,
send_message_request, self.client_config.request_timeout)
+ logger.debug(invocation)
+ send_recepits = SendReceipt.process_send_message_response(mq,
invocation)
+ send_recepit = send_recepits[0]
+ if attempt > 1:
+ logger.info(
+ f"Re-send message successfully, topic={message.topic},"
+ + f" max_attempts={max_attempts},
endpoints={str(endpoints)}, clientId={self.client_id}")
+ return send_recepit
+ except Exception as e:
+ exception = e
+ self.isolated[endpoints] = True
+ if attempt >= max_attempts:
+ logger.error("Failed to send message finally, run out of
attempt times, "
+ + f"topic={message.topic},
maxAttempt={max_attempts}, attempt={attempt}, "
+ + f"endpoints={endpoints},
messageId={publishing_message.message_id}, clientId={self.client_id}")
+ raise
+ if publishing_message.message_type == MessageType.TRANSACTION:
+ logger.error("Failed to send transaction message, run out
of attempt times, "
+ + f"topic={message.topic}, maxAttempt=1,
attempt={attempt}, "
+ + f"endpoints={endpoints},
messageId={publishing_message.message_id}, clientId={self.client_id}")
+ raise
+ if not isinstance(exception, TooManyRequestsException):
+ logger.error(f"Failed to send message,
topic={message.topic}, max_attempts={max_attempts}, "
+ + f"attempt={attempt}, endpoints={endpoints},
messageId={publishing_message.message_id},"
+ + f" clientId={self.client_id}")
+ continue
+
+ nextAttempt = 1 + attempt
+ delay = retry_policy.get_next_attempt_delay(nextAttempt)
+ await asyncio.sleep(delay.total_seconds())
+ logger.warning(f"Failed to send message due to too many
requests, would attempt to resend after {delay},\
+ topic={message.topic},
max_attempts={max_attempts}, attempt={attempt}, endpoints={endpoints},\
+ message_id={publishing_message.message_id},
client_id={self.client_id}")
+ finally:
+ elapsed_time = time.time() - start_time
+ logger.info(f"send time: {elapsed_time}")
+
+ def update_publish_load_balancer(self, topic, topic_route_data):
+ """Update the load balancer used for publishing messages to a topic.
+
+ :param topic: The topic for which to update the load balancer.
+ :param topic_route_data: The new route data for the topic.
+ :return: The updated load balancer.
+ """
+ publishing_load_balancer = None
+ if topic in self.publish_routedata_cache:
+ publishing_load_balancer = self.publish_routedata_cache[topic]
+ else:
+ publishing_load_balancer = PublishingLoadBalancer(topic_route_data)
+ self.publish_routedata_cache[topic] = publishing_load_balancer
+ return publishing_load_balancer
+
+ async def get_publish_load_balancer(self, topic):
+ """Get the load balancer used for publishing messages to a topic.
+
+ :param topic: The topic for which to get the load balancer.
+ :return: The load balancer for the topic.
+ """
+ if topic in self.publish_routedata_cache:
+ return self.publish_routedata_cache[topic]
+ topic_route_data = await self.get_route_data(topic)
+ return self.update_publish_load_balancer(topic, topic_route_data)
def get_settings(self):
+ """Get the publishing settings for this producer.
+
+ :return: The publishing settings for this producer.
+ """
return self.publish_settings
+ def get_retry_policy(self):
+ """Get the retry policy for this producer.
+
+ :return: The retry policy for this producer.
+ """
+ return self.publish_settings.GetRetryPolicy()
+
+ def begin_transaction(self):
+ """Start a new transaction."""
+ return Transaction(self)
+
+ async def end_transaction(self, endpoints, topic, message_id,
transaction_id, resolution):
+ """End a transaction based on its resolution (commit or rollback)."""
+ topic_resource = ProtoResource(name=topic)
+ request = EndTransactionRequest(
+ transaction_id=transaction_id,
+ message_id=message_id,
+ topic=topic_resource,
+ resolution=ProtoTransactionResolution.COMMIT if resolution ==
"Commit" else ProtoTransactionResolution.ROLLBACK
+ )
+ await self.client_manager.end_transaction(endpoints, request,
self.client_config.request_timeout)
+ # StatusChecker.check(invocation.response.status, request,
invocation.request_id)
+
async def test():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+
endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "normal_topic"
+ msg = ProtoMessage()
+ msg.topic.CopyFrom(topic)
+ msg.body = b"My Normal Message Body"
+ sysperf = SystemProperties()
+ sysperf.message_id = MessageIdCodec.next_message_id()
+ msg.system_properties.CopyFrom(sysperf)
+ producer = Producer(client_config, topics={"normal_topic"})
+ message = Message(topic.name, msg.body)
+ await producer.start()
+ await asyncio.sleep(10)
+ send_receipt = await producer.send(message)
+ logger.info(f"Send message successfully, {send_receipt}")
+
+
+async def test_delay_message():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+
endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "delay_topic"
+ msg = ProtoMessage()
+ msg.topic.CopyFrom(topic)
+ msg.body = b"My Delay Message Body"
+ sysperf = SystemProperties()
+ sysperf.message_id = MessageIdCodec.next_message_id()
+ msg.system_properties.CopyFrom(sysperf)
+ logger.debug(f"{msg}")
+ producer = Producer(client_config, topics={"delay_topic"})
+ current_time_millis = int(round(time.time() * 1000))
+ message_delay_time = timedelta(seconds=10)
+ result_time_millis = current_time_millis +
int(message_delay_time.total_seconds() * 1000)
+ result_time_datetime = datetime.fromtimestamp(result_time_millis / 1000.0)
+ message = Message(topic.name, msg.body,
delivery_timestamp=result_time_datetime)
+ await producer.start()
+ await asyncio.sleep(10)
+ send_receipt = await producer.send(message)
+ logger.info(f"Send message successfully, {send_receipt}")
+
+
+async def test_fifo_message():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+
endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "fifo_topic"
+ msg = ProtoMessage()
+ msg.topic.CopyFrom(topic)
+ msg.body = b"My FIFO Message Body"
+ sysperf = SystemProperties()
+ sysperf.message_id = MessageIdCodec.next_message_id()
+ msg.system_properties.CopyFrom(sysperf)
+ logger.debug(f"{msg}")
+ producer = Producer(client_config, topics={"fifo_topic"})
+ message = Message(topic.name, msg.body, message_group="yourMessageGroup")
+ await producer.start()
+ await asyncio.sleep(10)
+ send_receipt = await producer.send(message)
+ logger.info(f"Send message successfully, {send_receipt}")
+
+
+async def test_transaction_message():
+ credentials = SessionCredentials("username", "password")
+ credentials_provider = SessionCredentialsProvider(credentials)
+ client_config = ClientConfig(
+
endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ session_credentials_provider=credentials_provider,
+ ssl_enabled=True,
+ )
+ topic = Resource()
+ topic.name = "transaction_topic"
+ msg = ProtoMessage()
+ msg.topic.CopyFrom(topic)
+ msg.body = b"My Transaction Message Body"
+ sysperf = SystemProperties()
+ sysperf.message_id = MessageIdCodec.next_message_id()
+ msg.system_properties.CopyFrom(sysperf)
+ logger.debug(f"{msg}")
+ producer = Producer(client_config, topics={"transaction_topic"})
+ message = Message(topic.name, msg.body)
+ await producer.start()
+ # await asyncio.sleep(10)
+ transaction = producer.begin_transaction()
+ send_receipt = await producer.send(message, transaction)
+ logger.info(f"Send message successfully, {send_receipt}")
+ await transaction.commit()
+
+
+async def test_retry_and_isolation():
credentials = SessionCredentials("username", "password")
credentials_provider = SessionCredentialsProvider(credentials)
client_config = ClientConfig(
@@ -137,10 +483,25 @@ async def test():
msg.system_properties.CopyFrom(sysperf)
logger.info(f"{msg}")
producer = Producer(client_config, topics={"normal_topic"})
- await producer.start()
- result = await producer.send_message(msg)
- print(result)
+ message = Message(topic.name, msg.body)
+ with patch.object(producer.client_manager, 'send_message',
new_callable=MagicMock) as mock_send:
+ mock_send.side_effect = Exception("Forced Exception for Testing")
+ await producer.start()
+
+ try:
+ await producer.send(message)
+ except Exception:
+ logger.info("Exception occurred as expected")
+
+ assert mock_send.call_count ==
producer.get_retry_policy().get_max_attempts(), "Number of attempts should
equal max_attempts."
+ logger.debug(producer.isolated)
+ assert producer.isolated, "Endpoint should be marked as isolated after
an error."
+ logger.info("Test completed successfully.")
if __name__ == "__main__":
asyncio.run(test())
+ asyncio.run(test_delay_message())
+ asyncio.run(test_fifo_message())
+ asyncio.run(test_transaction_message())
+ asyncio.run(test_retry_and_isolation())
diff --git a/python/rocketmq/publish_settings.py
b/python/rocketmq/publish_settings.py
index c629d514..4f09cb7c 100644
--- a/python/rocketmq/publish_settings.py
+++ b/python/rocketmq/publish_settings.py
@@ -17,13 +17,14 @@ import platform
import socket
from typing import Dict
+from rocketmq.exponential_backoff_retry_policy import \
+ ExponentialBackoffRetryPolicy
from rocketmq.protocol.definition_pb2 import UA
from rocketmq.protocol.definition_pb2 import Publishing as ProtoPublishing
from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
from rocketmq.protocol.definition_pb2 import Settings as ProtoSettings
from rocketmq.rpc_client import Endpoints
-from rocketmq.settings import (ClientType, ClientTypeHelper, IRetryPolicy,
- Settings)
+from rocketmq.settings import ClientType, ClientTypeHelper, Settings
from rocketmq.signature import Signature
@@ -44,7 +45,7 @@ class PublishingSettings(Settings):
self,
client_id: str,
endpoints: Endpoints,
- retry_policy: IRetryPolicy,
+ retry_policy: ExponentialBackoffRetryPolicy,
request_timeout: int,
topics: Dict[str, bool],
):
diff --git a/python/rocketmq/publishing_message.py
b/python/rocketmq/publishing_message.py
new file mode 100644
index 00000000..195399cb
--- /dev/null
+++ b/python/rocketmq/publishing_message.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import socket
+
+from definition import Encoding, EncodingHelper, MessageType, MessageTypeHelper
+from google.protobuf.timestamp_pb2 import Timestamp
+from message import Message
+from message_id_codec import MessageIdCodec
+from protocol.definition_pb2 import Message as ProtoMessage
+from protocol.definition_pb2 import Resource, SystemProperties
+from rocketmq.log import logger
+
+
+class PublishingMessage(Message):
+ def __init__(self, message, publishing_settings, tx_enabled=False):
+ self.message = message
+ self.publishing_settings = publishing_settings
+ self.tx_enabled = tx_enabled
+ self.message_type = None
+
+ max_body_size_bytes = publishing_settings.get_max_body_size_bytes()
+ if len(message.body) > max_body_size_bytes:
+ raise IOError(f"Message body size exceed the threshold, max
size={max_body_size_bytes} bytes")
+
+ self.message_id = MessageIdCodec.next_message_id()
+
+ if not message.message_group and not message.delivery_timestamp and
not tx_enabled:
+ self.message_type = MessageType.NORMAL
+ return
+
+ if message.message_group and not tx_enabled:
+ self.message_type = MessageType.FIFO
+ return
+
+ if message.delivery_timestamp and not tx_enabled:
+ self.message_type = MessageType.DELAY
+ return
+
+ if message.message_group or message.delivery_timestamp or not
tx_enabled:
+ pass
+
+ self.message_type = MessageType.TRANSACTION
+ logger.debug(self.message_type)
+
+ def to_protobuf(self, queue_id):
+ system_properties = SystemProperties(
+ keys=self.message.keys,
+ message_id=self.message_id,
+ #
born_timestamp=Timestamp.FromDatetime(dt=datetime.datetime.utcnow()),
+ born_host=socket.gethostname(),
+ body_encoding=EncodingHelper.to_protobuf(Encoding.IDENTITY),
+ queue_id=queue_id,
+ message_type=MessageTypeHelper.to_protobuf(self.message_type)
+ )
+ if self.message.tag:
+ system_properties.tag = self.message.tag
+
+ if self.message.delivery_timestamp:
+ timestamp = Timestamp()
+ timestamp.FromDatetime(self.message.delivery_timestamp)
+ system_properties.delivery_timestamp.CopyFrom(timestamp)
+
+ if self.message.message_group:
+ system_properties.message_group = self.message.message_group
+
+ topic_resource = Resource(name=self.message.topic)
+
+ return ProtoMessage(
+ topic=topic_resource,
+ body=self.message.body,
+ system_properties=system_properties,
+ user_properties=self.message.properties
+ )
diff --git a/python/rocketmq/rpc_client.py b/python/rocketmq/rpc_client.py
index c1f129c4..6c1107ab 100644
--- a/python/rocketmq/rpc_client.py
+++ b/python/rocketmq/rpc_client.py
@@ -132,7 +132,7 @@ class Endpoints:
def __str__(self):
for address in self.Addresses:
- return None
+ return str(address.host) + str(address.port)
def grpc_target(self, sslEnabled):
for address in self.Addresses:
diff --git a/python/rocketmq/send_receipt.py b/python/rocketmq/send_receipt.py
index 3aaa5978..8e742da2 100644
--- a/python/rocketmq/send_receipt.py
+++ b/python/rocketmq/send_receipt.py
@@ -13,13 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# from rocketmq.status_checker import StatusChecker
+from rocketmq.log import logger
from rocketmq.message_id import MessageId
+from rocketmq.protocol.definition_pb2 import Code as ProtoCode
class SendReceipt:
- def __init__(self, message_id: MessageId):
- self.__message_id = message_id
+ def __init__(self, message_id: MessageId, transaction_id, message_queue):
+ self.message_id = message_id
+ self.transaction_id = transaction_id
+ self.message_queue = message_queue
@property
- def message_id(self):
- return self.__message_id
+ def endpoints(self):
+ return self.message_queue.broker.endpoints
+
+ def __str__(self):
+ return f'MessageId: {self.message_id}'
+
+ @staticmethod
+ def process_send_message_response(mq, invocation):
+ status = invocation.status
+ for entry in invocation.entries:
+ if entry.status.code == ProtoCode.OK:
+ status = entry.status
+ logger.debug(status)
+ # May throw exception.
+ # StatusChecker.check(status, invocation.request,
invocation.request_id)
+ return [SendReceipt(entry.message_id, entry.transaction_id, mq) for
entry in invocation.entries]
diff --git a/python/rocketmq/session.py b/python/rocketmq/session.py
index 50c11d98..c3ea9d32 100644
--- a/python/rocketmq/session.py
+++ b/python/rocketmq/session.py
@@ -14,8 +14,8 @@
# limitations under the License.
import asyncio
-from threading import Event
+from rocketmq.log import logger
from rocketmq.protocol.service_pb2 import \
TelemetryCommand as ProtoTelemetryCommand
@@ -26,12 +26,25 @@ class Session:
self._semaphore = asyncio.Semaphore(1)
self._streaming_call = streaming_call
self._client = client
- self._event = Event()
+ asyncio.create_task(self.loop())
+
+ async def loop(self):
+ try:
+ while True:
+ await self._streaming_call.read()
+ except asyncio.exceptions.InvalidStateError as e:
+ logger.error('Error:', e)
async def write_async(self, telemetry_command: ProtoTelemetryCommand):
- await self._streaming_call.write(telemetry_command)
- response = await self._streaming_call.read()
- print(response)
+ await asyncio.sleep(1)
+ try:
+ await self._streaming_call.write(telemetry_command)
+ # TODO handle read operation exceed the time limit
+ # await asyncio.wait_for(self._streaming_call.read(), timeout=5)
+ except asyncio.exceptions.InvalidStateError as e:
+ self.on_error(e)
+ except asyncio.TimeoutError:
+ logger.error('Timeout: The read operation exceeded the time limit')
async def sync_settings(self, await_resp):
await self._semaphore.acquire()
@@ -42,3 +55,20 @@ class Session:
await self.write_async(telemetry_command)
finally:
self._semaphore.release()
+
+ def rebuild_telemetry(self):
+ logger.info("Try to rebuild telemetry")
+ stream = self._client.client_manager.telemetry(self._endpoints, 10)
+ self._streaming_call = stream
+
+ def on_error(self, exception):
+ client_id = self._client.get_client_id()
+ logger.error("Caught InvalidStateError: RPC already finished.")
+ logger.error(f"Exception raised from stream, clientId={client_id},
endpoints={self._endpoints}", exception)
+ max_retry = 3
+ for i in range(max_retry):
+ try:
+ self.rebuild_telemetry()
+ break
+ except Exception as e:
+ logger.error(f"An error occurred during rebuilding telemetry:
{e}, attempt {i + 1} of {max_retry}")
diff --git a/python/rocketmq/status_checker.py
b/python/rocketmq/status_checker.py
new file mode 100644
index 00000000..ae2d1913
--- /dev/null
+++ b/python/rocketmq/status_checker.py
@@ -0,0 +1,212 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from rocketmq.log import logger
+from rocketmq.message import Message
+from rocketmq.protocol.definition_pb2 import Code as ProtoCode
+from rocketmq.protocol.definition_pb2 import Message as ProtoMessage
+from rocketmq.protocol.definition_pb2 import Status as ProtoStatus
+from rocketmq.protocol.service_pb2 import \
+ ReceiveMessageRequest as ProtoReceiveMessageRequest
+
+
+class RocketMQException(Exception):
+ def __init__(self, status_code, request_id, status_message):
+ self.status_code = status_code
+ self.request_id = request_id
+ self.status_message = status_message
+
+ def __str__(self):
+ return f"{self.__class__.__name__}: code={self.status_code},
requestId={self.request_id}, message={self.status_message}"
+
+
+class BadRequestException(RocketMQException):
+ pass
+
+
+class UnauthorizedException(RocketMQException):
+ pass
+
+
+class PaymentRequiredException(RocketMQException):
+ pass
+
+
+class ForbiddenException(RocketMQException):
+ pass
+
+
+class NotFoundException(RocketMQException):
+ pass
+
+
+class PayloadTooLargeException(RocketMQException):
+ pass
+
+
+class TooManyRequestsException(RocketMQException):
+ pass
+
+
+class RequestHeaderFieldsTooLargeException(RocketMQException):
+ pass
+
+
+class InternalErrorException(RocketMQException):
+ pass
+
+
+class ProxyTimeoutException(RocketMQException):
+ pass
+
+
+class UnsupportedException(RocketMQException):
+ pass
+
+
+class StatusChecker:
+ @staticmethod
+ def check(status: ProtoStatus, request: Message, request_id: str):
+ """Check the status of a request and raise an exception if necessary.
+
+ :param status: A ProtoStatus object that contains the status code and
message.
+ :param request: The request message object.
+ :param request_id: The ID of the request.
+ :raise BadRequestException: If the status code indicates a bad request.
+ :raise UnauthorizedException: If the status code indicates an
unauthorized request.
+ :raise PaymentRequiredException: If the status code indicates payment
is required.
+ :raise ForbiddenException: If the status code indicates a forbidden
request.
+ :raise NotFoundException: If the status code indicates a resource is
not found.
+ :raise PayloadTooLargeException: If the status code indicates the
request payload is too large.
+ :raise TooManyRequestsException: If the status code indicates too many
requests.
+ :raise RequestHeaderFieldsTooLargeException: If the status code
indicates the request headers are too large.
+ :raise InternalErrorException: If the status code indicates an
internal error.
+ :raise ProxyTimeoutException: If the status code indicates a proxy
timeout.
+ :raise UnsupportedException: If the status code indicates an
unsupported operation.
+ """
+ status_code = status.code
+ status_message = status.message
+
+ if status_code in [ProtoCode.OK, ProtoCode.MULTIPLE_RESULTS]:
+ return
+ elif status_code in [
+ ProtoCode.BAD_REQUEST,
+ ProtoCode.ILLEGAL_ACCESS_POINT,
+ ProtoCode.ILLEGAL_TOPIC,
+ ProtoCode.ILLEGAL_CONSUMER_GROUP,
+ ProtoCode.ILLEGAL_MESSAGE_TAG,
+ ProtoCode.ILLEGAL_MESSAGE_KEY,
+ ProtoCode.ILLEGAL_MESSAGE_GROUP,
+ ProtoCode.ILLEGAL_MESSAGE_PROPERTY_KEY,
+ ProtoCode.INVALID_TRANSACTION_ID,
+ ProtoCode.ILLEGAL_MESSAGE_ID,
+ ProtoCode.ILLEGAL_FILTER_EXPRESSION,
+ ProtoCode.ILLEGAL_INVISIBLE_TIME,
+ ProtoCode.ILLEGAL_DELIVERY_TIME,
+ ProtoCode.INVALID_RECEIPT_HANDLE,
+ ProtoCode.MESSAGE_PROPERTY_CONFLICT_WITH_TYPE,
+ ProtoCode.UNRECOGNIZED_CLIENT_TYPE,
+ ProtoCode.MESSAGE_CORRUPTED,
+ ProtoCode.CLIENT_ID_REQUIRED,
+ ProtoCode.ILLEGAL_POLLING_TIME,
+ ]:
+ raise BadRequestException(status_code, request_id, status_message)
+ elif status_code == ProtoCode.UNAUTHORIZED:
+ raise UnauthorizedException(status_code, request_id,
status_message)
+ elif status_code == ProtoCode.PAYMENT_REQUIRED:
+ raise PaymentRequiredException(status_code, request_id,
status_message)
+ elif status_code == ProtoCode.FORBIDDEN:
+ raise ForbiddenException(status_code, request_id, status_message)
+ elif status_code == ProtoCode.MESSAGE_NOT_FOUND:
+ if isinstance(request, ProtoReceiveMessageRequest):
+ return
+ else:
+ # Fall through on purpose.
+ status_code = ProtoCode.NOT_FOUND
+ if status_code in [
+ ProtoCode.NOT_FOUND,
+ ProtoCode.TOPIC_NOT_FOUND,
+ ProtoCode.CONSUMER_GROUP_NOT_FOUND,
+ ]:
+ raise NotFoundException(status_code, request_id, status_message)
+ elif status_code in [
+ ProtoCode.PAYLOAD_TOO_LARGE,
+ ProtoCode.MESSAGE_BODY_TOO_LARGE,
+ ]:
+ raise PayloadTooLargeException(status_code, request_id,
status_message)
+ elif status_code == ProtoCode.TOO_MANY_REQUESTS:
+ raise TooManyRequestsException(status_code, request_id,
status_message)
+ elif status_code in [
+ ProtoCode.REQUEST_HEADER_FIELDS_TOO_LARGE,
+ ProtoCode.MESSAGE_PROPERTIES_TOO_LARGE,
+ ]:
+ raise RequestHeaderFieldsTooLargeException(status_code,
request_id, status_message)
+ elif status_code in [
+ ProtoCode.INTERNAL_ERROR,
+ ProtoCode.INTERNAL_SERVER_ERROR,
+ ProtoCode.HA_NOT_AVAILABLE,
+ ]:
+ raise InternalErrorException(status_code, request_id,
status_message)
+ elif status_code in [
+ ProtoCode.PROXY_TIMEOUT,
+ ProtoCode.MASTER_PERSISTENCE_TIMEOUT,
+ ProtoCode.SLAVE_PERSISTENCE_TIMEOUT,
+ ]:
+ raise ProxyTimeoutException(status_code, request_id,
status_message)
+ elif status_code in [
+ ProtoCode.UNSUPPORTED,
+ ProtoCode.VERSION_UNSUPPORTED,
+ ProtoCode.VERIFY_FIFO_MESSAGE_UNSUPPORTED,
+ ]:
+ raise UnsupportedException(status_code, request_id, status_message)
+ else:
+ logger.warning(f"Unrecognized status code={status_code},
requestId={request_id}, statusMessage={status_message}")
+ raise UnsupportedException(status_code, request_id, status_message)
+
+
+def main():
+ # 创建一个表示'OK'状态的ProtoStatus
+ status_ok = ProtoStatus()
+ status_ok.code = ProtoCode.OK
+ status_ok.message = "Everything is OK"
+
+ # 创建一个表示'BadRequest'状态的ProtoStatus
+ status_bad_request = ProtoStatus()
+ status_bad_request.code = ProtoCode.BAD_REQUEST
+ status_bad_request.message = "Bad request"
+
+ # 创建一个表示'Unauthorized'状态的ProtoStatus
+ status_unauthorized = ProtoStatus()
+ status_unauthorized.code = ProtoCode.UNAUTHORIZED
+ status_unauthorized.message = "Unauthorized"
+
+ request = ProtoMessage()
+
+ # 进行一些测试
+ StatusChecker.check(status_ok, request, "request1") # 不应抛出异常
+
+ try:
+ StatusChecker.check(status_bad_request, request, "request2")
+ except BadRequestException as e:
+ logger.error(f"Caught expected exception: {e}")
+
+ try:
+ StatusChecker.check(status_unauthorized, request, "request3")
+ except UnauthorizedException as e:
+ logger.error(f"Caught expected exception: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/rocketmq/utils.py b/python/rocketmq/utils.py
index dd2a4a98..e5cf3b27 100644
--- a/python/rocketmq/utils.py
+++ b/python/rocketmq/utils.py
@@ -39,3 +39,8 @@ def sign(access_secret: str, datetime: str) -> str:
hashlib.sha1,
)
return digester.hexdigest().upper()
+
+
+def get_positive_mod(k: int, n: int):
+ result = k % n
+ return result + n if result < 0 else result
diff --git a/python/tests/test_foo.py b/python/tests/test_foo.py
index 70b00f6a..89b9ea31 100644
--- a/python/tests/test_foo.py
+++ b/python/tests/test_foo.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from rocketmq import foo, logger
+from rocketmq import foo
+from rocketmq.log import logger
def test_passing():