Repository: thrift Updated Branches: refs/heads/master 262cfb418 -> cfaadcc4a
THRIFT-3231 CPP: Limit recursion depth to 64 Client: cpp Patch: Ben Craig <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/thrift/repo Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/cfaadcc4 Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/cfaadcc4 Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/cfaadcc4 Branch: refs/heads/master Commit: cfaadcc4adcfde2a8232c62ec89870b73ef40df1 Parents: 262cfb4 Author: Ben Craig <[email protected]> Authored: Wed Jul 8 20:50:33 2015 -0500 Committer: Ben Craig <[email protected]> Committed: Wed Jul 8 20:50:33 2015 -0500 ---------------------------------------------------------------------- compiler/cpp/src/generate/t_cpp_generator.cc | 18 +- lib/cpp/CMakeLists.txt | 3 +- lib/cpp/Makefile.am | 1 + lib/cpp/src/thrift/protocol/TProtocol.cpp | 33 ++++ lib/cpp/src/thrift/protocol/TProtocol.h | 223 ++++++++++++---------- 5 files changed, 166 insertions(+), 112 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/thrift/blob/cfaadcc4/compiler/cpp/src/generate/t_cpp_generator.cc ---------------------------------------------------------------------- diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc index 426434f..aed3935 100644 --- a/compiler/cpp/src/generate/t_cpp_generator.cc +++ b/compiler/cpp/src/generate/t_cpp_generator.cc @@ -1367,10 +1367,16 @@ void t_cpp_generator::generate_struct_reader(ofstream& out, t_struct* tstruct, b vector<t_field*>::const_iterator f_iter; // Declare stack tmp variables - out << endl << indent() << "uint32_t xfer = 0;" << endl << indent() << "std::string fname;" - << endl << indent() << "::apache::thrift::protocol::TType ftype;" << endl << indent() - << "int16_t fid;" << endl << endl << indent() << "xfer += iprot->readStructBegin(fname);" - << endl << endl << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl + out << endl + << indent() << "apache::thrift::protocol::TRecursionTracker tracker(*iprot);" << endl + << indent() << "uint32_t xfer = 0;" << endl + << indent() << "std::string fname;" << endl + << indent() << "::apache::thrift::protocol::TType ftype;" << endl + << indent() << "int16_t fid;" << endl + << endl + << indent() << "xfer += iprot->readStructBegin(fname);" << endl + << endl + << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl << endl; // Required variables aren't in __isset, so we need tmp vars to check them. @@ -1486,7 +1492,7 @@ void t_cpp_generator::generate_struct_writer(ofstream& out, t_struct* tstruct, b out << indent() << "uint32_t xfer = 0;" << endl; - indent(out) << "oprot->incrementRecursionDepth();" << endl; + indent(out) << "apache::thrift::protocol::TRecursionTracker tracker(*oprot);" << endl; indent(out) << "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { @@ -1522,7 +1528,7 @@ void t_cpp_generator::generate_struct_writer(ofstream& out, t_struct* tstruct, b // Write the struct map out << indent() << "xfer += oprot->writeFieldStop();" << endl << indent() << "xfer += oprot->writeStructEnd();" << endl << indent() - << "oprot->decrementRecursionDepth();" << endl << indent() << "return xfer;" << endl; + << "return xfer;" << endl; indent_down(); indent(out) << "}" << endl << endl; http://git-wip-us.apache.org/repos/asf/thrift/blob/cfaadcc4/lib/cpp/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt index b97e356..bab2e84 100755 --- a/lib/cpp/CMakeLists.txt +++ b/lib/cpp/CMakeLists.txt @@ -39,11 +39,12 @@ set( thriftcpp_SOURCES src/thrift/concurrency/TimerManager.cpp src/thrift/concurrency/Util.cpp src/thrift/processor/PeekProcessor.cpp + src/thrift/protocol/TBase64Utils.cpp src/thrift/protocol/TDebugProtocol.cpp src/thrift/protocol/TDenseProtocol.cpp src/thrift/protocol/TJSONProtocol.cpp - src/thrift/protocol/TBase64Utils.cpp src/thrift/protocol/TMultiplexedProtocol.cpp + src/thrift/protocol/TProtocol.cpp src/thrift/transport/TTransportException.cpp src/thrift/transport/TFDTransport.cpp src/thrift/transport/TSimpleFileTransport.cpp http://git-wip-us.apache.org/repos/asf/thrift/blob/cfaadcc4/lib/cpp/Makefile.am ---------------------------------------------------------------------- diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am index 9156577..0ecbeee 100755 --- a/lib/cpp/Makefile.am +++ b/lib/cpp/Makefile.am @@ -75,6 +75,7 @@ libthrift_la_SOURCES = src/thrift/TApplicationException.cpp \ src/thrift/protocol/TJSONProtocol.cpp \ src/thrift/protocol/TBase64Utils.cpp \ src/thrift/protocol/TMultiplexedProtocol.cpp \ + src/thrift/protocol/TProtocol.cpp \ src/thrift/transport/TTransportException.cpp \ src/thrift/transport/TFDTransport.cpp \ src/thrift/transport/TFileTransport.cpp \ http://git-wip-us.apache.org/repos/asf/thrift/blob/cfaadcc4/lib/cpp/src/thrift/protocol/TProtocol.cpp ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/protocol/TProtocol.cpp b/lib/cpp/src/thrift/protocol/TProtocol.cpp new file mode 100644 index 0000000..c378aca --- /dev/null +++ b/lib/cpp/src/thrift/protocol/TProtocol.cpp @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <thrift/protocol/TProtocol.h> + +namespace apache { +namespace thrift { +namespace protocol { + +TProtocol::~TProtocol() {} +uint32_t TProtocol::skip_virt(TType type) { + return ::apache::thrift::protocol::skip(*this, type); +} + +TProtocolFactory::~TProtocolFactory() {} + +}}} // apache::thrift::protocol http://git-wip-us.apache.org/repos/asf/thrift/blob/cfaadcc4/lib/cpp/src/thrift/protocol/TProtocol.h ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h index 9eec1ee..1aa2122 100644 --- a/lib/cpp/src/thrift/protocol/TProtocol.h +++ b/lib/cpp/src/thrift/protocol/TProtocol.h @@ -33,6 +33,7 @@ #include <string> #include <map> #include <vector> +#include <climits> // Use this to get around strict aliasing rules. // For example, uint64_t i = bitwise_cast<uint64_t>(returns_double()); @@ -199,105 +200,6 @@ enum TMessageType { T_ONEWAY = 4 }; - -/** - * Helper template for implementing TProtocol::skip(). - * - * Templatized to avoid having to make virtual function calls. - */ -template <class Protocol_> -uint32_t skip(Protocol_& prot, TType type) { - switch (type) { - case T_BOOL: { - bool boolv; - return prot.readBool(boolv); - } - case T_BYTE: { - int8_t bytev; - return prot.readByte(bytev); - } - case T_I16: { - int16_t i16; - return prot.readI16(i16); - } - case T_I32: { - int32_t i32; - return prot.readI32(i32); - } - case T_I64: { - int64_t i64; - return prot.readI64(i64); - } - case T_DOUBLE: { - double dub; - return prot.readDouble(dub); - } - case T_STRING: { - std::string str; - return prot.readBinary(str); - } - case T_STRUCT: { - uint32_t result = 0; - std::string name; - int16_t fid; - TType ftype; - result += prot.readStructBegin(name); - while (true) { - result += prot.readFieldBegin(name, ftype, fid); - if (ftype == T_STOP) { - break; - } - result += skip(prot, ftype); - result += prot.readFieldEnd(); - } - result += prot.readStructEnd(); - return result; - } - case T_MAP: { - uint32_t result = 0; - TType keyType; - TType valType; - uint32_t i, size; - result += prot.readMapBegin(keyType, valType, size); - for (i = 0; i < size; i++) { - result += skip(prot, keyType); - result += skip(prot, valType); - } - result += prot.readMapEnd(); - return result; - } - case T_SET: { - uint32_t result = 0; - TType elemType; - uint32_t i, size; - result += prot.readSetBegin(elemType, size); - for (i = 0; i < size; i++) { - result += skip(prot, elemType); - } - result += prot.readSetEnd(); - return result; - } - case T_LIST: { - uint32_t result = 0; - TType elemType; - uint32_t i, size; - result += prot.readListBegin(elemType, size); - for (i = 0; i < size; i++) { - result += skip(prot, elemType); - } - result += prot.readListEnd(); - return result; - } - case T_STOP: - case T_VOID: - case T_U64: - case T_UTF8: - case T_UTF16: - break; - } - return 0; -} - static const uint32_t DEFAULT_RECURSION_LIMIT = 64; /** @@ -316,7 +218,7 @@ static const uint32_t DEFAULT_RECURSION_LIMIT = 64; */ class TProtocol { public: - virtual ~TProtocol() {} + virtual ~TProtocol(); /** * Writing functions. @@ -641,7 +543,7 @@ public: T_VIRTUAL_CALL(); return skip_virt(type); } - virtual uint32_t skip_virt(TType type) { return ::apache::thrift::protocol::skip(*this, type); } + virtual uint32_t skip_virt(TType type); inline boost::shared_ptr<TTransport> getTransport() { return ptrans_; } @@ -657,10 +559,13 @@ public: } void decrementRecursionDepth() { --recursion_depth_; } + uint32_t getRecursionLimit() const {return recursion_limit_;} + void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;} protected: TProtocol(boost::shared_ptr<TTransport> ptrans) - : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT) {} + : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT) + {} boost::shared_ptr<TTransport> ptrans_; @@ -677,7 +582,7 @@ class TProtocolFactory { public: TProtocolFactory() {} - virtual ~TProtocolFactory() {} + virtual ~TProtocolFactory(); virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0; }; @@ -712,8 +617,116 @@ struct TNetworkLittleEndian static uint64_t fromWire64(uint64_t x) {return letohll(x);} }; +struct TRecursionTracker { + TProtocol &prot_; + TRecursionTracker(TProtocol &prot) : prot_(prot) { + prot_.incrementRecursionDepth(); + } + ~TRecursionTracker() { + prot_.decrementRecursionDepth(); + } +}; + +/** + * Helper template for implementing TProtocol::skip(). + * + * Templatized to avoid having to make virtual function calls. + */ +template <class Protocol_> +uint32_t skip(Protocol_& prot, TType type) { + TRecursionTracker tracker(prot); + + switch (type) { + case T_BOOL: { + bool boolv; + return prot.readBool(boolv); + } + case T_BYTE: { + int8_t bytev; + return prot.readByte(bytev); + } + case T_I16: { + int16_t i16; + return prot.readI16(i16); + } + case T_I32: { + int32_t i32; + return prot.readI32(i32); + } + case T_I64: { + int64_t i64; + return prot.readI64(i64); + } + case T_DOUBLE: { + double dub; + return prot.readDouble(dub); + } + case T_STRING: { + std::string str; + return prot.readBinary(str); + } + case T_STRUCT: { + uint32_t result = 0; + std::string name; + int16_t fid; + TType ftype; + result += prot.readStructBegin(name); + while (true) { + result += prot.readFieldBegin(name, ftype, fid); + if (ftype == T_STOP) { + break; + } + result += skip(prot, ftype); + result += prot.readFieldEnd(); + } + result += prot.readStructEnd(); + return result; + } + case T_MAP: { + uint32_t result = 0; + TType keyType; + TType valType; + uint32_t i, size; + result += prot.readMapBegin(keyType, valType, size); + for (i = 0; i < size; i++) { + result += skip(prot, keyType); + result += skip(prot, valType); + } + result += prot.readMapEnd(); + return result; + } + case T_SET: { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += prot.readSetBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(prot, elemType); + } + result += prot.readSetEnd(); + return result; + } + case T_LIST: { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += prot.readListBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(prot, elemType); + } + result += prot.readListEnd(); + return result; + } + case T_STOP: + case T_VOID: + case T_U64: + case T_UTF8: + case T_UTF16: + break; + } + return 0; } -} -} // apache::thrift::protocol + +}}} // apache::thrift::protocol #endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1
