This is an automated email from the ASF dual-hosted git repository. guangmingchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/brpc.git
The following commit(s) were added to refs/heads/master by this push: new a1877bc0 Support various payload of baidu-std: json, proto-json and proto-text (#2946) a1877bc0 is described below commit a1877bc07748c4027f5eadf779f496931e7c1270 Author: Bright Chen <chenguangmin...@foxmail.com> AuthorDate: Wed Apr 16 17:03:30 2025 +0800 Support various payload of baidu-std: json, proto-json and proto-text (#2946) * Support various payload of baidu-std: json, proto-json and proto-text * Refactor implementation of compress --- src/brpc/compress.cpp | 39 ++++- src/brpc/compress.h | 136 +++++++++++++++++ src/brpc/controller.cpp | 2 + src/brpc/controller.h | 19 +++ src/brpc/global.cpp | 11 +- src/brpc/memcache.cpp | 4 +- src/brpc/nonreflectable_message.h | 2 +- src/brpc/nshead_message.cpp | 2 +- src/brpc/options.proto | 7 + src/brpc/policy/baidu_rpc_meta.proto | 1 + src/brpc/policy/baidu_rpc_protocol.cpp | 253 +++++++++++++++++++++++++++---- src/brpc/policy/baidu_rpc_protocol.h | 7 + src/brpc/policy/gzip_compress.cpp | 108 ++++++++----- src/brpc/policy/snappy_compress.cpp | 51 +++++-- src/brpc/proto_base.proto | 3 + src/brpc/protocol.cpp | 7 +- src/brpc/redis.cpp | 4 +- src/brpc/serialized_request.cpp | 2 +- src/brpc/serialized_request.h | 1 - src/brpc/serialized_response.cpp | 2 +- src/brpc/serialized_response.h | 1 - src/json2pb/json_to_pb.cpp | 7 +- src/json2pb/json_to_pb.h | 6 +- src/json2pb/pb_to_json.cpp | 9 +- src/json2pb/pb_to_json.h | 13 +- src/json2pb/protobuf_type_resolver.h | 2 + test/brpc_http_rpc_protocol_unittest.cpp | 11 +- test/brpc_server_unittest.cpp | 183 +++++++++++++++------- test/brpc_socket_unittest.cpp | 43 ++++-- test/bthread_cond_unittest.cpp | 2 +- 30 files changed, 729 insertions(+), 209 deletions(-) diff --git a/src/brpc/compress.cpp b/src/brpc/compress.cpp index dd9ba825..36e55c7c 100644 --- a/src/brpc/compress.cpp +++ b/src/brpc/compress.cpp @@ -17,9 +17,10 @@ #include "butil/logging.h" +#include "json2pb/json_to_pb.h" #include "brpc/compress.h" #include "brpc/protocol.h" - +#include "brpc/proto_base.pb.h" namespace brpc { @@ -47,7 +48,7 @@ int RegisterCompressHandler(CompressType type, // Find CompressHandler by type. // Returns NULL if not found -inline const CompressHandler* FindCompressHandler(CompressType type) { +const CompressHandler* FindCompressHandler(CompressType type) { int index = type; if (index < 0 || index >= MAX_HANDLER_SIZE) { LOG(ERROR) << "CompressType=" << type << " is out of range"; @@ -83,10 +84,14 @@ bool ParseFromCompressedData(const butil::IOBuf& data, return ParsePbFromIOBuf(msg, data); } const CompressHandler* handler = FindCompressHandler(compress_type); - if (NULL != handler) { - return handler->Decompress(data, msg); + if (NULL == handler) { + return false; } - return false; + + Deserializer deserializer([msg](google::protobuf::io::ZeroCopyInputStream* input) { + return msg->ParseFromZeroCopyStream(input); + }); + return handler->Decompress(data, &deserializer); } bool SerializeAsCompressedData(const google::protobuf::Message& msg, @@ -96,10 +101,28 @@ bool SerializeAsCompressedData(const google::protobuf::Message& msg, return msg.SerializeToZeroCopyStream(&wrapper); } const CompressHandler* handler = FindCompressHandler(compress_type); - if (NULL != handler) { - return handler->Compress(msg, buf); + if (NULL == handler) { + return false; } - return false; + + Serializer serializer([&msg](google::protobuf::io::ZeroCopyOutputStream* output) { + return msg.SerializeToZeroCopyStream(output); + }); + return handler->Compress(serializer, buf); +} + +::google::protobuf::Metadata Serializer::GetMetadata() const { + ::google::protobuf::Metadata metadata{}; + metadata.descriptor = SerializerBase::descriptor(); + metadata.reflection = nullptr; + return metadata; +} + +::google::protobuf::Metadata Deserializer::GetMetadata() const { + ::google::protobuf::Metadata metadata{}; + metadata.descriptor = DeserializerBase::descriptor(); + metadata.reflection = nullptr; + return metadata; } } // namespace brpc diff --git a/src/brpc/compress.h b/src/brpc/compress.h index 0b0fcb17..45299598 100644 --- a/src/brpc/compress.h +++ b/src/brpc/compress.h @@ -21,10 +21,143 @@ #include <google/protobuf/message.h> // Message #include "butil/iobuf.h" // butil::IOBuf +#include "butil/logging.h" #include "brpc/options.pb.h" // CompressType +#include "brpc/nonreflectable_message.h" namespace brpc { +// Serializer can be used to implement custom serialization +// before compression with user callback. +class Serializer : public NonreflectableMessage<Serializer> { +public: + using Callback = std::function<bool(google::protobuf::io::ZeroCopyOutputStream*)>; + + Serializer() :Serializer(NULL) {} + + explicit Serializer(Callback callback) + :_callback(std::move(callback)) { + SharedCtor(); + } + + ~Serializer() override { + SharedDtor(); + } + + Serializer(const Serializer& from) + : NonreflectableMessage(from) { + SharedCtor(); + MergeFrom(from); + } + + Serializer& operator=(const Serializer& from) { + CopyFrom(from); + return *this; + } + + void Swap(Serializer* other) { + if (other != this) { + } + } + + void MergeFrom(const Serializer& from) override { + CHECK_NE(&from, this); + } + + // implements Message ---------------------------------------------- + void Clear() override { + _callback = nullptr; + } + size_t ByteSizeLong() const override { return 0; } + int GetCachedSize() const PB_425_OVERRIDE { return ByteSize(); } + + ::google::protobuf::Metadata GetMetadata() const PB_527_OVERRIDE; + + // Converts the data into `output' for later compression. + bool SerializeTo(google::protobuf::io::ZeroCopyOutputStream* output) const { + if (!_callback) { + LOG(WARNING) << "Serializer::SerializeTo() called without callback"; + return false; + } + return _callback(output); + } + + void SetCallback(Callback callback) { + _callback = std::move(callback); + } + +private: + void SharedCtor() {} + void SharedDtor() {} + + Callback _callback; +}; + +// Deserializer can be used to implement custom deserialization +// after decompression with user callback. +class Deserializer : public NonreflectableMessage<Deserializer> { +public: +public: + using Callback = std::function<bool(google::protobuf::io::ZeroCopyInputStream*)>; + + Deserializer() :Deserializer(NULL) {} + + explicit Deserializer(Callback callback) : _callback(std::move(callback)) { + SharedCtor(); + } + + ~Deserializer() override { + SharedDtor(); + } + + Deserializer(const Deserializer& from) + : NonreflectableMessage(from) { + SharedCtor(); + MergeFrom(from); + } + + Deserializer& operator=(const Deserializer& from) { + CopyFrom(from); + return *this; + } + + void Swap(Deserializer* other) { + if (other != this) { + _callback.swap(other->_callback); + } + } + + void MergeFrom(const Deserializer& from) override { + CHECK_NE(&from, this); + _callback = from._callback; + } + + // implements Message ---------------------------------------------- + void Clear() override { _callback = nullptr; } + size_t ByteSizeLong() const override { return 0; } + int GetCachedSize() const PB_425_OVERRIDE { return ByteSize(); } + + ::google::protobuf::Metadata GetMetadata() const PB_527_OVERRIDE; + + // Converts the decompressed `input'. + bool DeserializeFrom(google::protobuf::io::ZeroCopyInputStream* intput) const { + if (!_callback) { + LOG(WARNING) << "Deserializer::DeserializeFrom() called without callback"; + return false; + } + return _callback(intput); + } + void SetCallback(Callback callback) { + _callback = std::move(callback); + } + +private: + void SharedCtor() {} + void SharedDtor() {} + + Callback _callback; +}; + struct CompressHandler { // Compress serialized `msg' into `buf'. // Returns true on success, false otherwise @@ -42,6 +175,9 @@ struct CompressHandler { // Returns 0 on success, -1 otherwise int RegisterCompressHandler(CompressType type, CompressHandler handler); +// Returns CompressHandler pointer of `type' if registered, NULL otherwise. +const CompressHandler* FindCompressHandler(CompressType type); + // Returns the `name' of the CompressType if registered const char* CompressTypeToCStr(CompressType type); diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp index 271ccfa4..0cb83dc5 100644 --- a/src/brpc/controller.cpp +++ b/src/brpc/controller.cpp @@ -290,6 +290,8 @@ void Controller::ResetPods() { _http_response = NULL; _request_user_fields = NULL; _response_user_fields = NULL; + _request_content_type = CONTENT_TYPE_PB; + _response_content_type = CONTENT_TYPE_PB; _request_streams.clear(); _response_streams.clear(); _remote_stream_settings = NULL; diff --git a/src/brpc/controller.h b/src/brpc/controller.h index d9799f88..000aee2f 100644 --- a/src/brpc/controller.h +++ b/src/brpc/controller.h @@ -616,6 +616,20 @@ public: void CallAfterRpcResp(const google::protobuf::Message* req, const google::protobuf::Message* res); + void set_request_content_type(ContentType type) { + _request_content_type = type; + } + ContentType request_content_type() const { + return _request_content_type; + } + + void set_response_content_type(ContentType type) { + _response_content_type = type; + } + ContentType response_content_type() const { + return _response_content_type; + } + private: struct CompletionInfo { CallId id; // call_id of the corresponding request @@ -859,6 +873,11 @@ private: butil::IOBuf _request_attachment; butil::IOBuf _response_attachment; + // Only SerializedRequest supports `_request_content_type'. + ContentType _request_content_type; + // Only SerializedResponse supports `_response_content_type'. + ContentType _response_content_type; + // Writable progressive attachment butil::intrusive_ptr<ProgressiveAttachment> _wpa; // Readable progressive attachment diff --git a/src/brpc/global.cpp b/src/brpc/global.cpp index d45c67f0..6b3310ec 100644 --- a/src/brpc/global.cpp +++ b/src/brpc/global.cpp @@ -388,25 +388,22 @@ static void GlobalInitializeOrDieImpl() { LoadBalancerExtension()->RegisterOrDie("_dynpart", &g_ext->dynpart_lb); // Compress Handlers - const CompressHandler gzip_compress = - { GzipCompress, GzipDecompress, "gzip" }; + CompressHandler gzip_compress = { GzipCompress, GzipDecompress, "gzip" }; if (RegisterCompressHandler(COMPRESS_TYPE_GZIP, gzip_compress) != 0) { exit(1); } - const CompressHandler zlib_compress = - { ZlibCompress, ZlibDecompress, "zlib" }; + CompressHandler zlib_compress = { ZlibCompress, ZlibDecompress, "zlib" }; if (RegisterCompressHandler(COMPRESS_TYPE_ZLIB, zlib_compress) != 0) { exit(1); } - const CompressHandler snappy_compress = - { SnappyCompress, SnappyDecompress, "snappy" }; + CompressHandler snappy_compress = { SnappyCompress, SnappyDecompress, "snappy" }; if (RegisterCompressHandler(COMPRESS_TYPE_SNAPPY, snappy_compress) != 0) { exit(1); } // Protocols Protocol baidu_protocol = { ParseRpcMessage, - SerializeRequestDefault, PackRpcRequest, + SerializeRpcRequest, PackRpcRequest, ProcessRpcRequest, ProcessRpcResponse, VerifyRpcRequest, NULL, NULL, CONNECTION_TYPE_ALL, "baidu_std" }; diff --git a/src/brpc/memcache.cpp b/src/brpc/memcache.cpp index c198d168..489d84db 100644 --- a/src/brpc/memcache.cpp +++ b/src/brpc/memcache.cpp @@ -32,7 +32,7 @@ MemcacheRequest::MemcacheRequest() } MemcacheRequest::MemcacheRequest(const MemcacheRequest& from) - : NonreflectableMessage<MemcacheRequest>() { + : NonreflectableMessage<MemcacheRequest>(from) { SharedCtor(); MergeFrom(from); } @@ -143,7 +143,7 @@ MemcacheResponse::MemcacheResponse() } MemcacheResponse::MemcacheResponse(const MemcacheResponse& from) - : NonreflectableMessage<MemcacheResponse>() { + : NonreflectableMessage<MemcacheResponse>(from) { SharedCtor(); MergeFrom(from); } diff --git a/src/brpc/nonreflectable_message.h b/src/brpc/nonreflectable_message.h index 54a479fc..1494cd1b 100644 --- a/src/brpc/nonreflectable_message.h +++ b/src/brpc/nonreflectable_message.h @@ -21,7 +21,7 @@ #include <google/protobuf/generated_message_reflection.h> #include <google/protobuf/message.h> -#include "pb_compat.h" +#include "brpc/pb_compat.h" namespace brpc { diff --git a/src/brpc/nshead_message.cpp b/src/brpc/nshead_message.cpp index fe9e4c96..46081c70 100644 --- a/src/brpc/nshead_message.cpp +++ b/src/brpc/nshead_message.cpp @@ -28,7 +28,7 @@ NsheadMessage::NsheadMessage() } NsheadMessage::NsheadMessage(const NsheadMessage& from) - : NonreflectableMessage<NsheadMessage>() { + : NonreflectableMessage<NsheadMessage>(from) { SharedCtor(); MergeFrom(from); } diff --git a/src/brpc/options.proto b/src/brpc/options.proto index 3e34b5f6..e334c48e 100644 --- a/src/brpc/options.proto +++ b/src/brpc/options.proto @@ -74,6 +74,13 @@ enum CompressType { COMPRESS_TYPE_LZ4 = 4; } +enum ContentType { + CONTENT_TYPE_PB = 0; + CONTENT_TYPE_JSON = 1; + CONTENT_TYPE_PROTO_JSON = 2; + CONTENT_TYPE_PROTO_TEXT = 3; +} + message ChunkInfo { required int64 stream_id = 1; required int64 chunk_id = 2; diff --git a/src/brpc/policy/baidu_rpc_meta.proto b/src/brpc/policy/baidu_rpc_meta.proto index 300564bb..59179831 100644 --- a/src/brpc/policy/baidu_rpc_meta.proto +++ b/src/brpc/policy/baidu_rpc_meta.proto @@ -33,6 +33,7 @@ message RpcMeta { optional bytes authentication_data = 7; optional StreamSettings stream_settings = 8; map<string, string> user_fields = 9; + optional ContentType content_type = 10; } message RpcRequestMeta { diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 269f0645..8efff065 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -20,11 +20,13 @@ #include <google/protobuf/message.h> // Message #include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/text_format.h> #include "butil/logging.h" // LOG() -#include "butil/time.h" #include "butil/iobuf.h" // butil::IOBuf #include "butil/raw_pack.h" // RawPacker RawUnpacker #include "butil/memory/scope_guard.h" +#include "json2pb/json_to_pb.h" +#include "json2pb/pb_to_json.h" #include "brpc/controller.h" // Controller #include "brpc/socket.h" // Socket #include "brpc/server.h" // Server @@ -56,6 +58,8 @@ DEFINE_bool(baidu_protocol_use_fullname, true, DEFINE_bool(baidu_std_protocol_deliver_timeout_ms, false, "If this flag is true, baidu_std puts timeout_ms in requests."); +DECLARE_bool(pb_enum_as_number); + // Notes: // 1. 12-byte header [PRPC][body_size][meta_size] // 2. body_size and meta_size are in network byte order @@ -137,23 +141,95 @@ ParseResult ParseRpcMessage(butil::IOBuf* source, Socket* socket, return MakeMessage(msg); } +bool SerializeRpcMessage(const google::protobuf::Message& message, + Controller& cntl, ContentType content_type, + CompressType compress_type, butil::IOBuf* buf) { + auto serialize = [&](Serializer& serializer) -> bool { + bool ok; + if (COMPRESS_TYPE_NONE == compress_type) { + butil::IOBufAsZeroCopyOutputStream stream(buf); + ok = serializer.SerializeTo(&stream); + } else { + const CompressHandler* handler = FindCompressHandler(compress_type); + if (NULL == handler) { + return false; + } + ok = handler->Compress(serializer, buf); + } + return ok; + }; + + if (CONTENT_TYPE_PB == content_type) { + Serializer serializer([&message](google::protobuf::io::ZeroCopyOutputStream* output) -> bool { + return message.SerializeToZeroCopyStream(output); + }); + return serialize(serializer); + } else if (CONTENT_TYPE_JSON == content_type) { + Serializer serializer([&message, &cntl](google::protobuf::io::ZeroCopyOutputStream* output) -> bool { + json2pb::Pb2JsonOptions options; + options.bytes_to_base64 = cntl.has_pb_bytes_to_base64(); + options.jsonify_empty_array = cntl.has_pb_jsonify_empty_array(); + options.always_print_primitive_fields = cntl.has_always_print_primitive_fields(); + options.single_repeated_to_array = cntl.has_pb_single_repeated_to_array(); + options.enum_option = FLAGS_pb_enum_as_number + ? json2pb::OUTPUT_ENUM_BY_NUMBER + : json2pb::OUTPUT_ENUM_BY_NAME; + std::string error; + bool ok = json2pb::ProtoMessageToJson(message, output, options, &error); + if (!ok) { + LOG(INFO) << "Fail to serialize message=" + << message.GetDescriptor()->full_name() + << " to json :" << error; + } + return ok; + }); + return serialize(serializer); + } else if (CONTENT_TYPE_PROTO_JSON == content_type) { + Serializer serializer([&message, &cntl](google::protobuf::io::ZeroCopyOutputStream* output) -> bool { + json2pb::Pb2ProtoJsonOptions options; + options.always_print_enums_as_ints = FLAGS_pb_enum_as_number; + AlwaysPrintPrimitiveFields(options) = cntl.has_always_print_primitive_fields(); + std::string error; + bool ok = json2pb::ProtoMessageToProtoJson(message, output, options, &error); + if (!ok) { + LOG(INFO) << "Fail to serialize message=" + << message.GetDescriptor()->full_name() + << " to proto-json :" << error; + } + return ok; + }); + return serialize(serializer); + } else if (CONTENT_TYPE_PROTO_TEXT == content_type) { + Serializer serializer([&message](google::protobuf::io::ZeroCopyOutputStream* output) -> bool { + return google::protobuf::TextFormat::Print(message, output); + }); + return serialize(serializer); + } + return false; +} + static bool SerializeResponse(const google::protobuf::Message& res, - Controller& cntl, CompressType compress_type, - butil::IOBuf& buf) { + Controller& cntl, butil::IOBuf& buf) { if (res.GetDescriptor() == SerializedResponse::descriptor()) { buf.swap(((SerializedResponse&)res).serialized_data()); return true; } if (!res.IsInitialized()) { - cntl.SetFailed(ERESPONSE, - "Missing required fields in response: %s", + cntl.SetFailed(ERESPONSE, "Missing required fields in response: %s", res.InitializationErrorString().c_str()); return false; - } else if (!SerializeAsCompressedData(res, &buf, compress_type)) { - cntl.SetFailed(ERESPONSE, - "Fail to serialize response, CompressType=%s", - CompressTypeToCStr(compress_type)); + } + + ContentType content_type = cntl.response_content_type(); + CompressType compress_type = cntl.response_compress_type(); + if (!SerializeRpcMessage(res, cntl, content_type, compress_type, &buf)) { + cntl.SetFailed( + ERESPONSE, "Fail to serialize response=%s, " + "ContentType=%s, CompressType=%s", + res.GetDescriptor()->full_name().c_str(), + ContentTypeToCStr(content_type), + CompressTypeToCStr(compress_type)); return false; } return true; @@ -234,8 +310,7 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, // response either CompressType compress_type = cntl->response_compress_type(); if (res != NULL && !cntl->Failed()) { - append_body = SerializeResponse( - *res, *cntl, compress_type, res_body); + append_body = SerializeResponse(*res, *cntl, res_body); } // Don't use res->ByteSize() since it may be compressed @@ -262,6 +337,7 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, } meta.set_correlation_id(correlation_id); meta.set_compress_type(compress_type); + meta.set_content_type(cntl->response_content_type()); if (attached_size > 0) { meta.set_attachment_size(attached_size); } @@ -408,6 +484,71 @@ void EndRunningCallMethodInPool( return EndRunningUserCodeInPool(CallMethodInBackupThread, args); }; +bool DeserializeRpcMessage(const butil::IOBuf& data, Controller& cntl, + ContentType content_type, CompressType compress_type, + google::protobuf::Message* message) { + auto deserialize = [&](Deserializer& deserializer) -> bool { + bool ok; + if (COMPRESS_TYPE_NONE == compress_type) { + butil::IOBufAsZeroCopyInputStream stream(data); + ok = deserializer.DeserializeFrom(&stream); + } else { + const CompressHandler* handler = FindCompressHandler(compress_type); + if (NULL == handler) { + return false; + } + ok = handler->Decompress(data, &deserializer); + } + return ok; + }; + + if (CONTENT_TYPE_PB == content_type) { + Deserializer deserializer([message]( + google::protobuf::io::ZeroCopyInputStream* input) -> bool { + return message->ParseFromZeroCopyStream(input); + }); + return deserialize(deserializer); + } else if (CONTENT_TYPE_JSON == content_type) { + Deserializer deserializer([message, &cntl]( + google::protobuf::io::ZeroCopyInputStream* input) -> bool { + json2pb::Json2PbOptions options; + options.base64_to_bytes = cntl.has_pb_bytes_to_base64(); + options.array_to_single_repeated = cntl.has_pb_single_repeated_to_array(); + std::string error; + bool ok = json2pb::JsonToProtoMessage(input, message, options, &error); + if (!ok) { + LOG(INFO) << "Fail to parse json to " + << message->GetDescriptor()->full_name() + << ": "<< error; + } + return ok; + }); + return deserialize(deserializer); + } else if (CONTENT_TYPE_PROTO_JSON == content_type) { + Deserializer deserializer([message]( + google::protobuf::io::ZeroCopyInputStream* input) -> bool { + json2pb::ProtoJson2PbOptions options; + options.ignore_unknown_fields = true; + std::string error; + bool ok = json2pb::ProtoJsonToProtoMessage(input, message, options, &error); + if (!ok) { + LOG(INFO) << "Fail to parse proto-json to " + << message->GetDescriptor()->full_name() + << ": "<< error; + } + return ok; + }); + return deserialize(deserializer); + } else if (CONTENT_TYPE_PROTO_TEXT == content_type) { + Deserializer deserializer([message]( + google::protobuf::io::ZeroCopyInputStream* input) -> bool { + return google::protobuf::TextFormat::Parse(input, message); + }); + return deserialize(deserializer); + } + return false; +} + void ProcessRpcRequest(InputMessageBase* msg_base) { const int64_t start_parse_us = butil::cpuwide_time_us(); DestroyingPtr<MostCommonMessage> msg(static_cast<MostCommonMessage*>(msg_base)); @@ -458,6 +599,7 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { if (request_meta.has_timeout_ms()) { cntl->set_timeout_ms(request_meta.timeout_ms()); } + cntl->set_request_content_type(meta.content_type()); cntl->set_request_compress_type((CompressType)meta.compress_type()); accessor.set_server(server) .set_security_mode(security_mode) @@ -640,12 +782,17 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { cntl->request_attachment().swap(msg->payload); } - auto req_cmp_type = static_cast<CompressType>(meta.compress_type()); + ContentType content_type = meta.content_type(); + auto compress_type = static_cast<CompressType>(meta.compress_type()); messages = server->options().rpc_pb_message_factory->Get(*svc, *method); - if (!ParseFromCompressedData(req_buf, messages->Request(), req_cmp_type)) { - cntl->SetFailed(EREQUEST, "Fail to parse request message, " - "CompressType=%s, request_size=%d", - CompressTypeToCStr(req_cmp_type), req_size); + if (!DeserializeRpcMessage(req_buf, *cntl, content_type, + compress_type, messages->Request())) { + cntl->SetFailed( + EREQUEST, "Fail to parse request=%s, ContentType=%s, " + "CompressType=%s, request_size=%d", + messages->Request()->GetDescriptor()->full_name().c_str(), + ContentTypeToCStr(content_type), + CompressTypeToCStr(compress_type), req_size); break; } req_buf.clear(); @@ -654,11 +801,9 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { // `socket' will be held until response has been sent google::protobuf::Closure* done = ::brpc::NewCallback< int64_t, Controller*, RpcPBMessages*, - const Server*, MethodStatus*, int64_t>(&SendRpcResponse, - meta.correlation_id(), - cntl.get(), messages, - server, method_status, - msg->received_us()); + const Server*, MethodStatus*, int64_t>( + &SendRpcResponse, meta.correlation_id(),cntl.get(), + messages, server, method_status, msg->received_us()); // optional, just release resource ASAP msg.reset(); @@ -799,8 +944,7 @@ void ProcessRpcResponse(InputMessageBase* msg_base) { if (meta.has_attachment_size()) { if (meta.attachment_size() > res_size) { cntl->SetFailed( - ERESPONSE, - "attachment_size=%d is larger than response_size=%d", + ERESPONSE, "attachment_size=%d is larger than response_size=%d", meta.attachment_size(), res_size); break; } @@ -810,20 +954,24 @@ void ProcessRpcResponse(InputMessageBase* msg_base) { cntl->response_attachment().swap(msg->payload); } - auto res_cmp_type = (CompressType)meta.compress_type(); - cntl->set_response_compress_type(res_cmp_type); + ContentType content_type = meta.content_type(); + auto compress_type = (CompressType)meta.compress_type(); + cntl->set_response_content_type(content_type); + cntl->set_response_compress_type(compress_type); if (cntl->response()) { if (cntl->response()->GetDescriptor() == SerializedResponse::descriptor()) { ((SerializedResponse*)cntl->response())-> serialized_data().append(*res_buf_ptr); - } else if (!ParseFromCompressedData( - *res_buf_ptr, cntl->response(), res_cmp_type)) { + } else if (!DeserializeRpcMessage(*res_buf_ptr, *cntl, content_type, + compress_type, cntl->response())) { cntl->SetFailed( - ERESPONSE, "Fail to parse response message, " - "CompressType=%s, response_size=%d", - CompressTypeToCStr(res_cmp_type), res_size); + EREQUEST, "Fail to parse response=%s, ContentType=%s, " + "CompressType=%s, request_size=%d", + cntl->response()->GetDescriptor()->full_name().c_str(), + ContentTypeToCStr(content_type), + CompressTypeToCStr(compress_type), res_size); } - } // else silently ignore the response. + } // else silently ignore the response. } while (0); // Unlocks correlation_id inside. Revert controller's // error code if it version check of `cid' fails @@ -831,6 +979,33 @@ void ProcessRpcResponse(InputMessageBase* msg_base) { accessor.OnResponse(cid, saved_error); } +void SerializeRpcRequest(butil::IOBuf* request_buf, Controller* cntl, + const google::protobuf::Message* request) { + // Check sanity of request. + if (NULL == request) { + return cntl->SetFailed(EREQUEST, "`request' is NULL"); + } + if (request->GetDescriptor() == SerializedRequest::descriptor()) { + request_buf->append(((SerializedRequest*)request)->serialized_data()); + return; + } + if (!request->IsInitialized()) { + return cntl->SetFailed(EREQUEST, "Missing required fields in request: %s", + request->InitializationErrorString().c_str()); + } + + ContentType content_type = cntl->request_content_type(); + CompressType compress_type = cntl->request_compress_type(); + if (!SerializeRpcMessage(*request, *cntl, content_type, compress_type, request_buf)) { + return cntl->SetFailed( + EREQUEST, "Fail to compress request=%s, " + "ContentType=%s, CompressType=%s", + request->GetDescriptor()->full_name().c_str(), + ContentTypeToCStr(content_type), + CompressTypeToCStr(compress_type)); + } +} + void PackRpcRequest(butil::IOBuf* req_buf, SocketMessage**, uint64_t correlation_id, @@ -904,6 +1079,7 @@ void PackRpcRequest(butil::IOBuf* req_buf, request_meta->set_timeout_ms(accessor.real_timeout_ms()); } } + meta.set_content_type(cntl->request_content_type()); Span* span = accessor.span(); if (span) { @@ -919,5 +1095,20 @@ void PackRpcRequest(butil::IOBuf* req_buf, } } +const char* ContentTypeToCStr(ContentType content_type) { + switch (content_type) { + case CONTENT_TYPE_PB: + return "pb"; + case CONTENT_TYPE_JSON: + return "json"; + case CONTENT_TYPE_PROTO_JSON: + return "proto-json"; + case CONTENT_TYPE_PROTO_TEXT: + return "proto-text"; + default: + return "unknown"; + } +} + } // namespace policy } // namespace brpc diff --git a/src/brpc/policy/baidu_rpc_protocol.h b/src/brpc/policy/baidu_rpc_protocol.h index e3e4954b..77ecc780 100644 --- a/src/brpc/policy/baidu_rpc_protocol.h +++ b/src/brpc/policy/baidu_rpc_protocol.h @@ -37,6 +37,10 @@ void ProcessRpcResponse(InputMessageBase* msg); // Verify authentication information in baidu_std format bool VerifyRpcRequest(const InputMessageBase* msg); +// Serialize `request' into `buf'. +void SerializeRpcRequest(butil::IOBuf* request_buf, Controller* cntl, + const google::protobuf::Message* request); + // Pack `request' to `method' into `buf'. void PackRpcRequest(butil::IOBuf* buf, SocketMessage**, @@ -46,6 +50,9 @@ void PackRpcRequest(butil::IOBuf* buf, const butil::IOBuf& request, const Authenticator* auth); +// Returns the `name' of the 'content_type'. +const char* ContentTypeToCStr(ContentType content_type); + } // namespace policy } // namespace brpc diff --git a/src/brpc/policy/gzip_compress.cpp b/src/brpc/policy/gzip_compress.cpp index 35367e3e..e8c77a55 100644 --- a/src/brpc/policy/gzip_compress.cpp +++ b/src/brpc/policy/gzip_compress.cpp @@ -17,57 +17,89 @@ #include <google/protobuf/io/gzip_stream.h> // GzipXXXStream +#include <google/protobuf/text_format.h> #include "butil/logging.h" #include "brpc/policy/gzip_compress.h" #include "brpc/protocol.h" - +#include "brpc/compress.h" namespace brpc { namespace policy { -static void LogError(const google::protobuf::io::GzipOutputStream& gzip) { - if (gzip.ZlibErrorMessage()) { - LOG(WARNING) << "Fail to decompress: " << gzip.ZlibErrorMessage(); - } else { - LOG(WARNING) << "Fail to decompress."; +const char* Format2CStr(google::protobuf::io::GzipOutputStream::Format format) { + switch (format) { + case google::protobuf::io::GzipOutputStream::GZIP: + return "gzip"; + case google::protobuf::io::GzipOutputStream::ZLIB: + return "zlib"; + default: + return "unknown"; } } -static void LogError(const google::protobuf::io::GzipInputStream& gzip) { - if (gzip.ZlibErrorMessage()) { - LOG(WARNING) << "Fail to decompress: " << gzip.ZlibErrorMessage(); - } else { - LOG(WARNING) << "Fail to decompress."; +const char* Format2CStr(google::protobuf::io::GzipInputStream::Format format) { + switch (format) { + case google::protobuf::io::GzipInputStream::GZIP: + return "gzip"; + case google::protobuf::io::GzipInputStream::ZLIB: + return "zlib"; + default: + return "unknown"; } } -bool GzipCompress(const google::protobuf::Message& msg, butil::IOBuf* buf) { +static bool Compress(const google::protobuf::Message& msg, butil::IOBuf* buf, + google::protobuf::io::GzipOutputStream::Format format) { butil::IOBufAsZeroCopyOutputStream wrapper(buf); - google::protobuf::io::GzipOutputStream::Options gzip_opt; - gzip_opt.format = google::protobuf::io::GzipOutputStream::GZIP; - google::protobuf::io::GzipOutputStream gzip(&wrapper, gzip_opt); - if (!msg.SerializeToZeroCopyStream(&gzip)) { - LogError(gzip); - return false; + GzipCompressOptions options; + options.format = format; + google::protobuf::io::GzipOutputStream gzip(&wrapper, options); + bool ok; + if (msg.GetDescriptor() == Serializer::descriptor()) { + ok = ((const Serializer&)msg).SerializeTo(&gzip); + } else { + ok = msg.SerializeToZeroCopyStream(&gzip); + } + if (!ok) { + LOG(WARNING) << "Fail to serialize input message=" + << msg.GetDescriptor()->full_name() + << ", format=" << Format2CStr(format) << " : " + << (NULL == gzip.ZlibErrorMessage() ? "" : gzip.ZlibErrorMessage()); } - return gzip.Close(); + return ok && gzip.Close(); } -bool GzipDecompress(const butil::IOBuf& data, google::protobuf::Message* msg) { +static bool Decompress(const butil::IOBuf& data, google::protobuf::Message* msg, + google::protobuf::io::GzipInputStream::Format format) { butil::IOBufAsZeroCopyInputStream wrapper(data); - google::protobuf::io::GzipInputStream gzip( - &wrapper, google::protobuf::io::GzipInputStream::GZIP); - if (!ParsePbFromZeroCopyStream(msg, &gzip)) { - LogError(gzip); - return false; + google::protobuf::io::GzipInputStream gzip(&wrapper, format); + bool ok; + if (msg->GetDescriptor() == Deserializer::descriptor()) { + ok = ((Deserializer*)msg)->DeserializeFrom(&gzip); + } else { + ok = msg->ParseFromZeroCopyStream(&gzip); } - return true; + if (!ok) { + LOG(WARNING) << "Fail to deserialize input message=" + << msg->GetDescriptor()->full_name() + << ", format=" << Format2CStr(format) << " : " + << (NULL == gzip.ZlibErrorMessage() ? "" : gzip.ZlibErrorMessage()); + } + return ok; +} + +bool GzipCompress(const google::protobuf::Message& msg, butil::IOBuf* buf) { + return Compress(msg, buf, google::protobuf::io::GzipOutputStream::GZIP); +} + +bool GzipDecompress(const butil::IOBuf& data, google::protobuf::Message* msg) { + return Decompress(data, msg, google::protobuf::io::GzipInputStream::GZIP); } bool GzipCompress(const butil::IOBuf& msg, butil::IOBuf* buf, const GzipCompressOptions* options_in) { butil::IOBufAsZeroCopyOutputStream wrapper(buf); - google::protobuf::io::GzipOutputStream::Options gzip_opt; + GzipCompressOptions gzip_opt; if (options_in) { gzip_opt = *options_in; } @@ -93,7 +125,8 @@ bool GzipCompress(const butil::IOBuf& msg, butil::IOBuf* buf, } if (size_in != 0 || (size_t)in.ByteCount() != msg.size()) { // If any stage is not fully consumed, something went wrong. - LogError(out); + LOG(WARNING) << "Fail to compress, format=" << Format2CStr(gzip_opt.format) + << " : " << out.ZlibErrorMessage(); return false; } if (size_out != 0) { @@ -132,7 +165,8 @@ inline bool GzipDecompressBase( // If any stage is not fully consumed, something went wrong. // Here we call in.Next addtitionally to make sure that the gzip // "blackbox" does not have buffer left. - LogError(in); + LOG(WARNING) << "Fail to decompress, format=" << Format2CStr(format) + << " : " << in.ZlibErrorMessage(); return false; } if (size_out != 0) { @@ -141,19 +175,13 @@ inline bool GzipDecompressBase( return true; } -bool ZlibCompress(const google::protobuf::Message& res, butil::IOBuf* buf) { - butil::IOBufAsZeroCopyOutputStream wrapper(buf); - google::protobuf::io::GzipOutputStream::Options zlib_opt; - zlib_opt.format = google::protobuf::io::GzipOutputStream::ZLIB; - google::protobuf::io::GzipOutputStream zlib(&wrapper, zlib_opt); - return res.SerializeToZeroCopyStream(&zlib) && zlib.Close(); +bool ZlibCompress(const google::protobuf::Message& msg, butil::IOBuf* buf) { + return Compress(msg, buf, google::protobuf::io::GzipOutputStream::ZLIB); } -bool ZlibDecompress(const butil::IOBuf& data, google::protobuf::Message* req) { - butil::IOBufAsZeroCopyInputStream wrapper(data); - google::protobuf::io::GzipInputStream zlib( - &wrapper, google::protobuf::io::GzipInputStream::ZLIB); - return ParsePbFromZeroCopyStream(req, &zlib); +bool ZlibDecompress(const butil::IOBuf& data, + google::protobuf::Message* msg) { + return Decompress(data, msg, google::protobuf::io::GzipInputStream::ZLIB); } bool GzipDecompress(const butil::IOBuf& data, butil::IOBuf* msg) { diff --git a/src/brpc/policy/snappy_compress.cpp b/src/brpc/policy/snappy_compress.cpp index 77e80170..8019b97b 100644 --- a/src/brpc/policy/snappy_compress.cpp +++ b/src/brpc/policy/snappy_compress.cpp @@ -20,32 +20,53 @@ #include "butil/third_party/snappy/snappy.h" #include "brpc/policy/snappy_compress.h" #include "brpc/protocol.h" - +#include "brpc/compress.h" namespace brpc { namespace policy { -bool SnappyCompress(const google::protobuf::Message& res, butil::IOBuf* buf) { +bool SnappyCompress(const google::protobuf::Message& msg, butil::IOBuf* buf) { butil::IOBuf serialized_pb; butil::IOBufAsZeroCopyOutputStream wrapper(&serialized_pb); - if (res.SerializeToZeroCopyStream(&wrapper)) { - butil::IOBufAsSnappySource source(serialized_pb); - butil::IOBufAsSnappySink sink(*buf); - return butil::snappy::Compress(&source, &sink); + bool ok; + if (msg.GetDescriptor() == Serializer::descriptor()) { + ok = ((const Serializer&)msg).SerializeTo(&wrapper); + } else { + ok = msg.SerializeToZeroCopyStream(&wrapper); + } + if (!ok) { + LOG(WARNING) << "Fail to serialize input pb=" + << msg.GetDescriptor()->full_name(); + return false; + } + + ok = SnappyCompress(serialized_pb, buf); + if (!ok) { + LOG(WARNING) << "Fail to snappy::Compress, size=" + << serialized_pb.size(); } - LOG(WARNING) << "Fail to serialize input pb=" << &res; - return false; + return ok; } -bool SnappyDecompress(const butil::IOBuf& data, google::protobuf::Message* req) { - butil::IOBufAsSnappySource source(data); +bool SnappyDecompress(const butil::IOBuf& data, google::protobuf::Message* msg) { butil::IOBuf binary_pb; - butil::IOBufAsSnappySink sink(binary_pb); - if (butil::snappy::Uncompress(&source, &sink)) { - return ParsePbFromIOBuf(req, binary_pb); + if (!SnappyDecompress(data, &binary_pb)) { + LOG(WARNING) << "Fail to snappy::Uncompress, size=" << data.size(); + return false; + } + + bool ok; + butil::IOBufAsZeroCopyInputStream stream(binary_pb); + if (msg->GetDescriptor() == Deserializer::descriptor()) { + ok = ((Deserializer*)msg)->DeserializeFrom(&stream); + } else { + ok = msg->ParseFromZeroCopyStream(&stream); + } + if (!ok) { + LOG(WARNING) << "Fail to eserialize input message=" + << msg->GetDescriptor()->full_name(); } - LOG(WARNING) << "Fail to snappy::Uncompress, size=" << data.size(); - return false; + return ok; } bool SnappyCompress(const butil::IOBuf& in, butil::IOBuf* out) { diff --git a/src/brpc/proto_base.proto b/src/brpc/proto_base.proto index 30033d49..b278ddb6 100644 --- a/src/brpc/proto_base.proto +++ b/src/brpc/proto_base.proto @@ -33,6 +33,9 @@ message NsheadMessageBase {} message SerializedRequestBase {} message SerializedResponseBase {} +message SerializerBase {} +message DeserializerBase {} + message ThriftFramedMessageBase {} service BaiduMasterServiceBase {} diff --git a/src/brpc/protocol.cpp b/src/brpc/protocol.cpp index e0468c22..9bb1fde3 100644 --- a/src/brpc/protocol.cpp +++ b/src/brpc/protocol.cpp @@ -130,17 +130,12 @@ void ListProtocols(std::vector<std::pair<ProtocolType, Protocol> >* vec) { } } -void SerializeRequestDefault(butil::IOBuf* buf, - Controller* cntl, +void SerializeRequestDefault(butil::IOBuf* buf, Controller* cntl, const google::protobuf::Message* request) { // Check sanity of request. if (!request) { return cntl->SetFailed(EREQUEST, "`request' is NULL"); } - if (request->GetDescriptor() == SerializedRequest::descriptor()) { - buf->append(((SerializedRequest*)request)->serialized_data()); - return; - } if (!request->IsInitialized()) { return cntl->SetFailed( EREQUEST, "Missing required fields in request: %s", diff --git a/src/brpc/redis.cpp b/src/brpc/redis.cpp index f8870ae5..9af036f8 100644 --- a/src/brpc/redis.cpp +++ b/src/brpc/redis.cpp @@ -35,7 +35,7 @@ RedisRequest::RedisRequest() } RedisRequest::RedisRequest(const RedisRequest& from) - : NonreflectableMessage<RedisRequest>() { + : NonreflectableMessage<RedisRequest>(from) { SharedCtor(); MergeFrom(from); } @@ -200,7 +200,7 @@ RedisResponse::RedisResponse() SharedCtor(); } RedisResponse::RedisResponse(const RedisResponse& from) - : NonreflectableMessage<RedisResponse>() + : NonreflectableMessage<RedisResponse>(from) , _first_reply(&_arena) { SharedCtor(); MergeFrom(from); diff --git a/src/brpc/serialized_request.cpp b/src/brpc/serialized_request.cpp index f4cabad2..ac55e31e 100644 --- a/src/brpc/serialized_request.cpp +++ b/src/brpc/serialized_request.cpp @@ -28,7 +28,7 @@ SerializedRequest::SerializedRequest() } SerializedRequest::SerializedRequest(const SerializedRequest& from) - : NonreflectableMessage<SerializedRequest>() { + : NonreflectableMessage<SerializedRequest>(from) { SharedCtor(); MergeFrom(from); } diff --git a/src/brpc/serialized_request.h b/src/brpc/serialized_request.h index 00d959f3..4d69aa42 100644 --- a/src/brpc/serialized_request.h +++ b/src/brpc/serialized_request.h @@ -54,7 +54,6 @@ private: void SharedCtor(); void SharedDtor(); -private: butil::IOBuf _serialized; }; diff --git a/src/brpc/serialized_response.cpp b/src/brpc/serialized_response.cpp index 6d5d8fef..c8466451 100644 --- a/src/brpc/serialized_response.cpp +++ b/src/brpc/serialized_response.cpp @@ -28,7 +28,7 @@ SerializedResponse::SerializedResponse() } SerializedResponse::SerializedResponse(const SerializedResponse& from) - : NonreflectableMessage<SerializedResponse>() { + : NonreflectableMessage<SerializedResponse>(from) { SharedCtor(); MergeFrom(from); } diff --git a/src/brpc/serialized_response.h b/src/brpc/serialized_response.h index a724be4d..acd18a2a 100644 --- a/src/brpc/serialized_response.h +++ b/src/brpc/serialized_response.h @@ -54,7 +54,6 @@ private: void SharedCtor(); void SharedDtor(); -private: butil::IOBuf _serialized; }; diff --git a/src/json2pb/json_to_pb.cpp b/src/json2pb/json_to_pb.cpp index 53887a38..42d4772e 100644 --- a/src/json2pb/json_to_pb.cpp +++ b/src/json2pb/json_to_pb.cpp @@ -718,14 +718,11 @@ bool ProtoJsonToProtoMessage(google::protobuf::io::ZeroCopyInputStream* json, const ProtoJson2PbOptions& options, std::string* error) { TypeResolverUniqueptr type_resolver = GetTypeResolver(*message); + std::string type_url = GetTypeUrl(*message); butil::IOBuf buf; butil::IOBufAsZeroCopyOutputStream output_stream(&buf); - std::string type_url = GetTypeUrl(*message); auto st = google::protobuf::util::JsonToBinaryStream( type_resolver.get(), type_url, json, &output_stream, options); - - butil::IOBufAsZeroCopyInputStream input_stream(buf); - google::protobuf::io::CodedInputStream decoder(&input_stream); if (!st.ok()) { if (NULL != error) { *error = st.ToString(); @@ -733,6 +730,8 @@ bool ProtoJsonToProtoMessage(google::protobuf::io::ZeroCopyInputStream* json, return false; } + butil::IOBufAsZeroCopyInputStream input_stream(buf); + google::protobuf::io::CodedInputStream decoder(&input_stream); bool ok = message->ParseFromCodedStream(&decoder); if (!ok && NULL != error) { *error = "Fail to ParseFromCodedStream"; diff --git a/src/json2pb/json_to_pb.h b/src/json2pb/json_to_pb.h index 78eb15b6..3734ef31 100644 --- a/src/json2pb/json_to_pb.h +++ b/src/json2pb/json_to_pb.h @@ -92,11 +92,11 @@ using ProtoJson2PbOptions = google::protobuf::util::JsonParseOptions; // See https://protobuf.dev/programming-guides/json/ for details. bool ProtoJsonToProtoMessage(google::protobuf::io::ZeroCopyInputStream* json, google::protobuf::Message* message, - const ProtoJson2PbOptions& options, + const ProtoJson2PbOptions& options = ProtoJson2PbOptions(), std::string* error = NULL); -// Use default GoogleJson2PbOptions. bool ProtoJsonToProtoMessage(const std::string& json, google::protobuf::Message* message, - const ProtoJson2PbOptions& options, std::string* error = NULL); + const ProtoJson2PbOptions& options = ProtoJson2PbOptions(), + std::string* error = NULL); } // namespace json2pb diff --git a/src/json2pb/pb_to_json.cpp b/src/json2pb/pb_to_json.cpp index 0dc94814..c23ccdf7 100644 --- a/src/json2pb/pb_to_json.cpp +++ b/src/json2pb/pb_to_json.cpp @@ -349,16 +349,15 @@ bool ProtoMessageToJson(const google::protobuf::Message& message, } bool ProtoMessageToProtoJson(const google::protobuf::Message& message, - google::protobuf::io::ZeroCopyOutputStream* json, - const Pb2ProtoJsonOptions& options, std::string* error) { - TypeResolverUniqueptr type_resolver = GetTypeResolver(message); + google::protobuf::io::ZeroCopyOutputStream* json, + const Pb2ProtoJsonOptions& options, std::string* error) { butil::IOBuf buf; butil::IOBufAsZeroCopyOutputStream output_stream(&buf); - google::protobuf::io::CodedOutputStream coded_stream(&output_stream); - if (!message.SerializeToCodedStream(&coded_stream)) { + if (!message.SerializeToZeroCopyStream(&output_stream)) { return false; } + TypeResolverUniqueptr type_resolver = GetTypeResolver(message); butil::IOBufAsZeroCopyInputStream input_stream(buf); auto st = google::protobuf::util::BinaryToJsonStream( type_resolver.get(), GetTypeUrl(message), &input_stream, json, options); diff --git a/src/json2pb/pb_to_json.h b/src/json2pb/pb_to_json.h index 33311ffb..8de63517 100644 --- a/src/json2pb/pb_to_json.h +++ b/src/json2pb/pb_to_json.h @@ -95,14 +95,21 @@ bool ProtoMessageToJson(const google::protobuf::Message& message, // See <google/protobuf/util/json_util.h> for details. using Pb2ProtoJsonOptions = google::protobuf::util::JsonOptions; +#if GOOGLE_PROTOBUF_VERSION >= 5026002 +#define AlwaysPrintPrimitiveFields(options) options.always_print_fields_with_no_presence +#else +#define AlwaysPrintPrimitiveFields(options) options.always_print_primitive_fields +#endif + // Convert protobuf `messge' to `json' in ProtoJSON format according to `options'. // See https://protobuf.dev/programming-guides/json/ for details. bool ProtoMessageToProtoJson(const google::protobuf::Message& message, google::protobuf::io::ZeroCopyOutputStream* json, - const Pb2ProtoJsonOptions& options, std::string* error = NULL); -// Using default GooglePb2JsonOptions. + const Pb2ProtoJsonOptions& options = Pb2ProtoJsonOptions(), + std::string* error = NULL); bool ProtoMessageToProtoJson(const google::protobuf::Message& message, std::string* json, - const Pb2ProtoJsonOptions& options, std::string* error = NULL); + const Pb2ProtoJsonOptions& options = Pb2ProtoJsonOptions(), + std::string* error = NULL); } // namespace json2pb #endif // BRPC_JSON2PB_PB_TO_JSON_H diff --git a/src/json2pb/protobuf_type_resolver.h b/src/json2pb/protobuf_type_resolver.h index 18993f18..a73a4231 100644 --- a/src/json2pb/protobuf_type_resolver.h +++ b/src/json2pb/protobuf_type_resolver.h @@ -35,6 +35,8 @@ inline std::string GetTypeUrl(const google::protobuf::Message& message) { message.GetDescriptor()->full_name().c_str()); } +// unique_ptr deleter for TypeResolver only deletes the object +// when it's not from the generated pool. class TypeResolverDeleter { public: explicit TypeResolverDeleter(bool is_generated_pool) diff --git a/test/brpc_http_rpc_protocol_unittest.cpp b/test/brpc_http_rpc_protocol_unittest.cpp index 578e8f89..0eca1532 100644 --- a/test/brpc_http_rpc_protocol_unittest.cpp +++ b/test/brpc_http_rpc_protocol_unittest.cpp @@ -189,7 +189,7 @@ protected: return msg; } - brpc::policy::HttpContext* MakePostJsonStdRequestMessage(const std::string& path) { + brpc::policy::HttpContext* MakePostProtoJsonRequestMessage(const std::string& path) { brpc::policy::HttpContext* msg = new brpc::policy::HttpContext(false); msg->header().uri().set_path(path); msg->header().set_content_type("application/proto-json"); @@ -199,7 +199,8 @@ protected: req.set_message(EXP_REQUEST); butil::IOBufAsZeroCopyOutputStream req_stream(&msg->body()); json2pb::Pb2ProtoJsonOptions options; - EXPECT_TRUE(json2pb::ProtoMessageToProtoJson(req, &req_stream, options)); + std::string error; + EXPECT_TRUE(json2pb::ProtoMessageToProtoJson(req, &req_stream, options, &error)) << error; return msg; } @@ -366,7 +367,7 @@ TEST_F(HttpTest, verify_request) { } { brpc::policy::HttpContext* msg = - MakePostJsonStdRequestMessage("/EchoService/Echo"); + MakePostProtoJsonRequestMessage("/EchoService/Echo"); VerifyMessage(msg, false); msg->Destroy(); } @@ -1847,7 +1848,7 @@ TEST_F(HttpTest, proto_json_content_type) { butil::IOBufAsZeroCopyOutputStream output_stream(&cntl.request_attachment()); ASSERT_TRUE(json2pb::ProtoMessageToProtoJson(req, &output_stream, json_options)); channel.CallMethod(nullptr, &cntl, nullptr, nullptr, nullptr); - ASSERT_FALSE(cntl.Failed()); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); ASSERT_EQ("application/proto-json", cntl.http_response().content_type()); json2pb::ProtoJson2PbOptions parse_options; parse_options.ignore_unknown_fields = true; @@ -1860,7 +1861,7 @@ TEST_F(HttpTest, proto_json_content_type) { cntl.http_request().set_content_type("application/proto-json"); res.Clear(); stub.Echo(&cntl, &req, &res, nullptr); - ASSERT_FALSE(cntl.Failed()); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); ASSERT_EQ(EXP_RESPONSE, res.message()); ASSERT_EQ("application/proto-json", cntl.http_response().content_type()); } diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 676f1662..a51b9317 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -70,6 +70,13 @@ DECLARE_bool(enable_dir_service); namespace policy { DECLARE_bool(use_http_error_code); + +extern bool SerializeRpcMessage(const google::protobuf::Message& serializer, Controller& cntl, + ContentType content_type, CompressType compress_type, + butil::IOBuf* buf); +extern bool DeserializeRpcMessage(const butil::IOBuf& deserializer, Controller& cntl, + ContentType content_type, CompressType compress_type, + google::protobuf::Message* message); } } @@ -1693,53 +1700,119 @@ public: cntl->sampled_request()->meta.service_name()); ASSERT_TRUE(cntl->sampled_request()->meta.has_method_name()); ASSERT_EQ("Echo", cntl->sampled_request()->meta.method_name()); + brpc::ContentType content_type = cntl->request_content_type(); + brpc::CompressType compress_type = cntl->request_compress_type(); + test::EchoRequest echo_request; test::EchoResponse echo_response; - brpc::CompressType type = cntl->request_compress_type(); - ASSERT_TRUE(brpc::ParseFromCompressedData( - request->serialized_data(), &echo_request, type)); + ASSERT_TRUE(brpc::policy::DeserializeRpcMessage( + request->serialized_data(), *cntl, content_type, compress_type, &echo_request)); ASSERT_EQ(EXP_REQUEST, echo_request.message()); ASSERT_EQ(EXP_REQUEST, cntl->request_attachment().to_string()); - echo_response.set_message(EXP_RESPONSE); - butil::IOBuf compressed_data; - ASSERT_TRUE(brpc::SerializeAsCompressedData( - echo_response, &response->serialized_data(), type)); - cntl->set_response_compress_type(type); + content_type = (brpc::ContentType)_content_type_index; + compress_type = (brpc::CompressType)_compress_type_index; + ++_compress_type_index; + if (_compress_type_index == brpc::COMPRESS_TYPE_LZ4) { + ++_compress_type_index; + } + if (_compress_type_index > brpc::CompressType_MAX) { + _compress_type_index = brpc::CompressType_MIN; + + ++_content_type_index; + if (_content_type_index > brpc::ContentType_MAX) { + _content_type_index = brpc::ContentType_MIN; + } + } + + cntl->set_response_content_type(content_type); + cntl->set_response_compress_type(compress_type); cntl->response_attachment().append(EXP_RESPONSE); + echo_response.set_message(EXP_RESPONSE); + ASSERT_TRUE(brpc::policy::SerializeRpcMessage( + echo_response, *cntl, content_type, compress_type, &response->serialized_data())); } +private: + int _content_type_index = brpc::ContentType_MIN; + int _compress_type_index = brpc::CompressType_MIN; }; -TEST_F(ServerTest, baidu_master_service) { - butil::EndPoint ep; - ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); - brpc::Server server; - EchoServiceImpl service; - ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); - brpc::ServerOptions opt; - opt.baidu_master_service = new BaiduMasterServiceImpl; - ASSERT_EQ(0, server.Start(ep, &opt)); - - brpc::Channel chan; - brpc::ChannelOptions copt; - copt.protocol = "baidu_std"; - ASSERT_EQ(0, chan.Init(ep, &copt)); +void TestBaiduMasterService(brpc::Channel& channel, brpc::CompressType compress_type) { brpc::Controller cntl; test::EchoRequest req; test::EchoResponse res; req.set_message(EXP_REQUEST); cntl.request_attachment().append(EXP_REQUEST); - cntl.set_request_compress_type(brpc::COMPRESS_TYPE_GZIP); - test::EchoService_Stub stub(&chan); + cntl.set_request_compress_type(compress_type); + test::EchoService_Stub stub(&channel); stub.Echo(&cntl, &req, &res, NULL); ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); ASSERT_EQ(EXP_RESPONSE, res.message()); ASSERT_EQ(EXP_RESPONSE, cntl.response_attachment().to_string()); +} + +TEST_F(ServerTest, baidu_master_service) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions server_options; + server_options.baidu_master_service = new BaiduMasterServiceImpl; + ASSERT_EQ(0, server.Start(ep, &server_options)); + + brpc::Channel channel; + brpc::ChannelOptions channel_options; + channel_options.protocol = "baidu_std"; + ASSERT_EQ(0, channel.Init(ep, &channel_options)); + + for (int i = 0; i < 10; ++i) { + TestBaiduMasterService(channel, brpc::COMPRESS_TYPE_ZLIB); + TestBaiduMasterService(channel, brpc::COMPRESS_TYPE_GZIP); + TestBaiduMasterService(channel, brpc::COMPRESS_TYPE_SNAPPY); + TestBaiduMasterService(channel, brpc::COMPRESS_TYPE_NONE); + } ASSERT_EQ(0, server.Stop(0)); ASSERT_EQ(0, server.Join()); } +void TestGenericCall(brpc::Channel& channel, + brpc::ContentType content_type, + brpc::CompressType compress_type) { + LOG(INFO) << "TestGenericCall: content_type=" << content_type + << ", compress_type=" << compress_type; + test::EchoRequest request; + test::EchoResponse response; + request.set_message(EXP_REQUEST); + + brpc::SerializedResponse serialized_response; + brpc::SerializedRequest serialized_request; + + brpc::Controller cntl; + cntl.set_request_content_type(content_type); + cntl.set_request_compress_type(compress_type); + cntl.request_attachment().append(EXP_REQUEST); + + std::string error; + ASSERT_TRUE(brpc::policy::SerializeRpcMessage( + request, cntl, content_type, compress_type, &serialized_request.serialized_data())); + auto sampled_request = new (std::nothrow) brpc::SampledRequest(); + sampled_request->meta.set_service_name( + test::EchoService::descriptor()->full_name()); + sampled_request->meta.set_method_name( + test::EchoService::descriptor()->FindMethodByName("Echo")->name()); + cntl.reset_sampled_request(sampled_request); + + channel.CallMethod(NULL, &cntl, &serialized_request, &serialized_response, NULL); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + + ASSERT_TRUE(brpc::policy::DeserializeRpcMessage(serialized_response.serialized_data(), + cntl, cntl.response_content_type(), + cntl.response_compress_type(), &response)); + ASSERT_EQ(EXP_RESPONSE, response.message()); + ASSERT_EQ(EXP_RESPONSE, cntl.response_attachment().to_string()); +} TEST_F(ServerTest, generic_call) { butil::EndPoint ep; @@ -1747,42 +1820,34 @@ TEST_F(ServerTest, generic_call) { brpc::Server server; EchoServiceImpl service; ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); - brpc::ServerOptions opt; - opt.baidu_master_service = new BaiduMasterServiceImpl; - ASSERT_EQ(0, server.Start(ep, &opt)); - - { - brpc::Channel chan; - brpc::ChannelOptions copt; - copt.protocol = "baidu_std"; - ASSERT_EQ(0, chan.Init(ep, &copt)); - brpc::Controller cntl; - test::EchoRequest req; - test::EchoResponse res; - req.set_message(EXP_REQUEST); - - brpc::SerializedResponse serialized_response; - brpc::SerializedRequest serialized_request; - brpc::CompressType type = brpc::COMPRESS_TYPE_GZIP; - ASSERT_TRUE(brpc::SerializeAsCompressedData( - req, &serialized_request.serialized_data(), type)); - cntl.request_attachment().append(EXP_REQUEST); - cntl.set_request_compress_type(type); - auto sampled_request = new (std::nothrow) brpc::SampledRequest(); - sampled_request->meta.set_service_name( - test::EchoService::descriptor()->full_name()); - sampled_request->meta.set_method_name( - test::EchoService::descriptor()->FindMethodByName("Echo")->name()); - cntl.reset_sampled_request(sampled_request); - chan.CallMethod(NULL, &cntl, &serialized_request, &serialized_response, NULL); - ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + brpc::ServerOptions server_options; + server_options.baidu_master_service = new BaiduMasterServiceImpl; + ASSERT_EQ(0, server.Start(ep, &server_options)); - ASSERT_TRUE(brpc::ParseFromCompressedData(serialized_response.serialized_data(), - &res, cntl.response_compress_type())) - << serialized_response.serialized_data().size(); - ASSERT_EQ(EXP_RESPONSE, res.message()); - ASSERT_EQ(EXP_RESPONSE, cntl.response_attachment().to_string()); - } + brpc::Channel channel; + brpc::ChannelOptions channel_options; + channel_options.protocol = "baidu_std"; + ASSERT_EQ(0, channel.Init(ep, &channel_options)); + + TestGenericCall(channel, brpc::CONTENT_TYPE_PB, brpc::COMPRESS_TYPE_ZLIB); + TestGenericCall(channel, brpc::CONTENT_TYPE_PB, brpc::COMPRESS_TYPE_GZIP); + TestGenericCall(channel, brpc::CONTENT_TYPE_PB, brpc::COMPRESS_TYPE_SNAPPY); + TestGenericCall(channel, brpc::CONTENT_TYPE_PB, brpc::COMPRESS_TYPE_NONE); + + TestGenericCall(channel, brpc::CONTENT_TYPE_JSON, brpc::COMPRESS_TYPE_ZLIB); + TestGenericCall(channel, brpc::CONTENT_TYPE_JSON, brpc::COMPRESS_TYPE_GZIP); + TestGenericCall(channel, brpc::CONTENT_TYPE_JSON, brpc::COMPRESS_TYPE_SNAPPY); + TestGenericCall(channel, brpc::CONTENT_TYPE_JSON, brpc::COMPRESS_TYPE_NONE); + + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_JSON, brpc::COMPRESS_TYPE_ZLIB); + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_JSON, brpc::COMPRESS_TYPE_GZIP); + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_JSON, brpc::COMPRESS_TYPE_SNAPPY); + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_JSON, brpc::COMPRESS_TYPE_NONE); + + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_TEXT, brpc::COMPRESS_TYPE_ZLIB); + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_TEXT, brpc::COMPRESS_TYPE_GZIP); + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_TEXT, brpc::COMPRESS_TYPE_SNAPPY); + TestGenericCall(channel, brpc::CONTENT_TYPE_PROTO_TEXT, brpc::COMPRESS_TYPE_NONE); ASSERT_EQ(0, server.Stop(0)); ASSERT_EQ(0, server.Join()); diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp index 0f35863f..8e9f90e8 100644 --- a/test/brpc_socket_unittest.cpp +++ b/test/brpc_socket_unittest.cpp @@ -421,6 +421,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) { ASSERT_EQ(-1, brpc::Socket::Address(id, &ptr)); messenger->StopAccept(0); + messenger->Join(); ASSERT_EQ(-1, messenger->listened_fd()); ASSERT_EQ(-1, fcntl(listening_fd, F_GETFD)); ASSERT_EQ(EBADF, errno); @@ -789,6 +790,7 @@ TEST_F(SocketTest, health_check) { // Must stop messenger before SetFailed the id otherwise StartHealthCheck // still has chance to get reconnected and revive the id. messenger->StopAccept(0); + messenger->Join(); ASSERT_EQ(-1, messenger->listened_fd()); ASSERT_EQ(-1, fcntl(listening_fd, F_GETFD)); ASSERT_EQ(EBADF, errno); @@ -1408,6 +1410,7 @@ TEST_F(SocketTest, keepalive_input_message) { } messenger->StopAccept(0); + messenger->Join(); ASSERT_EQ(-1, messenger->listened_fd()); ASSERT_EQ(-1, fcntl(listening_fd, F_GETFD)); ASSERT_EQ(EBADF, errno); @@ -1422,11 +1425,24 @@ void CheckTCPUserTimeout(int fd, int expect_tcp_user_timeout) { } TEST_F(SocketTest, tcp_user_timeout) { + brpc::Acceptor* messenger = new brpc::Acceptor; + int listening_fd = -1; + butil::EndPoint point(butil::IP_ANY, 7878); + for (int i = 0; i < 100; ++i) { + point.port += i; + listening_fd = tcp_listen(point); + if (listening_fd >= 0) { + break; + } + } + ASSERT_GT(listening_fd, 0) << berror(); + ASSERT_EQ(0, butil::make_non_blocking(listening_fd)); + ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false)); + { - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_GT(sockfd, 0); brpc::SocketOptions options; - options.fd = sockfd; + options.remote_side = point; + options.connect_on_create = true; brpc::SocketId id = brpc::INVALID_SOCKET_ID; ASSERT_EQ(0, brpc::Socket::Create(options, &id)); brpc::SocketUniquePtr ptr; @@ -1436,10 +1452,9 @@ TEST_F(SocketTest, tcp_user_timeout) { { int tcp_user_timeout_ms = 1000; - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_GT(sockfd, 0); brpc::SocketOptions options; - options.fd = sockfd; + options.remote_side = point; + options.connect_on_create = true; options.tcp_user_timeout_ms = tcp_user_timeout_ms; brpc::SocketId id = brpc::INVALID_SOCKET_ID; ASSERT_EQ(0, brpc::Socket::Create(options, &id)); @@ -1450,10 +1465,9 @@ TEST_F(SocketTest, tcp_user_timeout) { brpc::FLAGS_socket_tcp_user_timeout_ms = 2000; { - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_GT(sockfd, 0); brpc::SocketOptions options; - options.fd = sockfd; + options.remote_side = point; + options.connect_on_create = true; brpc::SocketId id = brpc::INVALID_SOCKET_ID; ASSERT_EQ(0, brpc::get_or_new_client_side_messenger()->Create(options, &id)); brpc::SocketUniquePtr ptr; @@ -1462,10 +1476,9 @@ TEST_F(SocketTest, tcp_user_timeout) { } { int tcp_user_timeout_ms = 3000; - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - ASSERT_GT(sockfd, 0); brpc::SocketOptions options; - options.fd = sockfd; + options.remote_side = point; + options.connect_on_create = true; options.tcp_user_timeout_ms = tcp_user_timeout_ms; brpc::SocketId id = brpc::INVALID_SOCKET_ID; ASSERT_EQ(0, brpc::get_or_new_client_side_messenger()->Create(options, &id)); @@ -1473,6 +1486,12 @@ TEST_F(SocketTest, tcp_user_timeout) { ASSERT_EQ(0, brpc::Socket::Address(id, &ptr)) << "id=" << id; CheckTCPUserTimeout(ptr->fd(), tcp_user_timeout_ms); } + + messenger->StopAccept(0); + messenger->Join(); + ASSERT_EQ(-1, messenger->listened_fd()); + ASSERT_EQ(-1, fcntl(listening_fd, F_GETFD)); + ASSERT_EQ(EBADF, errno); } #endif diff --git a/test/bthread_cond_unittest.cpp b/test/bthread_cond_unittest.cpp index 7342cea1..d01ef69c 100644 --- a/test/bthread_cond_unittest.cpp +++ b/test/bthread_cond_unittest.cpp @@ -397,6 +397,7 @@ private: bthread_mutex_t _mutex; }; +#ifndef BUTIL_USE_ASAN volatile bool g_stop = false; bool started_wait = false; bool ended_wait = false; @@ -449,7 +450,6 @@ TEST(CondTest, too_many_bthreads_from_pthread) { launch_many_bthreads(); } -#ifndef BUTIL_USE_ASAN static void* run_launch_many_bthreads(void*) { launch_many_bthreads(); return NULL; --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@brpc.apache.org For additional commands, e-mail: dev-h...@brpc.apache.org