This is an automated email from the ASF dual-hosted git repository.

baodi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-python.git


The following commit(s) were added to refs/heads/main by this push:
     new d2fac8f  Fetch writer schema to decode Avro messages (#119)
d2fac8f is described below

commit d2fac8fb8bbaaaa7d134abc35e7f1f8f89f615be
Author: Yunze Xu <[email protected]>
AuthorDate: Thu May 25 09:43:49 2023 +0800

    Fetch writer schema to decode Avro messages (#119)
    
    Fixes https://github.com/apache/pulsar-client-python/issues/108
    
    ### Motivation
    
    Currently the Python client uses the reader schema, which is the schema
    of the consumer, to decode Avro messages. However, when the writer
    schema is different from the reader schema, the decode will fail.
    
    ### Modifications
    
    Add `attach_client` method to `Schema` and call it when creating
    consumers and readers. This method stores a reference to a
    `_pulsar.Client` instance, which leverages the C++ APIs added in
    https://github.com/apache/pulsar-client-cpp/pull/257 to fetch schema
    info. The `AvroSchema` class fetches and caches the writer schema if it
    is not cached, then use both the writer schema and reader schema to
    decode messages.
    
    Add `test_schema_evolve` to test consumers or readers can decode
    any message whose writer schema is different with the reader schema.
---
 pulsar/__init__.py           |  4 ++-
 pulsar/schema/schema.py      |  6 +++++
 pulsar/schema/schema_avro.py | 45 ++++++++++++++++++++++++++++++++-
 src/client.cc                |  7 ++++++
 src/message.cc               |  1 +
 tests/schema_test.py         | 59 ++++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 120 insertions(+), 2 deletions(-)

diff --git a/pulsar/__init__.py b/pulsar/__init__.py
index c85c6e3..843274b 100644
--- a/pulsar/__init__.py
+++ b/pulsar/__init__.py
@@ -127,7 +127,7 @@ class Message:
         """
         Returns object with the de-serialized version of the message content
         """
-        return self._schema.decode(self._message.data())
+        return self._schema.decode_message(self._message)
 
     def properties(self):
         """
@@ -841,6 +841,7 @@ class Client:
 
         c._client = self
         c._schema = schema
+        c._schema.attach_client(self._client)
         self._consumers.append(c)
         return c
 
@@ -942,6 +943,7 @@ class Client:
         c._reader = self._client.create_reader(topic, start_message_id, conf)
         c._client = self
         c._schema = schema
+        c._schema.attach_client(self._client)
         self._consumers.append(c)
         return c
 
diff --git a/pulsar/schema/schema.py b/pulsar/schema/schema.py
index f062c2e..b50a1fe 100644
--- a/pulsar/schema/schema.py
+++ b/pulsar/schema/schema.py
@@ -38,9 +38,15 @@ class Schema(object):
     def decode(self, data):
         pass
 
+    def decode_message(self, msg: _pulsar.Message):
+        return self.decode(msg.data())
+
     def schema_info(self):
         return self._schema_info
 
+    def attach_client(self, client: _pulsar.Client):
+        self._client = client
+
     def _validate_object_type(self, obj):
         if not isinstance(obj, self._record_cls):
             raise TypeError('Invalid record obj of type ' + str(type(obj))
diff --git a/pulsar/schema/schema_avro.py b/pulsar/schema/schema_avro.py
index 3e629fb..70fda98 100644
--- a/pulsar/schema/schema_avro.py
+++ b/pulsar/schema/schema_avro.py
@@ -19,6 +19,8 @@
 
 import _pulsar
 import io
+import json
+import logging
 import enum
 
 from . import Record
@@ -40,6 +42,8 @@ if HAS_AVRO:
                 self._schema = record_cls.schema()
             else:
                 self._schema = schema_definition
+            self._writer_schemas = dict()
+            self._logger = logging.getLogger()
             super(AvroSchema, self).__init__(record_cls, 
_pulsar.SchemaType.AVRO, self._schema, 'AVRO')
 
         def _get_serialized_value(self, x):
@@ -76,8 +80,47 @@ if HAS_AVRO:
             return obj
 
         def decode(self, data):
+            return self._decode_bytes(data, self._schema)
+
+        def decode_message(self, msg: _pulsar.Message):
+            if self._client is None:
+                return self.decode(msg.data())
+            topic = msg.topic_name()
+            version = msg.int_schema_version()
+            try:
+                writer_schema = self._get_writer_schema(topic, version)
+                return self._decode_bytes(msg.data(), writer_schema)
+            except Exception as e:
+                self._logger.error('Failed to get schema info of {topic} 
version {version}: {e}')
+                return self._decode_bytes(msg.data(), self._schema)
+
+        def _get_writer_schema(self, topic: str, version: int) -> 'dict':
+            if self._writer_schemas.get(topic) is None:
+                self._writer_schemas[topic] = dict()
+            writer_schema = self._writer_schemas[topic].get(version)
+            if writer_schema is not None:
+                return writer_schema
+            if self._client is None:
+                return self._schema
+
+            self._logger.info('Downloading schema of %s version %d...', topic, 
version)
+            info = self._client.get_schema_info(topic, version)
+            self._logger.info('Downloaded schema of %s version %d', topic, 
version)
+            if info.schema_type() != _pulsar.SchemaType.AVRO:
+                raise RuntimeError(f'The schema type of topic "{topic}" and 
version {version}'
+                                   f' is {info.schema_type()}')
+            writer_schema = json.loads(info.schema())
+            self._writer_schemas[topic][version] = writer_schema
+            return writer_schema
+
+        def _decode_bytes(self, data: bytes, writer_schema: dict):
             buffer = io.BytesIO(data)
-            d = fastavro.schemaless_reader(buffer, self._schema)
+            # If the record names are different between the writer schema and 
the reader schema,
+            # schemaless_reader will fail with 
fastavro._read_common.SchemaResolutionError.
+            # So we make the record name fields consistent here.
+            reader_schema: dict = self._schema
+            writer_schema['name'] = reader_schema['name']
+            d = fastavro.schemaless_reader(buffer, writer_schema, 
reader_schema)
             if self._record_cls is not None:
                 return self._record_cls(**d)
             else:
diff --git a/src/client.cc b/src/client.cc
index 0103309..626ff9f 100644
--- a/src/client.cc
+++ b/src/client.cc
@@ -58,6 +58,12 @@ std::vector<std::string> Client_getTopicPartitions(Client& 
client, const std::st
         [&](GetPartitionsCallback callback) { 
client.getPartitionsForTopicAsync(topic, callback); });
 }
 
+SchemaInfo Client_getSchemaInfo(Client& client, const std::string& topic, 
int64_t version) {
+    return waitForAsyncValue<SchemaInfo>([&](std::function<void(Result, const 
SchemaInfo&)> callback) {
+        client.getSchemaInfoAsync(topic, version, callback);
+    });
+}
+
 void Client_close(Client& client) {
     waitForAsyncResult([&](ResultCallback callback) { 
client.closeAsync(callback); });
 }
@@ -71,6 +77,7 @@ void export_client(py::module_& m) {
         .def("subscribe_pattern", &Client_subscribe_pattern)
         .def("create_reader", &Client_createReader)
         .def("get_topic_partitions", &Client_getTopicPartitions)
+        .def("get_schema_info", &Client_getSchemaInfo)
         .def("close", &Client_close)
         .def("shutdown", &Client::shutdown);
 }
diff --git a/src/message.cc b/src/message.cc
index 6e8dd3f..895209f 100644
--- a/src/message.cc
+++ b/src/message.cc
@@ -98,6 +98,7 @@ void export_message(py::module_& m) {
              })
         .def("topic_name", &Message::getTopicName, return_value_policy::copy)
         .def("redelivery_count", &Message::getRedeliveryCount)
+        .def("int_schema_version", &Message::getLongSchemaVersion)
         .def("schema_version", &Message::getSchemaVersion, 
return_value_policy::copy);
 
     MessageBatch& (MessageBatch::*MessageBatchParseFromString)(const 
std::string& payload,
diff --git a/tests/schema_test.py b/tests/schema_test.py
index 47acc30..3e6e9c6 100755
--- a/tests/schema_test.py
+++ b/tests/schema_test.py
@@ -18,6 +18,10 @@
 # under the License.
 #
 
+import math
+import logging
+import requests
+from typing import List
 from unittest import TestCase, main
 
 import fastavro
@@ -27,6 +31,9 @@ from enum import Enum
 import json
 from fastavro.schema import load_schema
 
+logging.basicConfig(level=logging.INFO,
+                    format='%(asctime)s %(levelname)-5s %(message)s')
+
 
 class SchemaTest(TestCase):
 
@@ -1287,5 +1294,57 @@ class SchemaTest(TestCase):
         with self.assertRaises(TypeError) as e:
             SomeSchema(some_field=["not", "integer"])
         self.assertEqual(str(e.exception), "Array field some_field items 
should all be of type int")
+
+    def test_schema_evolve(self):
+        class User1(Record):
+            name = String()
+            age = Integer()
+
+        class User2(Record):
+            _sorted_fields = True
+            name = String()
+            age = Integer(required=True)
+
+        response = requests.put('http://localhost:8080/admin/v2/namespaces/'
+                                'public/default/schemaCompatibilityStrategy',
+                                data='"FORWARD"'.encode(),
+                                headers={'Content-Type': 'application/json'})
+        self.assertEqual(response.status_code, 204)
+
+        topic = 'schema-test-schema-evolve-2'
+        client = pulsar.Client(self.serviceUrl)
+        producer1 = client.create_producer(topic, schema=AvroSchema(User1))
+        consumer = client.subscribe(topic, 'sub', schema=AvroSchema(User1))
+        reader = client.create_reader(topic,
+                                      schema=AvroSchema(User1),
+                                      
start_message_id=pulsar.MessageId.earliest)
+        producer2 = client.create_producer(topic, schema=AvroSchema(User2))
+
+        num_messages = 10 * 2
+        for i in range(int(num_messages / 2)):
+            producer1.send(User1(age=i+100, name=f'User1 {i}'))
+            producer2.send(User2(age=i+200, name=f'User2 {i}'))
+
+        def verify_messages(msgs: List[pulsar.Message]):
+            for i, msg in enumerate(msgs):
+                value = msg.value()
+                index = math.floor(i / 2)
+                if i % 2 == 0:
+                    self.assertEqual(value.age, index + 100)
+                    self.assertEqual(value.name, f'User1 {index}')
+                else:
+                    self.assertEqual(value.age, index + 200)
+                    self.assertEqual(value.name, f'User2 {index}')
+
+        msgs1 = []
+        msgs2 = []
+        for i in range(num_messages):
+            msgs1.append(consumer.receive())
+            msgs2.append(reader.read_next(1000))
+        verify_messages(msgs1)
+        verify_messages(msgs2)
+
+        client.close()
+
 if __name__ == '__main__':
     main()

Reply via email to