THRIFT-4176: Implement threaded server for Rust Client: rs * Create a TIoChannel construct * Separate TTransport into TReadTransport and TWriteTransport * Restructure types to avoid shared ownership * Remove user-visible boxing and ref-counting * Replace TSimpleServer with a thread-pool based TServer
This closes #1255 Project: http://git-wip-us.apache.org/repos/asf/thrift/repo Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/0e22c362 Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/0e22c362 Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/0e22c362 Branch: refs/heads/master Commit: 0e22c362b967bd3765ee3da349faa789904a0707 Parents: 9db23b7 Author: Allen George <[email protected]> Authored: Mon Jan 30 07:15:00 2017 -0500 Committer: James E. King, III <[email protected]> Committed: Thu Apr 27 08:46:02 2017 -0400 ---------------------------------------------------------------------- .rustfmt.toml | 7 + .../cpp/src/thrift/generate/t_rs_generator.cc | 98 +- lib/rs/Cargo.toml | 3 +- lib/rs/src/autogen.rs | 2 +- lib/rs/src/errors.rs | 150 ++- lib/rs/src/lib.rs | 6 +- lib/rs/src/protocol/binary.rs | 440 ++++--- lib/rs/src/protocol/compact.rs | 1153 +++++++++++------- lib/rs/src/protocol/mod.rs | 532 ++++++-- lib/rs/src/protocol/multiplexed.rs | 96 +- lib/rs/src/protocol/stored.rs | 45 +- lib/rs/src/server/mod.rs | 18 +- lib/rs/src/server/multiplexed.rs | 70 +- lib/rs/src/server/simple.rs | 189 --- lib/rs/src/server/threaded.rs | 239 ++++ lib/rs/src/transport/buffered.rs | 302 +++-- lib/rs/src/transport/framed.rs | 267 ++-- lib/rs/src/transport/mem.rs | 213 ++-- lib/rs/src/transport/mod.rs | 273 ++++- lib/rs/src/transport/passthru.rs | 73 -- lib/rs/src/transport/socket.rs | 106 +- lib/rs/test/src/bin/kitchen_sink_client.rs | 72 +- lib/rs/test/src/bin/kitchen_sink_server.rs | 119 +- lib/rs/test/src/lib.rs | 5 +- test/rs/src/bin/test_client.rs | 216 ++-- test/rs/src/bin/test_server.rs | 225 ++-- tutorial/rs/README.md | 43 +- tutorial/rs/src/bin/tutorial_client.rs | 51 +- tutorial/rs/src/bin/tutorial_server.rs | 70 +- 29 files changed, 3209 insertions(+), 1874 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/.rustfmt.toml ---------------------------------------------------------------------- diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..2962d47 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,7 @@ +max_width = 100 +fn_args_layout = "Block" +array_layout = "Block" +where_style = "Rfc" +generics_indent = "Block" +fn_call_style = "Block" +reorder_imported_names = true http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/compiler/cpp/src/thrift/generate/t_rs_generator.cc ---------------------------------------------------------------------- diff --git a/compiler/cpp/src/thrift/generate/t_rs_generator.cc b/compiler/cpp/src/thrift/generate/t_rs_generator.cc index c34ed17..30f46f2 100644 --- a/compiler/cpp/src/thrift/generate/t_rs_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_rs_generator.cc @@ -31,10 +31,9 @@ using std::string; using std::vector; using std::set; -static const string endl = "\n"; // avoid ostream << std::endl flushes - -static const string SERVICE_RESULT_VARIABLE = "result_value"; -static const string RESULT_STRUCT_SUFFIX = "Result"; +static const string endl("\n"); // avoid ostream << std::endl flushes +static const string SERVICE_RESULT_VARIABLE("result_value"); +static const string RESULT_STRUCT_SUFFIX("Result"); static const string RUST_RESERVED_WORDS[] = { "abstract", "alignof", "as", "become", "box", "break", "const", "continue", @@ -55,6 +54,9 @@ const set<string> RUST_RESERVED_WORDS_SET( RUST_RESERVED_WORDS + sizeof(RUST_RESERVED_WORDS)/sizeof(RUST_RESERVED_WORDS[0]) ); +static const string SYNC_CLIENT_GENERIC_BOUND_VARS("<IP, OP>"); +static const string SYNC_CLIENT_GENERIC_BOUNDS("where IP: TInputProtocol, OP: TOutputProtocol"); + // FIXME: extract common TMessageIdentifier function // FIXME: have to_rust_type deal with Option @@ -364,9 +366,10 @@ private: ); // Return a string containing all the unpacked service call args given a service call function - // `t_function`. Prepends the args with `&mut self` and includes the arg types in the returned string, - // for example: `fn foo(&mut self, field_0: String)`. - string rust_sync_service_call_declaration(t_function* tfunc); + // `t_function`. Prepends the args with either `&mut self` or `&self` and includes the arg types + // in the returned string, for example: + // `fn foo(&mut self, field_0: String)`. + string rust_sync_service_call_declaration(t_function* tfunc, bool self_is_mutable); // Return a string containing all the unpacked service call args given a service call function // `t_function`. Only includes the arg names, each of which is prefixed with the optional prefix @@ -512,6 +515,9 @@ void t_rs_generator::render_attributes_and_includes() { // constructors take *all* struct parameters, which can trigger the "too many arguments" warning // some auto-gen'd types can be deeply nested. clippy recommends factoring them out which is hard to autogen f_gen_ << "#![cfg_attr(feature = \"cargo-clippy\", allow(too_many_arguments, type_complexity))]" << endl; + // prevent rustfmt from running against this file + // lines are too long, code is (thankfully!) not visual-indented, etc. + f_gen_ << "#![cfg_attr(rustfmt, rustfmt_skip)]" << endl; f_gen_ << endl; // add standard includes @@ -2050,7 +2056,7 @@ void t_rs_generator::render_sync_client_trait(t_service *tservice) { for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) { t_function* tfunc = (*func_iter); string func_name = service_call_client_function_name(tfunc); - string func_args = rust_sync_service_call_declaration(tfunc); + string func_args = rust_sync_service_call_declaration(tfunc, true); string func_return = to_rust_type(tfunc->get_returntype()); render_rustdoc((t_doc*) tfunc); f_gen_ << indent() << "fn " << func_name << func_args << " -> thrift::Result<" << func_return << ">;" << endl; @@ -2069,8 +2075,14 @@ void t_rs_generator::render_sync_client_marker_trait(t_service *tservice) { void t_rs_generator::render_sync_client_marker_trait_impls(t_service *tservice, const string &impl_struct_name) { f_gen_ << indent() - << "impl " << rust_namespace(tservice) << rust_sync_client_marker_trait_name(tservice) - << " for " << impl_struct_name + << "impl " + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << rust_namespace(tservice) << rust_sync_client_marker_trait_name(tservice) + << " for " + << impl_struct_name << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS << " {}" << endl; @@ -2081,11 +2093,19 @@ void t_rs_generator::render_sync_client_marker_trait_impls(t_service *tservice, } void t_rs_generator::render_sync_client_definition_and_impl(const string& client_impl_name) { + // render the definition for the client struct - f_gen_ << "pub struct " << client_impl_name << " {" << endl; + f_gen_ + << "pub struct " + << client_impl_name + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS + << " {" + << endl; indent_up(); - f_gen_ << indent() << "_i_prot: Box<TInputProtocol>," << endl; - f_gen_ << indent() << "_o_prot: Box<TOutputProtocol>," << endl; + f_gen_ << indent() << "_i_prot: IP," << endl; + f_gen_ << indent() << "_o_prot: OP," << endl; f_gen_ << indent() << "_sequence_number: i32," << endl; indent_down(); f_gen_ << "}" << endl; @@ -2093,7 +2113,16 @@ void t_rs_generator::render_sync_client_definition_and_impl(const string& client // render the struct implementation // this includes the new() function as well as the helper send/recv methods for each service call - f_gen_ << "impl " << client_impl_name << " {" << endl; + f_gen_ + << "impl " + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << client_impl_name + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS + << " {" + << endl; indent_up(); render_sync_client_lifecycle_functions(client_impl_name); indent_down(); @@ -2104,8 +2133,9 @@ void t_rs_generator::render_sync_client_definition_and_impl(const string& client void t_rs_generator::render_sync_client_lifecycle_functions(const string& client_struct) { f_gen_ << indent() - << "pub fn new(input_protocol: Box<TInputProtocol>, output_protocol: Box<TOutputProtocol>) -> " + << "pub fn new(input_protocol: IP, output_protocol: OP) -> " << client_struct + << SYNC_CLIENT_GENERIC_BOUND_VARS << " {" << endl; indent_up(); @@ -2121,11 +2151,20 @@ void t_rs_generator::render_sync_client_lifecycle_functions(const string& client } void t_rs_generator::render_sync_client_tthriftclient_impl(const string &client_impl_name) { - f_gen_ << indent() << "impl TThriftClient for " << client_impl_name << " {" << endl; + f_gen_ + << indent() + << "impl " + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " TThriftClient for " + << client_impl_name + << SYNC_CLIENT_GENERIC_BOUND_VARS + << " " + << SYNC_CLIENT_GENERIC_BOUNDS + << " {" << endl; indent_up(); - f_gen_ << indent() << "fn i_prot_mut(&mut self) -> &mut TInputProtocol { &mut *self._i_prot }" << endl; - f_gen_ << indent() << "fn o_prot_mut(&mut self) -> &mut TOutputProtocol { &mut *self._o_prot }" << endl; + f_gen_ << indent() << "fn i_prot_mut(&mut self) -> &mut TInputProtocol { &mut self._i_prot }" << endl; + f_gen_ << indent() << "fn o_prot_mut(&mut self) -> &mut TOutputProtocol { &mut self._o_prot }" << endl; f_gen_ << indent() << "fn sequence_number(&self) -> i32 { self._sequence_number }" << endl; f_gen_ << indent() @@ -2172,7 +2211,7 @@ string t_rs_generator::sync_client_marker_traits_for_extension(t_service *tservi void t_rs_generator::render_sync_send_recv_wrapper(t_function* tfunc) { string func_name = service_call_client_function_name(tfunc); - string func_decl_args = rust_sync_service_call_declaration(tfunc); + string func_decl_args = rust_sync_service_call_declaration(tfunc, true); string func_call_args = rust_sync_service_call_invocation(tfunc); string func_return = to_rust_type(tfunc->get_returntype()); @@ -2268,12 +2307,17 @@ void t_rs_generator::render_sync_recv(t_function* tfunc) { f_gen_ << indent() << "}" << endl; } -string t_rs_generator::rust_sync_service_call_declaration(t_function* tfunc) { +string t_rs_generator::rust_sync_service_call_declaration(t_function* tfunc, bool self_is_mutable) { ostringstream func_args; - func_args << "(&mut self"; + + if (self_is_mutable) { + func_args << "(&mut self"; + } else { + func_args << "(&self"; + } if (has_args(tfunc)) { - func_args << ", "; // put comma after "&mut self" + func_args << ", "; // put comma after "self" func_args << struct_to_declaration(tfunc->get_arglist(), T_ARGS); } @@ -2388,7 +2432,7 @@ void t_rs_generator::render_sync_handler_trait(t_service *tservice) { for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) { t_function* tfunc = (*func_iter); string func_name = service_call_handler_function_name(tfunc); - string func_args = rust_sync_service_call_declaration(tfunc); + string func_args = rust_sync_service_call_declaration(tfunc, false); string func_return = to_rust_type(tfunc->get_returntype()); render_rustdoc((t_doc*) tfunc); f_gen_ @@ -2472,7 +2516,7 @@ void t_rs_generator::render_sync_processor_definition_and_impl(t_service *tservi f_gen_ << indent() - << "fn process(&mut self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {" + << "fn process(&self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {" << endl; indent_up(); f_gen_ << indent() << "let message_ident = i_prot.read_message_begin()?;" << endl; @@ -2511,7 +2555,7 @@ void t_rs_generator::render_sync_process_delegation_functions(t_service *tservic f_gen_ << indent() << "fn " << function_name - << "(&mut self, " + << "(&self, " << "incoming_sequence_number: i32, " << "i_prot: &mut TInputProtocol, " << "o_prot: &mut TOutputProtocol) " @@ -2524,7 +2568,7 @@ void t_rs_generator::render_sync_process_delegation_functions(t_service *tservic << actual_processor << "::" << function_name << "(" - << "&mut self.handler, " + << "&self.handler, " << "incoming_sequence_number, " << "i_prot, " << "o_prot" @@ -2576,7 +2620,7 @@ void t_rs_generator::render_sync_process_function(t_function *tfunc, const strin << indent() << "pub fn process_" << rust_snake_case(tfunc->get_name()) << "<H: " << handler_type << ">" - << "(handler: &mut H, " + << "(handler: &H, " << sequence_number_param << ": i32, " << "i_prot: &mut TInputProtocol, " << output_protocol_param << ": &mut TOutputProtocol) " http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/Cargo.toml ---------------------------------------------------------------------- diff --git a/lib/rs/Cargo.toml b/lib/rs/Cargo.toml index 07c5e67..be34785 100644 --- a/lib/rs/Cargo.toml +++ b/lib/rs/Cargo.toml @@ -11,8 +11,9 @@ exclude = ["Makefile*", "test/**"] keywords = ["thrift"] [dependencies] +byteorder = "0.5.3" integer-encoding = "1.0.3" log = "~0.3.6" -byteorder = "0.5.3" +threadpool = "1.0" try_from = "0.2.0" http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/autogen.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/autogen.rs b/lib/rs/src/autogen.rs index 289c7be..54d4080 100644 --- a/lib/rs/src/autogen.rs +++ b/lib/rs/src/autogen.rs @@ -22,7 +22,7 @@ //! to implement required functionality. Users should never have to use code //! in this module directly. -use ::protocol::{TInputProtocol, TOutputProtocol}; +use protocol::{TInputProtocol, TOutputProtocol}; /// Specifies the minimum functionality an auto-generated client should provide /// to communicate with a Thrift server. http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/errors.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/errors.rs b/lib/rs/src/errors.rs index a6049d5..e36cb3b 100644 --- a/lib/rs/src/errors.rs +++ b/lib/rs/src/errors.rs @@ -21,7 +21,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::{error, fmt, io, string}; use try_from::TryFrom; -use ::protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType}; +use protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType}; // FIXME: should all my error structs impl error::Error as well? // FIXME: should all fields in TransportError, ProtocolError and ApplicationError be optional? @@ -198,8 +198,8 @@ impl Error { /// Create an `ApplicationError` from its wire representation. /// /// Application code **should never** call this method directly. - pub fn read_application_error_from_in_protocol(i: &mut TInputProtocol) - -> ::Result<ApplicationError> { + pub fn read_application_error_from_in_protocol(i: &mut TInputProtocol,) + -> ::Result<ApplicationError> { let mut message = "general remote error".to_owned(); let mut kind = ApplicationErrorKind::Unknown; @@ -212,7 +212,9 @@ impl Error { break; } - let id = field_ident.id.expect("sender should always specify id for non-STOP field"); + let id = field_ident + .id + .expect("sender should always specify id for non-STOP field"); match id { 1 => { @@ -222,8 +224,9 @@ impl Error { } 2 => { let remote_type_as_int = i.read_i32()?; - let remote_kind: ApplicationErrorKind = TryFrom::try_from(remote_type_as_int) - .unwrap_or(ApplicationErrorKind::Unknown); + let remote_kind: ApplicationErrorKind = + TryFrom::try_from(remote_type_as_int) + .unwrap_or(ApplicationErrorKind::Unknown); i.read_field_end()?; kind = remote_kind; } @@ -235,20 +238,23 @@ impl Error { i.read_struct_end()?; - Ok(ApplicationError { - kind: kind, - message: message, - }) + Ok( + ApplicationError { + kind: kind, + message: message, + }, + ) } /// Convert an `ApplicationError` into its wire representation and write /// it to the remote. /// /// Application code **should never** call this method directly. - pub fn write_application_error_to_out_protocol(e: &ApplicationError, - o: &mut TOutputProtocol) - -> ::Result<()> { - o.write_struct_begin(&TStructIdentifier { name: "TApplicationException".to_owned() })?; + pub fn write_application_error_to_out_protocol( + e: &ApplicationError, + o: &mut TOutputProtocol, + ) -> ::Result<()> { + o.write_struct_begin(&TStructIdentifier { name: "TApplicationException".to_owned() },)?; let message_field = TFieldIdentifier::new("message", TType::String, 1); let type_field = TFieldIdentifier::new("type", TType::I32, 2); @@ -303,19 +309,23 @@ impl Display for Error { impl From<String> for Error { fn from(s: String) -> Self { - Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: s, - }) + Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: s, + }, + ) } } impl<'a> From<&'a str> for Error { fn from(s: &'a str) -> Self { - Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: String::from(s), - }) + Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: String::from(s), + }, + ) } } @@ -418,10 +428,14 @@ impl TryFrom<i32> for TransportErrorKind { 5 => Ok(TransportErrorKind::NegativeSize), 6 => Ok(TransportErrorKind::SizeLimit), _ => { - Err(Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot convert {} to TransportErrorKind", from), - })) + Err( + Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot convert {} to TransportErrorKind", from), + }, + ), + ) } } } @@ -433,34 +447,44 @@ impl From<io::Error> for Error { io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionRefused | io::ErrorKind::NotConnected => { - Error::Transport(TransportError { - kind: TransportErrorKind::NotOpen, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::NotOpen, + message: err.description().to_owned(), + }, + ) } io::ErrorKind::AlreadyExists => { - Error::Transport(TransportError { - kind: TransportErrorKind::AlreadyOpen, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::AlreadyOpen, + message: err.description().to_owned(), + }, + ) } io::ErrorKind::TimedOut => { - Error::Transport(TransportError { - kind: TransportErrorKind::TimedOut, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::TimedOut, + message: err.description().to_owned(), + }, + ) } io::ErrorKind::UnexpectedEof => { - Error::Transport(TransportError { - kind: TransportErrorKind::EndOfFile, - message: err.description().to_owned(), - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::EndOfFile, + message: err.description().to_owned(), + }, + ) } _ => { - Error::Transport(TransportError { - kind: TransportErrorKind::Unknown, - message: err.description().to_owned(), // FIXME: use io error's debug string - }) + Error::Transport( + TransportError { + kind: TransportErrorKind::Unknown, + message: err.description().to_owned(), // FIXME: use io error's debug string + }, + ) } } } @@ -468,10 +492,12 @@ impl From<io::Error> for Error { impl From<string::FromUtf8Error> for Error { fn from(err: string::FromUtf8Error) -> Self { - Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::InvalidData, - message: err.description().to_owned(), // FIXME: use fmt::Error's debug string - }) + Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: err.description().to_owned(), // FIXME: use fmt::Error's debug string + }, + ) } } @@ -558,10 +584,14 @@ impl TryFrom<i32> for ProtocolErrorKind { 5 => Ok(ProtocolErrorKind::NotImplemented), 6 => Ok(ProtocolErrorKind::DepthLimit), _ => { - Err(Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot convert {} to ProtocolErrorKind", from), - })) + Err( + Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot convert {} to ProtocolErrorKind", from), + }, + ), + ) } } } @@ -668,10 +698,14 @@ impl TryFrom<i32> for ApplicationErrorKind { 9 => Ok(ApplicationErrorKind::InvalidProtocol), 10 => Ok(ApplicationErrorKind::UnsupportedClientType), _ => { - Err(Error::Application(ApplicationError { - kind: ApplicationErrorKind::Unknown, - message: format!("cannot convert {} to ApplicationErrorKind", from), - })) + Err( + Error::Application( + ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: format!("cannot convert {} to ApplicationErrorKind", from), + }, + ), + ) } } } http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/lib.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/lib.rs b/lib/rs/src/lib.rs index ad18721..7ebb10c 100644 --- a/lib/rs/src/lib.rs +++ b/lib/rs/src/lib.rs @@ -26,11 +26,12 @@ //! 4. server //! 5. autogen //! -//! The modules are layered as shown in the diagram below. The `generated` +//! The modules are layered as shown in the diagram below. The `autogen'd` //! layer is generated by the Thrift compiler's Rust plugin. It uses the //! types and functions defined in this crate to serialize and deserialize //! messages and implement RPC. Users interact with these types and services -//! by writing their own code on top. +//! by writing their own code that uses the auto-generated clients and +//! servers. //! //! ```text //! +-----------+ @@ -49,6 +50,7 @@ extern crate byteorder; extern crate integer_encoding; +extern crate threadpool; extern crate try_from; #[macro_use] http://git-wip-us.apache.org/repos/asf/thrift/blob/0e22c362/lib/rs/src/protocol/binary.rs ---------------------------------------------------------------------- diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs index 54613a5..e03ec94 100644 --- a/lib/rs/src/protocol/binary.rs +++ b/lib/rs/src/protocol/binary.rs @@ -16,14 +16,11 @@ // under the License. use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt}; -use std::cell::RefCell; use std::convert::From; -use std::io::{Read, Write}; -use std::rc::Rc; use try_from::TryFrom; -use ::{ProtocolError, ProtocolErrorKind}; -use ::transport::TTransport; +use {ProtocolError, ProtocolErrorKind}; +use transport::{TReadTransport, TWriteTransport}; use super::{TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType}; use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; @@ -41,32 +38,35 @@ const BINARY_PROTOCOL_VERSION_1: u32 = 0x80010000; /// Create and use a `TBinaryInputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let mut i_prot = TBinaryInputProtocol::new(transport, true); +/// let mut protocol = TBinaryInputProtocol::new(channel, true); /// -/// let recvd_bool = i_prot.read_bool().unwrap(); -/// let recvd_string = i_prot.read_string().unwrap(); +/// let recvd_bool = protocol.read_bool().unwrap(); +/// let recvd_string = protocol.read_string().unwrap(); /// ``` -pub struct TBinaryInputProtocol<'a> { +#[derive(Debug)] +pub struct TBinaryInputProtocol<T> +where + T: TReadTransport, +{ strict: bool, - transport: Rc<RefCell<Box<TTransport + 'a>>>, + transport: T, } -impl<'a> TBinaryInputProtocol<'a> { +impl<'a, T> TBinaryInputProtocol<T> +where + T: TReadTransport, +{ /// Create a `TBinaryInputProtocol` that reads bytes from `transport`. /// /// Set `strict` to `true` if all incoming messages contain the protocol /// version number in the protocol header. - pub fn new(transport: Rc<RefCell<Box<TTransport + 'a>>>, - strict: bool) -> TBinaryInputProtocol<'a> { + pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> { TBinaryInputProtocol { strict: strict, transport: transport, @@ -74,11 +74,14 @@ impl<'a> TBinaryInputProtocol<'a> { } } -impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { +impl<T> TInputProtocol for TBinaryInputProtocol<T> +where + T: TReadTransport, +{ #[cfg_attr(feature = "cargo-clippy", allow(collapsible_if))] fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { let mut first_bytes = vec![0; 4]; - self.transport.borrow_mut().read_exact(&mut first_bytes[..])?; + self.transport.read_exact(&mut first_bytes[..])?; // the thrift version header is intentionally negative // so the first check we'll do is see if the sign bit is set @@ -87,10 +90,14 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { // apparently we got a protocol-version header - check // it, and if it matches, read the rest of the fields if first_bytes[0..2] != [0x80, 0x01] { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::BadVersion, - message: format!("received bad version: {:?}", &first_bytes[0..2]), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::BadVersion, + message: format!("received bad version: {:?}", &first_bytes[0..2]), + }, + ), + ) } else { let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?; let name = self.read_string()?; @@ -103,17 +110,21 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { if self.strict { // we're in strict mode however, and that always // requires the protocol-version header to be written first - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::BadVersion, - message: format!("received bad version: {:?}", &first_bytes[0..2]), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::BadVersion, + message: format!("received bad version: {:?}", &first_bytes[0..2]), + }, + ), + ) } else { // in the non-strict version the first message field // is the message name. strings (byte arrays) are length-prefixed, // so we've just read the length in the first 4 bytes let name_size = BigEndian::read_i32(&first_bytes) as usize; let mut name_buf: Vec<u8> = Vec::with_capacity(name_size); - self.transport.borrow_mut().read_exact(&mut name_buf)?; + self.transport.read_exact(&mut name_buf)?; let name = String::from_utf8(name_buf)?; // read the rest of the fields @@ -143,7 +154,7 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { TType::Stop => Ok(0), _ => self.read_i16(), }?; - Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id)) + Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id),) } fn read_field_end(&mut self) -> ::Result<()> { @@ -151,9 +162,12 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { } fn read_bytes(&mut self) -> ::Result<Vec<u8>> { - let num_bytes = self.transport.borrow_mut().read_i32::<BigEndian>()? as usize; + let num_bytes = self.transport.read_i32::<BigEndian>()? as usize; let mut buf = vec![0u8; num_bytes]; - self.transport.borrow_mut().read_exact(&mut buf).map(|_| buf).map_err(From::from) + self.transport + .read_exact(&mut buf) + .map(|_| buf) + .map_err(From::from) } fn read_bool(&mut self) -> ::Result<bool> { @@ -165,23 +179,31 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { } fn read_i8(&mut self) -> ::Result<i8> { - self.transport.borrow_mut().read_i8().map_err(From::from) + self.transport.read_i8().map_err(From::from) } fn read_i16(&mut self) -> ::Result<i16> { - self.transport.borrow_mut().read_i16::<BigEndian>().map_err(From::from) + self.transport + .read_i16::<BigEndian>() + .map_err(From::from) } fn read_i32(&mut self) -> ::Result<i32> { - self.transport.borrow_mut().read_i32::<BigEndian>().map_err(From::from) + self.transport + .read_i32::<BigEndian>() + .map_err(From::from) } fn read_i64(&mut self) -> ::Result<i64> { - self.transport.borrow_mut().read_i64::<BigEndian>().map_err(From::from) + self.transport + .read_i64::<BigEndian>() + .map_err(From::from) } fn read_double(&mut self) -> ::Result<f64> { - self.transport.borrow_mut().read_f64::<BigEndian>().map_err(From::from) + self.transport + .read_f64::<BigEndian>() + .map_err(From::from) } fn read_string(&mut self) -> ::Result<String> { @@ -224,7 +246,7 @@ impl<'a> TInputProtocol for TBinaryInputProtocol<'a> { // fn read_byte(&mut self) -> ::Result<u8> { - self.transport.borrow_mut().read_u8().map_err(From::from) + self.transport.read_u8().map_err(From::from) } } @@ -240,8 +262,8 @@ impl TBinaryInputProtocolFactory { } impl TInputProtocolFactory for TBinaryInputProtocolFactory { - fn create<'a>(&mut self, transport: Rc<RefCell<Box<TTransport + 'a>>>) -> Box<TInputProtocol + 'a> { - Box::new(TBinaryInputProtocol::new(transport, true)) as Box<TInputProtocol + 'a> + fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send> { + Box::new(TBinaryInputProtocol::new(transport, true)) } } @@ -256,32 +278,35 @@ impl TInputProtocolFactory for TBinaryInputProtocolFactory { /// Create and use a `TBinaryOutputProtocol`. /// /// ```no_run -/// use std::cell::RefCell; -/// use std::rc::Rc; /// use thrift::protocol::{TBinaryOutputProtocol, TOutputProtocol}; -/// use thrift::transport::{TTcpTransport, TTransport}; +/// use thrift::transport::TTcpChannel; /// -/// let mut transport = TTcpTransport::new(); -/// transport.open("localhost:9090").unwrap(); -/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>)); +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); /// -/// let mut o_prot = TBinaryOutputProtocol::new(transport, true); +/// let mut protocol = TBinaryOutputProtocol::new(channel, true); /// -/// o_prot.write_bool(true).unwrap(); -/// o_prot.write_string("test_string").unwrap(); +/// protocol.write_bool(true).unwrap(); +/// protocol.write_string("test_string").unwrap(); /// ``` -pub struct TBinaryOutputProtocol<'a> { +#[derive(Debug)] +pub struct TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ strict: bool, - transport: Rc<RefCell<Box<TTransport + 'a>>>, + pub transport: T, // FIXME: do not make public; only public for testing! } -impl<'a> TBinaryOutputProtocol<'a> { +impl<T> TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ /// Create a `TBinaryOutputProtocol` that writes bytes to `transport`. /// /// Set `strict` to `true` if all outgoing messages should contain the /// protocol version number in the protocol header. - pub fn new(transport: Rc<RefCell<Box<TTransport + 'a>>>, - strict: bool) -> TBinaryOutputProtocol<'a> { + pub fn new(transport: T, strict: bool) -> TBinaryOutputProtocol<T> { TBinaryOutputProtocol { strict: strict, transport: transport, @@ -289,16 +314,22 @@ impl<'a> TBinaryOutputProtocol<'a> { } fn write_transport(&mut self, buf: &[u8]) -> ::Result<()> { - self.transport.borrow_mut().write(buf).map(|_| ()).map_err(From::from) + self.transport + .write(buf) + .map(|_| ()) + .map_err(From::from) } } -impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { +impl<T> TOutputProtocol for TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { if self.strict { let message_type: u8 = identifier.message_type.into(); let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32); - self.transport.borrow_mut().write_u32::<BigEndian>(header)?; + self.transport.write_u32::<BigEndian>(header)?; self.write_string(&identifier.name)?; self.write_i32(identifier.sequence_number) } else { @@ -322,11 +353,17 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { if identifier.id.is_none() && identifier.field_type != TType::Stop { - return Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::Unknown, - message: format!("cannot write identifier {:?} without sequence number", - &identifier), - })); + return Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!( + "cannot write identifier {:?} without sequence number", + &identifier + ), + }, + ), + ); } self.write_byte(field_type_to_u8(identifier.field_type))?; @@ -359,23 +396,31 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { } fn write_i8(&mut self, i: i8) -> ::Result<()> { - self.transport.borrow_mut().write_i8(i).map_err(From::from) + self.transport.write_i8(i).map_err(From::from) } fn write_i16(&mut self, i: i16) -> ::Result<()> { - self.transport.borrow_mut().write_i16::<BigEndian>(i).map_err(From::from) + self.transport + .write_i16::<BigEndian>(i) + .map_err(From::from) } fn write_i32(&mut self, i: i32) -> ::Result<()> { - self.transport.borrow_mut().write_i32::<BigEndian>(i).map_err(From::from) + self.transport + .write_i32::<BigEndian>(i) + .map_err(From::from) } fn write_i64(&mut self, i: i64) -> ::Result<()> { - self.transport.borrow_mut().write_i64::<BigEndian>(i).map_err(From::from) + self.transport + .write_i64::<BigEndian>(i) + .map_err(From::from) } fn write_double(&mut self, d: f64) -> ::Result<()> { - self.transport.borrow_mut().write_f64::<BigEndian>(d).map_err(From::from) + self.transport + .write_f64::<BigEndian>(d) + .map_err(From::from) } fn write_string(&mut self, s: &str) -> ::Result<()> { @@ -401,10 +446,12 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { } fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { - let key_type = identifier.key_type + let key_type = identifier + .key_type .expect("map identifier to write should contain key type"); self.write_byte(field_type_to_u8(key_type))?; - let val_type = identifier.value_type + let val_type = identifier + .value_type .expect("map identifier to write should contain value type"); self.write_byte(field_type_to_u8(val_type))?; self.write_i32(identifier.size) @@ -415,14 +462,14 @@ impl<'a> TOutputProtocol for TBinaryOutputProtocol<'a> { } fn flush(&mut self) -> ::Result<()> { - self.transport.borrow_mut().flush().map_err(From::from) + self.transport.flush().map_err(From::from) } // utility // fn write_byte(&mut self, b: u8) -> ::Result<()> { - self.transport.borrow_mut().write_u8(b).map_err(From::from) + self.transport.write_u8(b).map_err(From::from) } } @@ -438,8 +485,8 @@ impl TBinaryOutputProtocolFactory { } impl TOutputProtocolFactory for TBinaryOutputProtocolFactory { - fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol> { - Box::new(TBinaryOutputProtocol::new(transport, true)) as Box<TOutputProtocol> + fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send> { + Box::new(TBinaryOutputProtocol::new(transport, true)) } } @@ -481,10 +528,14 @@ fn field_type_from_u8(b: u8) -> ::Result<TType> { 0x10 => Ok(TType::Utf8), 0x11 => Ok(TType::Utf16), unkn => { - Err(::Error::Protocol(ProtocolError { - kind: ProtocolErrorKind::InvalidData, - message: format!("cannot convert {} to TType", unkn), - })) + Err( + ::Error::Protocol( + ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} to TType", unkn), + }, + ), + ) } } } @@ -492,56 +543,79 @@ fn field_type_from_u8(b: u8) -> ::Result<TType> { #[cfg(test)] mod tests { - use std::rc::Rc; - use std::cell::RefCell; - - use ::protocol::{TFieldIdentifier, TMessageIdentifier, TMessageType, TInputProtocol, - TListIdentifier, TMapIdentifier, TOutputProtocol, TSetIdentifier, - TStructIdentifier, TType}; - use ::transport::{TPassThruTransport, TTransport}; - use ::transport::mem::TBufferTransport; + use protocol::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TMessageType, TOutputProtocol, TSetIdentifier, + TStructIdentifier, TType}; + use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; use super::*; #[test] fn must_write_message_call_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let ident = TMessageIdentifier::new("test", TMessageType::Call, 1); assert!(o_prot.write_message_begin(&ident).is_ok()); - let buf = trans.borrow().write_buffer_to_vec(); - - let expected: [u8; 16] = [0x80, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x74, 0x65, - 0x73, 0x74, 0x00, 0x00, 0x00, 0x01]; - - assert_eq!(&expected, buf.as_slice()); + let expected: [u8; 16] = [ + 0x80, + 0x01, + 0x00, + 0x01, + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x00, + 0x00, + 0x00, + 0x01, + ]; + + assert_eq_written_bytes!(o_prot, expected); } - #[test] fn must_write_message_reply_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10); assert!(o_prot.write_message_begin(&ident).is_ok()); - let buf = trans.borrow().write_buffer_to_vec(); - - let expected: [u8; 16] = [0x80, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x74, 0x65, - 0x73, 0x74, 0x00, 0x00, 0x00, 0x0A]; - - assert_eq!(&expected, buf.as_slice()); + let expected: [u8; 16] = [ + 0x80, + 0x01, + 0x00, + 0x02, + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x00, + 0x00, + 0x00, + 0x0A, + ]; + + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_strict_message_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1); assert!(o_prot.write_message_begin(&sent_ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_message_begin()); assert_eq!(&received_ident, &sent_ident); @@ -564,24 +638,26 @@ mod tests { #[test] fn must_write_field_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22)) - .is_ok()); + assert!( + o_prot + .write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22)) + .is_ok() + ); let expected: [u8; 3] = [0x0B, 0x00, 0x16]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_field_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20); assert!(o_prot.write_field_begin(&sent_field_ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let expected_ident = TFieldIdentifier { name: None, @@ -594,22 +670,21 @@ mod tests { #[test] fn must_write_stop_field() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert!(o_prot.write_field_stop().is_ok()); let expected: [u8; 1] = [0x00]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_field_stop() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); assert!(o_prot.write_field_stop().is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let expected_ident = TFieldIdentifier { name: None, @@ -628,23 +703,26 @@ mod tests { #[test] fn must_write_list_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_list_begin(&TListIdentifier::new(TType::Bool, 5)).is_ok()); + assert!( + o_prot + .write_list_begin(&TListIdentifier::new(TType::Bool, 5)) + .is_ok() + ); let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_list_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TListIdentifier::new(TType::List, 900); assert!(o_prot.write_list_begin(&ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_list_begin()); assert_eq!(&received_ident, &ident); @@ -657,23 +735,26 @@ mod tests { #[test] fn must_write_set_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_set_begin(&TSetIdentifier::new(TType::I16, 7)).is_ok()); + assert!( + o_prot + .write_set_begin(&TSetIdentifier::new(TType::I16, 7)) + .is_ok() + ); let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_set_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TSetIdentifier::new(TType::I64, 2000); assert!(o_prot.write_set_begin(&ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident_result = i_prot.read_set_begin(); assert!(received_ident_result.is_ok()); @@ -687,24 +768,26 @@ mod tests { #[test] fn must_write_map_begin() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); - assert!(o_prot.write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32)) - .is_ok()); + assert!( + o_prot + .write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32)) + .is_ok() + ); let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_round_trip_map_begin() { - let (trans, mut i_prot, mut o_prot) = test_objects(); + let (mut i_prot, mut o_prot) = test_objects(); let ident = TMapIdentifier::new(TType::Map, TType::Set, 100); assert!(o_prot.write_map_begin(&ident).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_ident = assert_success!(i_prot.read_map_begin()); assert_eq!(&received_ident, &ident); @@ -717,31 +800,29 @@ mod tests { #[test] fn must_write_bool_true() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert!(o_prot.write_bool(true).is_ok()); let expected: [u8; 1] = [0x01]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_write_bool_false() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); assert!(o_prot.write_bool(false).is_ok()); let expected: [u8; 1] = [0x00]; - let buf = trans.borrow().write_buffer_to_vec(); - assert_eq!(&expected, buf.as_slice()); + assert_eq_written_bytes!(o_prot, expected); } #[test] fn must_read_bool_true() { - let (trans, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); - trans.borrow_mut().set_readable_bytes(&[0x01]); + set_readable_bytes!(i_prot, &[0x01]); let read_bool = assert_success!(i_prot.read_bool()); assert_eq!(read_bool, true); @@ -749,9 +830,9 @@ mod tests { #[test] fn must_read_bool_false() { - let (trans, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); - trans.borrow_mut().set_readable_bytes(&[0x00]); + set_readable_bytes!(i_prot, &[0x00]); let read_bool = assert_success!(i_prot.read_bool()); assert_eq!(read_bool, false); @@ -759,9 +840,9 @@ mod tests { #[test] fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() { - let (trans, mut i_prot, _) = test_objects(); + let (mut i_prot, _) = test_objects(); - trans.borrow_mut().set_readable_bytes(&[0xAC]); + set_readable_bytes!(i_prot, &[0xAC]); let read_bool = assert_success!(i_prot.read_bool()); assert_eq!(read_bool, true); @@ -769,52 +850,77 @@ mod tests { #[test] fn must_write_bytes() { - let (trans, _, mut o_prot) = test_objects(); + let (_, mut o_prot) = test_objects(); let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF]; assert!(o_prot.write_bytes(&bytes).is_ok()); - let buf = trans.borrow().write_buffer_to_vec(); + let buf = o_prot.transport.write_bytes(); assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); // length assert_eq!(&buf[4..], bytes); // actual bytes } #[test] fn must_round_trip_bytes() { - let (trans, mut i_prot, mut o_prot) = test_objects(); - - let bytes: [u8; 25] = [0x20, 0xFD, 0x18, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF, 0x34, - 0xDC, 0x98, 0xA4, 0x6D, 0xF3, 0x99, 0xB4, 0xB7, 0xD4, 0x9C, 0xA5, - 0xB3, 0xC9, 0x88]; + let (mut i_prot, mut o_prot) = test_objects(); + + let bytes: [u8; 25] = [ + 0x20, + 0xFD, + 0x18, + 0x84, + 0x99, + 0x12, + 0xAB, + 0xBB, + 0x45, + 0xDF, + 0x34, + 0xDC, + 0x98, + 0xA4, + 0x6D, + 0xF3, + 0x99, + 0xB4, + 0xB7, + 0xD4, + 0x9C, + 0xA5, + 0xB3, + 0xC9, + 0x88, + ]; assert!(o_prot.write_bytes(&bytes).is_ok()); - trans.borrow_mut().copy_write_buffer_to_read_buffer(); + copy_write_buffer_to_read_buffer!(o_prot); let received_bytes = assert_success!(i_prot.read_bytes()); assert_eq!(&received_bytes, &bytes); } - fn test_objects<'a> - () - -> (Rc<RefCell<Box<TBufferTransport>>>, TBinaryInputProtocol<'a>, TBinaryOutputProtocol<'a>) + fn test_objects() + -> (TBinaryInputProtocol<ReadHalf<TBufferChannel>>, + TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) { + let mem = TBufferChannel::with_capacity(40, 40); - let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40)))); + let (r_mem, w_mem) = mem.split().unwrap(); - let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() }); - let inner = Rc::new(RefCell::new(inner)); + let i_prot = TBinaryInputProtocol::new(r_mem, true); + let o_prot = TBinaryOutputProtocol::new(w_mem, true); - let i_prot = TBinaryInputProtocol::new(inner.clone(), true); - let o_prot = TBinaryOutputProtocol::new(inner.clone(), true); - - (mem, i_prot, o_prot) + (i_prot, o_prot) } - fn assert_no_write<F: FnMut(&mut TBinaryOutputProtocol) -> ::Result<()>>(mut write_fn: F) { - let (trans, _, mut o_prot) = test_objects(); + fn assert_no_write<F>(mut write_fn: F) + where + F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>, + { + let (_, mut o_prot) = test_objects(); assert!(write_fn(&mut o_prot).is_ok()); - assert_eq!(trans.borrow().write_buffer_as_ref().len(), 0); + assert_eq!(o_prot.transport.write_bytes().len(), 0); } }
