This is an automated email from the ASF dual-hosted git repository.

xyz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-cpp.git


The following commit(s) were added to refs/heads/main by this push:
     new 889a04b  Support getting encryption context on a message (#526)
889a04b is described below

commit 889a04bc71eb207ec064a8cec37709dd3b7bccdb
Author: Yunze Xu <[email protected]>
AuthorDate: Tue Dec 9 10:07:41 2025 +0800

    Support getting encryption context on a message (#526)
---
 include/pulsar/EncryptionContext.h | 113 +++++++++++++++++++++++++++++
 include/pulsar/Message.h           |   8 ++
 lib/Commands.cc                    |   1 +
 lib/ConsumerImpl.cc                |  48 ++++++------
 lib/ConsumerImpl.h                 |  11 ++-
 lib/EncryptionContext.cc           |  48 ++++++++++++
 lib/Message.cc                     |   7 ++
 lib/MessageCrypto.cc               |  35 ++++-----
 lib/MessageCrypto.h                |  14 ++--
 lib/MessageImpl.h                  |   4 +
 tests/BasicEndToEndTest.cc         |   4 +
 tests/EncryptionTest.cc            | 145 +++++++++++++++++++++++++++++++++++++
 win-examples/CMakeLists.txt        |   1 +
 13 files changed, 390 insertions(+), 49 deletions(-)

diff --git a/include/pulsar/EncryptionContext.h 
b/include/pulsar/EncryptionContext.h
new file mode 100644
index 0000000..ac7ebf9
--- /dev/null
+++ b/include/pulsar/EncryptionContext.h
@@ -0,0 +1,113 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#pragma once
+
+#include <cstdint>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "CompressionType.h"
+#include "defines.h"
+
+namespace pulsar {
+
+namespace proto {
+class MessageMetadata;
+}
+
+struct PULSAR_PUBLIC EncryptionKey {
+    std::string key;
+    std::string value;
+    std::unordered_map<std::string, std::string> metadata;
+
+    EncryptionKey(const std::string& key, const std::string& value,
+                  const decltype(EncryptionKey::metadata)& metadata)
+        : key(key), value(value), metadata(metadata) {}
+};
+
+/**
+ * It contains encryption and compression information in it using which 
application can decrypt consumed
+ * message with encrypted-payload.
+ */
+class PULSAR_PUBLIC EncryptionContext {
+   public:
+    using KeysType = std::vector<EncryptionKey>;
+
+    /**
+     * @return the map of encryption keys used for the message
+     */
+    const KeysType& keys() const noexcept { return keys_; }
+
+    /**
+     * @return the encryption parameter used for the message
+     */
+    const std::string& param() const noexcept { return param_; }
+
+    /**
+     * @return the encryption algorithm used for the message
+     */
+    const std::string& algorithm() const noexcept { return algorithm_; }
+
+    /**
+     * @return the compression type used for the message
+     */
+    CompressionType compressionType() const noexcept { return 
compressionType_; }
+
+    /**
+     * @return the uncompressed message size if the message is compressed, 0 
otherwise
+     */
+    uint32_t uncompressedMessageSize() const noexcept { return 
uncompressedMessageSize_; }
+
+    /**
+     * @return the batch size if the message is part of a batch, -1 otherwise
+     */
+    int32_t batchSize() const noexcept { return batchSize_; }
+
+    /**
+     * When the `ConsumerConfiguration#getCryptoFailureAction` is set to 
`CONSUME`, the message will still be
+     * returned even if the decryption failed. This method is provided to let 
users know whether the
+     * decryption failed.
+     *
+     * @return whether the decryption failed
+     */
+    bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; }
+
+    /**
+     * This constructor is public to allow in-place construction via 
std::optional
+     * (e.g., `std::optional<EncryptionContext>(std::in_place, metadata, 
false)`),
+     * but should not be used directly in application code.
+     */
+    EncryptionContext(const proto::MessageMetadata&, bool);
+
+   private:
+    KeysType keys_;
+    std::string param_;
+    std::string algorithm_;
+    CompressionType compressionType_{CompressionNone};
+    uint32_t uncompressedMessageSize_{0};
+    int32_t batchSize_{-1};
+    bool isDecryptionFailed_{false};
+
+    void setDecryptionFailed(bool failed) noexcept { isDecryptionFailed_ = 
failed; }
+
+    friend class ConsumerImpl;
+};
+
+}  // namespace pulsar
diff --git a/include/pulsar/Message.h b/include/pulsar/Message.h
index ea4c4ab..f52879e 100644
--- a/include/pulsar/Message.h
+++ b/include/pulsar/Message.h
@@ -19,10 +19,12 @@
 #ifndef MESSAGE_HPP_
 #define MESSAGE_HPP_
 
+#include <pulsar/EncryptionContext.h>
 #include <pulsar/defines.h>
 
 #include <map>
 #include <memory>
+#include <optional>
 #include <string>
 
 #include "KeyValue.h"
@@ -202,6 +204,12 @@ class PULSAR_PUBLIC Message {
      */
     const std::string& getProducerName() const noexcept;
 
+    /**
+     * @return the optional encryption context that is present when the 
message is encrypted, the pointer is
+     * valid as the Message instance is alive
+     */
+    std::optional<const EncryptionContext*> getEncryptionContext() const;
+
     bool operator==(const Message& msg) const;
 
    protected:
diff --git a/lib/Commands.cc b/lib/Commands.cc
index 3c687c0..30f5bf1 100644
--- a/lib/Commands.cc
+++ b/lib/Commands.cc
@@ -930,6 +930,7 @@ Message Commands::deSerializeSingleMessageInBatch(Message& 
batchedMessage, int32
                           batchedMessage.impl_->metadata, payload, metadata,
                           batchedMessage.impl_->topicName_);
     singleMessage.impl_->cnx_ = batchedMessage.impl_->cnx_;
+    singleMessage.impl_->encryptionContext_ = 
batchedMessage.impl_->encryptionContext_;
 
     return singleMessage;
 }
diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc
index 4781e96..430b851 100644
--- a/lib/ConsumerImpl.cc
+++ b/lib/ConsumerImpl.cc
@@ -19,6 +19,7 @@
 #include "ConsumerImpl.h"
 
 #include <pulsar/DeadLetterPolicyBuilder.h>
+#include <pulsar/EncryptionContext.h>
 #include <pulsar/MessageIdBuilder.h>
 
 #include <algorithm>
@@ -549,24 +550,27 @@ void ConsumerImpl::messageReceived(const 
ClientConnectionPtr& cnx, const proto::
                                    proto::MessageMetadata& metadata, 
SharedBuffer& payload) {
     LOG_DEBUG(getName() << "Received Message -- Size: " << 
payload.readableBytes());
 
-    if (!decryptMessageIfNeeded(cnx, msg, metadata, payload)) {
-        // Message was discarded or not consumed due to decryption failure
-        return;
-    }
-
     if (!isChecksumValid) {
         // Message discarded for checksum error
         discardCorruptedMessage(cnx, msg.message_id(), 
CommandAck_ValidationError_ChecksumMismatch);
         return;
     }
 
-    auto redeliveryCount = msg.redelivery_count();
-    const bool isMessageUndecryptable =
-        metadata.encryption_keys_size() > 0 && 
!config_.getCryptoKeyReader().get() &&
-        config_.getCryptoFailureAction() == 
ConsumerCryptoFailureAction::CONSUME;
+    auto encryptionContext = metadata.encryption_keys_size() > 0
+                                 ? optional<EncryptionContext>(std::in_place, 
metadata, false)
+                                 : std::nullopt;
+    const auto decryptionResult = decryptMessageIfNeeded(cnx, msg, 
encryptionContext, payload);
+    if (decryptionResult == DecryptionResult::FAILED) {
+        // Message was discarded or not consumed due to decryption failure
+        return;
+    } else if (decryptionResult == DecryptionResult::CONSUME_ENCRYPTED && 
encryptionContext.has_value()) {
+        // Message is encrypted, but we let the application consume it as-is
+        encryptionContext->setDecryptionFailed(true);
+    }
 
+    auto redeliveryCount = msg.redelivery_count();
     const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1;
-    if (!isMessageUndecryptable && !isChunkedMessage) {
+    if (decryptionResult == DecryptionResult::SUCCESS && !isChunkedMessage) {
         if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, 
payload, true)) {
             // Message was discarded on decompression error
             return;
@@ -590,6 +594,7 @@ void ConsumerImpl::messageReceived(const 
ClientConnectionPtr& cnx, const proto::
     m.impl_->cnx_ = cnx.get();
     m.impl_->setTopicName(getTopicPtr());
     m.impl_->setRedeliveryCount(msg.redelivery_count());
+    m.impl_->encryptionContext_ = std::move(encryptionContext);
 
     if (metadata.has_schema_version()) {
         m.impl_->setSchemaVersion(metadata.schema_version());
@@ -610,7 +615,7 @@ void ConsumerImpl::messageReceived(const 
ClientConnectionPtr& cnx, const proto::
         return;
     }
 
-    if (metadata.has_num_messages_in_batch()) {
+    if (metadata.has_num_messages_in_batch() && decryptionResult == 
DecryptionResult::SUCCESS) {
         BitSet::Data words(msg.ack_set_size());
         for (int i = 0; i < words.size(); i++) {
             words[i] = msg.ack_set(i);
@@ -812,17 +817,18 @@ uint32_t 
ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection
     return batchSize - skippedMessages;
 }
 
-bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, 
const proto::CommandMessage& msg,
-                                          const proto::MessageMetadata& 
metadata, SharedBuffer& payload) {
-    if (!metadata.encryption_keys_size()) {
-        return true;
+auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, 
const proto::CommandMessage& msg,
+                                          const optional<EncryptionContext>& 
context, SharedBuffer& payload)
+    -> DecryptionResult {
+    if (!context.has_value()) {
+        return DecryptionResult::SUCCESS;
     }
 
     // If KeyReader is not configured throw exception based on config param
     if (!config_.isEncryptionEnabled()) {
         if (config_.getCryptoFailureAction() == 
ConsumerCryptoFailureAction::CONSUME) {
             LOG_WARN(getName() << "CryptoKeyReader is not implemented. 
Consuming encrypted message.");
-            return true;
+            return DecryptionResult::CONSUME_ENCRYPTED;
         } else if (config_.getCryptoFailureAction() == 
ConsumerCryptoFailureAction::DISCARD) {
             LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader 
is not implemented and config "
                                   "is set to discard");
@@ -833,20 +839,20 @@ bool ConsumerImpl::decryptMessageIfNeeded(const 
ClientConnectionPtr& cnx, const
             auto messageId = MessageIdBuilder::from(msg.message_id()).build();
             unAckedMessageTrackerPtr_->add(messageId);
         }
-        return false;
+        return DecryptionResult::FAILED;
     }
 
     SharedBuffer decryptedPayload;
-    if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), 
decryptedPayload)) {
+    if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), 
decryptedPayload)) {
         payload = decryptedPayload;
-        return true;
+        return DecryptionResult::SUCCESS;
     }
 
     if (config_.getCryptoFailureAction() == 
ConsumerCryptoFailureAction::CONSUME) {
         // Note, batch message will fail to consume even if config is set to 
consume
         LOG_WARN(
             getName() << "Decryption failed. Consuming encrypted message since 
config is set to consume.");
-        return true;
+        return DecryptionResult::CONSUME_ENCRYPTED;
     } else if (config_.getCryptoFailureAction() == 
ConsumerCryptoFailureAction::DISCARD) {
         LOG_WARN(getName() << "Discarding message since decryption failed and 
config is set to discard");
         discardCorruptedMessage(cnx, msg.message_id(), 
CommandAck_ValidationError_DecryptionError);
@@ -855,7 +861,7 @@ bool ConsumerImpl::decryptMessageIfNeeded(const 
ClientConnectionPtr& cnx, const
         auto messageId = MessageIdBuilder::from(msg.message_id()).build();
         unAckedMessageTrackerPtr_->add(messageId);
     }
-    return false;
+    return DecryptionResult::FAILED;
 }
 
 bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx,
diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h
index c1df080..63eb51d 100644
--- a/lib/ConsumerImpl.h
+++ b/lib/ConsumerImpl.h
@@ -195,8 +195,15 @@ class ConsumerImpl : public ConsumerImplBase {
     bool isPriorEntryIndex(int64_t idx);
     void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const 
BrokerConsumerStatsCallback&);
 
-    bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const 
proto::CommandMessage& msg,
-                                const proto::MessageMetadata& metadata, 
SharedBuffer& payload);
+    enum class DecryptionResult : uint8_t
+    {
+        SUCCESS,
+        CONSUME_ENCRYPTED,
+        FAILED
+    };
+    DecryptionResult decryptMessageIfNeeded(const ClientConnectionPtr& cnx, 
const proto::CommandMessage& msg,
+                                            const optional<EncryptionContext>& 
context,
+                                            SharedBuffer& payload);
 
     // TODO - Convert these functions to lambda when we move to C++11
     Result receiveHelper(Message& msg);
diff --git a/lib/EncryptionContext.cc b/lib/EncryptionContext.cc
new file mode 100644
index 0000000..5376f06
--- /dev/null
+++ b/lib/EncryptionContext.cc
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <pulsar/EncryptionContext.h>
+
+#include "PulsarApi.pb.h"
+
+namespace pulsar {
+
+static EncryptionContext::KeysType encryptedKeysFromMetadata(const 
proto::MessageMetadata& msgMetadata) {
+    EncryptionContext::KeysType keys;
+    for (auto&& key : msgMetadata.encryption_keys()) {
+        decltype(EncryptionKey::metadata) metadata;
+        for (int i = 0; i < key.metadata_size(); i++) {
+            const auto& entry = key.metadata(i);
+            metadata[entry.key()] = entry.value();
+        }
+        keys.emplace_back(key.key(), key.value(), std::move(metadata));
+    }
+    return keys;
+}
+
+EncryptionContext::EncryptionContext(const proto::MessageMetadata& 
msgMetadata, bool isDecryptionFailed)
+
+    : keys_(encryptedKeysFromMetadata(msgMetadata)),
+      param_(msgMetadata.encryption_param()),
+      algorithm_(msgMetadata.encryption_algo()),
+      
compressionType_(static_cast<CompressionType>(msgMetadata.compression())),
+      uncompressedMessageSize_(msgMetadata.uncompressed_size()),
+      batchSize_(msgMetadata.has_num_messages_in_batch() ? 
msgMetadata.num_messages_in_batch() : -1),
+      isDecryptionFailed_(isDecryptionFailed) {}
+
+}  // namespace pulsar
diff --git a/lib/Message.cc b/lib/Message.cc
index 1e26b52..9505565 100644
--- a/lib/Message.cc
+++ b/lib/Message.cc
@@ -220,6 +220,13 @@ const std::string& Message::getProducerName() const 
noexcept {
     return impl_->metadata.producer_name();
 }
 
+std::optional<const EncryptionContext*> Message::getEncryptionContext() const {
+    if (!impl_ || !impl_->encryptionContext_.has_value()) {
+        return std::nullopt;
+    }
+    return &impl_->encryptionContext_.value();
+}
+
 bool Message::operator==(const Message& msg) const { return getMessageId() == 
msg.getMessageId(); }
 
 KeyValue Message::getKeyValueData() const { return 
KeyValue(impl_->keyValuePtr); }
diff --git a/lib/MessageCrypto.cc b/lib/MessageCrypto.cc
index b06ff65..daa492e 100644
--- a/lib/MessageCrypto.cc
+++ b/lib/MessageCrypto.cc
@@ -394,13 +394,13 @@ bool MessageCrypto::encrypt(const std::set<std::string>& 
encKeys, const CryptoKe
     return true;
 }
 
-bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const 
CryptoKeyReader& keyReader) {
-    const auto& keyName = encKeys.key();
-    const auto& encryptedDataKey = encKeys.value();
-    const auto& encKeyMeta = encKeys.metadata();
+bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const 
CryptoKeyReader& keyReader) {
+    const auto& keyName = encKeys.key;
+    const auto& encryptedDataKey = encKeys.value;
+    const auto& encKeyMeta = encKeys.metadata;
     StringMap keyMeta;
     for (auto iter = encKeyMeta.begin(); iter != encKeyMeta.end(); iter++) {
-        keyMeta[iter->key()] = iter->value();
+        keyMeta[iter->first] = iter->second;
     }
 
     // Read the private key info using callback
@@ -451,11 +451,10 @@ bool MessageCrypto::decryptDataKey(const 
proto::EncryptionKeys& encKeys, const C
     return true;
 }
 
-bool MessageCrypto::decryptData(const std::string& dataKeySecret, const 
proto::MessageMetadata& msgMetadata,
+bool MessageCrypto::decryptData(const std::string& dataKeySecret, const 
EncryptionContext& context,
                                 SharedBuffer& payload, SharedBuffer& 
decryptedPayload) {
     // unpack iv and encrypted data
-    msgMetadata.encryption_param().copy(reinterpret_cast<char*>(iv_.get()),
-                                        msgMetadata.encryption_param().size());
+    context.param().copy(reinterpret_cast<char*>(iv_.get()), 
context.param().size());
 
     EVP_CIPHER_CTX* cipherCtx = NULL;
     decryptedPayload = SharedBuffer::allocate(payload.readableBytes() + 
EVP_MAX_BLOCK_LENGTH + tagLen_);
@@ -518,15 +517,14 @@ bool MessageCrypto::decryptData(const std::string& 
dataKeySecret, const proto::M
     return true;
 }
 
-bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& 
msgMetadata, SharedBuffer& payload,
+bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, 
SharedBuffer& payload,
                                          SharedBuffer& decryptedPayload) {
     SharedBuffer decryptedData;
     bool dataDecrypted = false;
 
-    for (auto iter = msgMetadata.encryption_keys().begin(); iter != 
msgMetadata.encryption_keys().end();
-         iter++) {
-        const std::string& keyName = iter->key();
-        const std::string& encDataKey = iter->value();
+    for (auto&& kv : context.keys()) {
+        const std::string& keyName = kv.key;
+        const std::string& encDataKey = kv.value;
         unsigned char keyDigest[EVP_MAX_MD_SIZE];
         unsigned int digestLen = 0;
         getDigest(keyName, encDataKey.c_str(), encDataKey.size(), keyDigest, 
digestLen);
@@ -539,7 +537,7 @@ bool MessageCrypto::getKeyAndDecryptData(const 
proto::MessageMetadata& msgMetada
             // retruns a different key, decryption fails. At this point, we 
would
             // call decryptDataKey to refresh the cache and come here again to 
decrypt.
             auto dataKeyEntry = dataKeyCacheIter->second;
-            if (decryptData(dataKeyEntry.first, msgMetadata, payload, 
decryptedPayload)) {
+            if (decryptData(dataKeyEntry.first, context, payload, 
decryptedPayload)) {
                 dataDecrypted = true;
                 break;
             }
@@ -552,17 +550,16 @@ bool MessageCrypto::getKeyAndDecryptData(const 
proto::MessageMetadata& msgMetada
     return dataDecrypted;
 }
 
-bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, 
SharedBuffer& payload,
+bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& 
payload,
                             const CryptoKeyReaderPtr& keyReader, SharedBuffer& 
decryptedPayload) {
     // Attempt to decrypt using the existing key
-    if (getKeyAndDecryptData(msgMetadata, payload, decryptedPayload)) {
+    if (getKeyAndDecryptData(context, payload, decryptedPayload)) {
         return true;
     }
 
     // Either first time, or decryption failed. Attempt to regenerate data key
     bool isDataKeyDecrypted = false;
-    for (int index = 0; index < msgMetadata.encryption_keys_size(); index++) {
-        const proto::EncryptionKeys& encKeys = 
msgMetadata.encryption_keys(index);
+    for (auto&& encKeys : context.keys()) {
         if (decryptDataKey(encKeys, *keyReader)) {
             isDataKeyDecrypted = true;
             break;
@@ -574,7 +571,7 @@ bool MessageCrypto::decrypt(const proto::MessageMetadata& 
msgMetadata, SharedBuf
         return false;
     }
 
-    return getKeyAndDecryptData(msgMetadata, payload, decryptedPayload);
+    return getKeyAndDecryptData(context, payload, decryptedPayload);
 }
 
 } /* namespace pulsar */
diff --git a/lib/MessageCrypto.h b/lib/MessageCrypto.h
index cd07bf5..4052066 100644
--- a/lib/MessageCrypto.h
+++ b/lib/MessageCrypto.h
@@ -26,10 +26,10 @@
 #include <openssl/rsa.h>
 #include <openssl/ssl.h>
 #include <pulsar/CryptoKeyReader.h>
+#include <pulsar/EncryptionContext.h>
 
 #include <boost/date_time/posix_time/ptime.hpp>
 #include <boost/scoped_array.hpp>
-#include <iostream>
 #include <map>
 #include <mutex>
 #include <set>
@@ -90,15 +90,15 @@ class MessageCrypto {
     /*
      * Decrypt the payload using the data key. Keys used to encrypt data key 
can be retrieved from msgMetadata
      *
-     * @param msgMetadata Message Metadata
+     * @param context the encryption context
      * @param payload Message which needs to be decrypted
      * @param keyReader KeyReader implementation to retrieve key value
      * @param decryptedPayload Contains decrypted payload if success
      *
      * @return true if success
      */
-    bool decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& 
payload,
-                 const CryptoKeyReaderPtr& keyReader, SharedBuffer& 
decryptedPayload);
+    bool decrypt(const EncryptionContext& context, SharedBuffer& payload, 
const CryptoKeyReaderPtr& keyReader,
+                 SharedBuffer& decryptedPayload);
 
    private:
     typedef std::unique_lock<std::mutex> Lock;
@@ -137,10 +137,10 @@ class MessageCrypto {
 
     Result addPublicKeyCipher(const std::string& keyName, const 
CryptoKeyReaderPtr& keyReader);
 
-    bool decryptDataKey(const proto::EncryptionKeys& encKeys, const 
CryptoKeyReader& keyReader);
-    bool decryptData(const std::string& dataKeySecret, const 
proto::MessageMetadata& msgMetadata,
+    bool decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& 
keyReader);
+    bool decryptData(const std::string& dataKeySecret, const 
EncryptionContext& context,
                      SharedBuffer& payload, SharedBuffer& decPayload);
-    bool getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, 
SharedBuffer& payload,
+    bool getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& 
payload,
                               SharedBuffer& decryptedPayload);
     std::string stringToHex(const std::string& inputStr, size_t len);
     std::string stringToHex(const char* inputStr, size_t len);
diff --git a/lib/MessageImpl.h b/lib/MessageImpl.h
index 6467b35..a234ca4 100644
--- a/lib/MessageImpl.h
+++ b/lib/MessageImpl.h
@@ -22,9 +22,12 @@
 #include <pulsar/Message.h>
 #include <pulsar/MessageId.h>
 
+#include <optional>
+
 #include "KeyValueImpl.h"
 #include "PulsarApi.pb.h"
 #include "SharedBuffer.h"
+#include "pulsar/EncryptionContext.h"
 
 using namespace pulsar;
 namespace pulsar {
@@ -48,6 +51,7 @@ class MessageImpl {
     bool hasSchemaVersion_;
     const std::string* schemaVersion_;
     std::weak_ptr<class ConsumerImpl> consumerPtr_;
+    std::optional<EncryptionContext> encryptionContext_;
 
     const std::string& getPartitionKey() const;
     bool hasPartitionKey() const;
diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc
index c9a8faa..9a02df0 100644
--- a/tests/BasicEndToEndTest.cc
+++ b/tests/BasicEndToEndTest.cc
@@ -1465,6 +1465,10 @@ TEST(BasicEndToEndTest, testRSAEncryption) {
             expected << msgContent << msgNum;
             ASSERT_EQ(expected.str(), msgReceived.getDataAsString());
             ASSERT_EQ(ResultOk, consumer.acknowledge(msgReceived));
+            auto context = msgReceived.getEncryptionContext();
+            ASSERT_TRUE(context.has_value());
+            ASSERT_EQ(context.value()->keys().size(), 1);
+            ASSERT_EQ(context.value()->keys()[0].key, "client-rsa.pem");
         }
 
         ASSERT_EQ(ResultOk, consumer.unsubscribe());
diff --git a/tests/EncryptionTest.cc b/tests/EncryptionTest.cc
new file mode 100644
index 0000000..ff5cb98
--- /dev/null
+++ b/tests/EncryptionTest.cc
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <pulsar/Client.h>
+#include <pulsar/ConsumerCryptoFailureAction.h>
+#include <pulsar/MessageBatch.h>
+
+#include <optional>
+#include <stdexcept>
+
+#include "lib/CompressionCodec.h"
+#include "lib/MessageCrypto.h"
+#include "lib/SharedBuffer.h"
+
+static std::string lookupUrl = "pulsar://localhost:6650";
+
+using namespace pulsar;
+
+static CryptoKeyReaderPtr getDefaultCryptoKeyReader() {
+    return std::make_shared<DefaultCryptoKeyReader>(TEST_CONF_DIR 
"/public-key.client-rsa.pem",
+                                                    TEST_CONF_DIR 
"/private-key.client-rsa.pem");
+}
+
+static std::vector<std::string> decryptValue(const char* data, size_t length,
+                                             std::optional<const 
EncryptionContext*> context) {
+    if (!context.has_value()) {
+        return {std::string(data, length)};
+    }
+    if (!context.value()->isDecryptionFailed()) {
+        return {std::string(data, length)};
+    }
+
+    MessageCrypto crypto{"test", false};
+    SharedBuffer decryptedPayload;
+    auto originalPayload = SharedBuffer::copy(data, length);
+    if (!crypto.decrypt(*context.value(), originalPayload, 
getDefaultCryptoKeyReader(), decryptedPayload)) {
+        throw std::runtime_error("Decryption failed");
+    }
+
+    SharedBuffer uncompressedPayload;
+    if (!CompressionCodecProvider::getCodec(context.value()->compressionType())
+             .decode(decryptedPayload, 
context.value()->uncompressedMessageSize(), uncompressedPayload)) {
+        throw std::runtime_error("Decompression failed");
+    }
+
+    std::vector<std::string> values;
+    if (auto batchSize = context.value()->batchSize(); batchSize > 0) {
+        MessageBatch batch;
+        for (auto&& msg : batch.parseFrom(uncompressedPayload, 
batchSize).messages()) {
+            values.emplace_back(msg.getDataAsString());
+        }
+    } else {
+        // non-batched message
+        values.emplace_back(uncompressedPayload.data(), 
uncompressedPayload.readableBytes());
+    }
+    return values;
+}
+
+static void testDecryption(Client& client, const std::string& topic, bool 
withDecryption,
+                           int numMessageReceived) {
+    ProducerConfiguration producerConf;
+    producerConf.setCompressionType(CompressionLZ4);
+    producerConf.addEncryptionKey("client-rsa.pem");
+    producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());
+
+    Producer producer;
+    ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer));
+
+    std::vector<std::string> sentValues;
+    auto send = [&producer, &sentValues](const std::string& value) {
+        Message msg = MessageBuilder().setContent(value).build();
+        producer.sendAsync(msg, nullptr);
+        sentValues.emplace_back(value);
+    };
+
+    for (int i = 0; i < 5; i++) {
+        send("msg-" + std::to_string(i));
+    }
+    producer.flush();
+    send("last-msg");
+    producer.flush();
+
+    ASSERT_EQ(ResultOk, client.createProducer(topic, producer));
+    send("unencrypted-msg");
+    producer.flush();
+    producer.close();
+
+    ConsumerConfiguration consumerConf;
+    consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest);
+    if (withDecryption) {
+        consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());
+    } else {
+        
consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME);
+    }
+    Consumer consumer;
+    ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, 
consumer));
+
+    std::vector<std::string> values;
+    for (int i = 0; i < numMessageReceived; i++) {
+        Message msg;
+        ASSERT_EQ(ResultOk, consumer.receive(msg, 3000));
+        if (i < numMessageReceived - 1) {
+            ASSERT_TRUE(msg.getEncryptionContext().has_value());
+        } else {
+            ASSERT_FALSE(msg.getEncryptionContext().has_value());
+        }
+        for (auto&& value : decryptValue(static_cast<const 
char*>(msg.getData()), msg.getLength(),
+                                         msg.getEncryptionContext())) {
+            values.emplace_back(value);
+        }
+    }
+    ASSERT_EQ(values, sentValues);
+    consumer.close();
+}
+
+TEST(EncryptionTests, testDecryptionSuccess) {
+    Client client{lookupUrl};
+    std::string topic = "test-decryption-success-" + 
std::to_string(time(nullptr));
+    testDecryption(client, topic, true, 7);
+    client.close();
+}
+
+TEST(EncryptionTests, testDecryptionFailure) {
+    Client client{lookupUrl};
+    std::string topic = "test-decryption-failure-" + 
std::to_string(time(nullptr));
+    // The 1st batch that has 5 messages cannot be decrypted, so they can be 
received only once
+    testDecryption(client, topic, false, 3);
+    client.close();
+}
diff --git a/win-examples/CMakeLists.txt b/win-examples/CMakeLists.txt
index 3998c43..c8d74b6 100644
--- a/win-examples/CMakeLists.txt
+++ b/win-examples/CMakeLists.txt
@@ -20,6 +20,7 @@
 cmake_minimum_required(VERSION 3.4)
 project(pulsar-cpp-win-examples)
 
+set(CMAKE_CXX_STANDARD 17)
 find_path(PULSAR_INCLUDES NAMES "pulsar/Client.h")
 if (PULSAR_INCLUDES)
     message(STATUS "PULSAR_INCLUDES: " ${PULSAR_INCLUDES})

Reply via email to