http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rpc-test-base.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rpc-test-base.h b/be/src/kudu/rpc/rpc-test-base.h new file mode 100644 index 0000000..2fb742e --- /dev/null +++ b/be/src/kudu/rpc/rpc-test-base.h @@ -0,0 +1,661 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifndef KUDU_RPC_RPC_TEST_BASE_H +#define KUDU_RPC_RPC_TEST_BASE_H + +#include <algorithm> +#include <atomic> +#include <memory> +#include <string> + +#include "kudu/gutil/walltime.h" +#include "kudu/rpc/acceptor_pool.h" +#include "kudu/rpc/messenger.h" +#include "kudu/rpc/proxy.h" +#include "kudu/rpc/reactor.h" +#include "kudu/rpc/remote_method.h" +#include "kudu/rpc/result_tracker.h" +#include "kudu/rpc/rpc_context.h" +#include "kudu/rpc/rpc_controller.h" +#include "kudu/rpc/rpc_sidecar.h" +#include "kudu/rpc/rtest.pb.h" +#include "kudu/rpc/rtest.proxy.h" +#include "kudu/rpc/rtest.service.h" +#include "kudu/rpc/service_if.h" +#include "kudu/rpc/service_pool.h" +#include "kudu/security/security-test-util.h" +#include "kudu/util/env.h" +#include "kudu/util/faststring.h" +#include "kudu/util/mem_tracker.h" +#include "kudu/util/monotime.h" +#include "kudu/util/net/sockaddr.h" +#include "kudu/util/path_util.h" +#include "kudu/util/pb_util.h" +#include "kudu/util/random.h" +#include "kudu/util/random_util.h" +#include "kudu/util/stopwatch.h" +#include "kudu/util/test_macros.h" +#include "kudu/util/test_util.h" +#include "kudu/util/trace.h" + +DECLARE_bool(rpc_encrypt_loopback_connections); + +namespace kudu { +namespace rpc { + +using kudu::rpc_test::AddRequestPB; +using kudu::rpc_test::AddResponsePB; +using kudu::rpc_test::CalculatorError; +using kudu::rpc_test::CalculatorServiceIf; +using kudu::rpc_test::CalculatorServiceProxy; +using kudu::rpc_test::EchoRequestPB; +using kudu::rpc_test::EchoResponsePB; +using kudu::rpc_test::ExactlyOnceRequestPB; +using kudu::rpc_test::ExactlyOnceResponsePB; +using kudu::rpc_test::FeatureFlags; +using kudu::rpc_test::PanicRequestPB; +using kudu::rpc_test::PanicResponsePB; +using kudu::rpc_test::PushTwoStringsRequestPB; +using kudu::rpc_test::PushTwoStringsResponsePB; +using kudu::rpc_test::SendTwoStringsRequestPB; +using kudu::rpc_test::SendTwoStringsResponsePB; +using kudu::rpc_test::SleepRequestPB; +using kudu::rpc_test::SleepResponsePB; +using kudu::rpc_test::SleepWithSidecarRequestPB; +using kudu::rpc_test::SleepWithSidecarResponsePB; +using kudu::rpc_test::TestInvalidResponseRequestPB; +using kudu::rpc_test::TestInvalidResponseResponsePB; +using kudu::rpc_test::WhoAmIRequestPB; +using kudu::rpc_test::WhoAmIResponsePB; +using kudu::rpc_test_diff_package::ReqDiffPackagePB; +using kudu::rpc_test_diff_package::RespDiffPackagePB; + +// Implementation of CalculatorService which just implements the generic +// RPC handler (no generated code). +class GenericCalculatorService : public ServiceIf { + public: + static const char *kFullServiceName; + static const char *kAddMethodName; + static const char *kSleepMethodName; + static const char *kSleepWithSidecarMethodName; + static const char *kPushTwoStringsMethodName; + static const char *kSendTwoStringsMethodName; + static const char *kAddExactlyOnce; + + static const char* kFirstString; + static const char* kSecondString; + + GenericCalculatorService() { + } + + // To match the argument list of the generated CalculatorService. + explicit GenericCalculatorService(const scoped_refptr<MetricEntity>& entity, + const scoped_refptr<ResultTracker>& result_tracker) { + // this test doesn't generate metrics, so we ignore the argument. + } + + void Handle(InboundCall *incoming) override { + if (incoming->remote_method().method_name() == kAddMethodName) { + DoAdd(incoming); + } else if (incoming->remote_method().method_name() == kSleepMethodName) { + DoSleep(incoming); + } else if (incoming->remote_method().method_name() == kSleepWithSidecarMethodName) { + DoSleepWithSidecar(incoming); + } else if (incoming->remote_method().method_name() == kSendTwoStringsMethodName) { + DoSendTwoStrings(incoming); + } else if (incoming->remote_method().method_name() == kPushTwoStringsMethodName) { + DoPushTwoStrings(incoming); + } else { + incoming->RespondFailure(ErrorStatusPB::ERROR_NO_SUCH_METHOD, + Status::InvalidArgument("bad method")); + } + } + + std::string service_name() const override { return kFullServiceName; } + static std::string static_service_name() { return kFullServiceName; } + + private: + void DoAdd(InboundCall *incoming) { + Slice param(incoming->serialized_request()); + AddRequestPB req; + if (!req.ParseFromArray(param.data(), param.size())) { + LOG(FATAL) << "couldn't parse: " << param.ToDebugString(); + } + + AddResponsePB resp; + resp.set_result(req.x() + req.y()); + incoming->RespondSuccess(resp); + } + + void DoSendTwoStrings(InboundCall* incoming) { + Slice param(incoming->serialized_request()); + SendTwoStringsRequestPB req; + if (!req.ParseFromArray(param.data(), param.size())) { + LOG(FATAL) << "couldn't parse: " << param.ToDebugString(); + } + + std::unique_ptr<faststring> first(new faststring); + std::unique_ptr<faststring> second(new faststring); + + Random r(req.random_seed()); + first->resize(req.size1()); + RandomString(first->data(), req.size1(), &r); + + second->resize(req.size2()); + RandomString(second->data(), req.size2(), &r); + + SendTwoStringsResponsePB resp; + int idx1, idx2; + CHECK_OK(incoming->AddOutboundSidecar( + RpcSidecar::FromFaststring(std::move(first)), &idx1)); + CHECK_OK(incoming->AddOutboundSidecar( + RpcSidecar::FromFaststring(std::move(second)), &idx2)); + resp.set_sidecar1(idx1); + resp.set_sidecar2(idx2); + + incoming->RespondSuccess(resp); + } + + void DoPushTwoStrings(InboundCall* incoming) { + Slice param(incoming->serialized_request()); + PushTwoStringsRequestPB req; + if (!req.ParseFromArray(param.data(), param.size())) { + LOG(FATAL) << "couldn't parse: " << param.ToDebugString(); + } + + Slice sidecar1; + CHECK_OK(incoming->GetInboundSidecar(req.sidecar1_idx(), &sidecar1)); + + Slice sidecar2; + CHECK_OK(incoming->GetInboundSidecar(req.sidecar2_idx(), &sidecar2)); + + // Check that reading non-existant sidecars doesn't work. + Slice tmp; + CHECK(!incoming->GetInboundSidecar(req.sidecar2_idx() + 2, &tmp).ok()); + + PushTwoStringsResponsePB resp; + resp.set_size1(sidecar1.size()); + resp.set_data1(reinterpret_cast<const char*>(sidecar1.data()), sidecar1.size()); + resp.set_size2(sidecar2.size()); + resp.set_data2(reinterpret_cast<const char*>(sidecar2.data()), sidecar2.size()); + + // Drop the sidecars etc, just to confirm that it's safe to do so. + CHECK_GT(incoming->GetTransferSize(), 0); + incoming->DiscardTransfer(); + CHECK_EQ(0, incoming->GetTransferSize()); + incoming->RespondSuccess(resp); + } + + void DoSleep(InboundCall *incoming) { + Slice param(incoming->serialized_request()); + SleepRequestPB req; + if (!req.ParseFromArray(param.data(), param.size())) { + incoming->RespondFailure(ErrorStatusPB::ERROR_INVALID_REQUEST, + Status::InvalidArgument("Couldn't parse pb", + req.InitializationErrorString())); + return; + } + + LOG(INFO) << "got call: " << pb_util::SecureShortDebugString(req); + SleepFor(MonoDelta::FromMicroseconds(req.sleep_micros())); + MonoDelta duration(MonoTime::Now().GetDeltaSince(incoming->GetTimeReceived())); + CHECK_GE(duration.ToMicroseconds(), req.sleep_micros()); + SleepResponsePB resp; + incoming->RespondSuccess(resp); + } + + void DoSleepWithSidecar(InboundCall *incoming) { + Slice param(incoming->serialized_request()); + SleepWithSidecarRequestPB req; + if (!req.ParseFromArray(param.data(), param.size())) { + incoming->RespondFailure(ErrorStatusPB::ERROR_INVALID_REQUEST, + Status::InvalidArgument("Couldn't parse pb", + req.InitializationErrorString())); + return; + } + + LOG(INFO) << "got call: " << pb_util::SecureShortDebugString(req); + SleepFor(MonoDelta::FromMicroseconds(req.sleep_micros())); + + uint32_t pattern = req.pattern(); + uint32_t num_repetitions = req.num_repetitions(); + Slice sidecar; + CHECK_OK(incoming->GetInboundSidecar(req.sidecar_idx(), &sidecar)); + CHECK_EQ(sidecar.size(), sizeof(uint32) * num_repetitions); + const uint32_t *data = reinterpret_cast<const uint32_t*>(sidecar.data()); + for (int i = 0; i < num_repetitions; ++i) CHECK_EQ(data[i], pattern); + + SleepResponsePB resp; + incoming->RespondSuccess(resp); + } +}; + +class CalculatorService : public CalculatorServiceIf { + public: + explicit CalculatorService(const scoped_refptr<MetricEntity>& entity, + const scoped_refptr<ResultTracker> result_tracker) + : CalculatorServiceIf(entity, result_tracker), + exactly_once_test_val_(0) { + } + + void Add(const AddRequestPB *req, AddResponsePB *resp, RpcContext *context) override { + CHECK_GT(context->GetTransferSize(), 0); + resp->set_result(req->x() + req->y()); + context->RespondSuccess(); + } + + void Sleep(const SleepRequestPB *req, SleepResponsePB *resp, RpcContext *context) override { + if (req->return_app_error()) { + CalculatorError my_error; + my_error.set_extra_error_data("some application-specific error data"); + context->RespondApplicationError(CalculatorError::app_error_ext.number(), + "Got some error", my_error); + return; + } + + // Respond w/ error if the RPC specifies that the client deadline is set, + // but it isn't. + if (req->client_timeout_defined()) { + MonoTime deadline = context->GetClientDeadline(); + if (deadline == MonoTime::Max()) { + CalculatorError my_error; + my_error.set_extra_error_data("Timeout not set"); + context->RespondApplicationError(CalculatorError::app_error_ext.number(), + "Missing required timeout", my_error); + return; + } + } + + if (req->deferred()) { + // Spawn a new thread which does the sleep and responds later. + scoped_refptr<Thread> thread; + CHECK_OK(Thread::Create("rpc-test", "deferred", + &CalculatorService::DoSleep, this, req, context, + &thread)); + return; + } + DoSleep(req, context); + } + + void Echo(const EchoRequestPB *req, EchoResponsePB *resp, RpcContext *context) override { + resp->set_data(req->data()); + context->RespondSuccess(); + } + + void WhoAmI(const WhoAmIRequestPB* /*req*/, + WhoAmIResponsePB* resp, + RpcContext* context) override { + const RemoteUser& user = context->remote_user(); + resp->mutable_credentials()->set_real_user(user.username()); + resp->set_address(context->remote_address().ToString()); + context->RespondSuccess(); + } + + void TestArgumentsInDiffPackage(const ReqDiffPackagePB *req, + RespDiffPackagePB *resp, + ::kudu::rpc::RpcContext *context) override { + context->RespondSuccess(); + } + + void Panic(const PanicRequestPB* req, PanicResponsePB* resp, RpcContext* context) override { + TRACE("Got panic request"); + PANIC_RPC(context, "Test method panicking!"); + } + + void TestInvalidResponse(const TestInvalidResponseRequestPB* req, + TestInvalidResponseResponsePB* resp, + RpcContext* context) override { + switch (req->error_type()) { + case rpc_test::TestInvalidResponseRequestPB_ErrorType_MISSING_REQUIRED_FIELD: + // Respond without setting the 'resp->response' protobuf field, which is + // marked as required. This exercises the error path of invalid responses. + context->RespondSuccess(); + break; + case rpc_test::TestInvalidResponseRequestPB_ErrorType_RESPONSE_TOO_LARGE: + resp->mutable_response()->resize(FLAGS_rpc_max_message_size + 1000); + context->RespondSuccess(); + break; + default: + LOG(FATAL); + } + } + + bool SupportsFeature(uint32_t feature) const override { + return feature == FeatureFlags::FOO; + } + + void AddExactlyOnce(const ExactlyOnceRequestPB* req, ExactlyOnceResponsePB* resp, + ::kudu::rpc::RpcContext* context) override { + if (req->sleep_for_ms() > 0) { + usleep(req->sleep_for_ms() * 1000); + } + // If failures are enabled, cause them some percentage of the time. + if (req->randomly_fail()) { + if (rand() % 10 < 3) { + context->RespondFailure(Status::ServiceUnavailable("Random injected failure.")); + return; + } + } + int result = exactly_once_test_val_ += req->value_to_add(); + resp->set_current_val(result); + resp->set_current_time_micros(GetCurrentTimeMicros()); + context->RespondSuccess(); + } + + bool AuthorizeDisallowAlice(const google::protobuf::Message* /*req*/, + google::protobuf::Message* /*resp*/, + RpcContext* context) override { + if (context->remote_user().username() == "alice") { + context->RespondFailure(Status::NotAuthorized("alice is not allowed to call this method")); + return false; + } + return true; + } + + bool AuthorizeDisallowBob(const google::protobuf::Message* /*req*/, + google::protobuf::Message* /*resp*/, + RpcContext* context) override { + if (context->remote_user().username() == "bob") { + context->RespondFailure(Status::NotAuthorized("bob is not allowed to call this method")); + return false; + } + return true; + } + + private: + void DoSleep(const SleepRequestPB *req, + RpcContext *context) { + TRACE_COUNTER_INCREMENT("test_sleep_us", req->sleep_micros()); + if (Trace::CurrentTrace()) { + scoped_refptr<Trace> child_trace(new Trace()); + Trace::CurrentTrace()->AddChildTrace("test_child", child_trace.get()); + ADOPT_TRACE(child_trace.get()); + TRACE_COUNTER_INCREMENT("related_trace_metric", 1); + } + + SleepFor(MonoDelta::FromMicroseconds(req->sleep_micros())); + context->RespondSuccess(); + } + + std::atomic_int exactly_once_test_val_; + +}; + +const char *GenericCalculatorService::kFullServiceName = "kudu.rpc.GenericCalculatorService"; +const char *GenericCalculatorService::kAddMethodName = "Add"; +const char *GenericCalculatorService::kSleepMethodName = "Sleep"; +const char *GenericCalculatorService::kSleepWithSidecarMethodName = "SleepWithSidecar"; +const char *GenericCalculatorService::kPushTwoStringsMethodName = "PushTwoStrings"; +const char *GenericCalculatorService::kSendTwoStringsMethodName = "SendTwoStrings"; +const char *GenericCalculatorService::kAddExactlyOnce = "AddExactlyOnce"; + +const char *GenericCalculatorService::kFirstString = + "1111111111111111111111111111111111111111111111111111111111"; +const char *GenericCalculatorService::kSecondString = + "2222222222222222222222222222222222222222222222222222222222222222222222"; + +class RpcTestBase : public KuduTest { + public: + RpcTestBase() + : n_worker_threads_(3), + service_queue_length_(100), + n_server_reactor_threads_(3), + keepalive_time_ms_(1000), + metric_entity_(METRIC_ENTITY_server.Instantiate(&metric_registry_, "test.rpc_test")) { + } + + void TearDown() override { + if (service_pool_) { + server_messenger_->UnregisterService(service_name_); + service_pool_->Shutdown(); + } + if (server_messenger_) { + server_messenger_->Shutdown(); + } + KuduTest::TearDown(); + } + + protected: + Status CreateMessenger(const std::string& name, + std::shared_ptr<Messenger>* messenger, + int n_reactors = 1, + bool enable_ssl = false, + const std::string& rpc_certificate_file = "", + const std::string& rpc_private_key_file = "", + const std::string& rpc_ca_certificate_file = "", + const std::string& rpc_private_key_password_cmd = "") { + MessengerBuilder bld(name); + + if (enable_ssl) { + FLAGS_rpc_encrypt_loopback_connections = true; + bld.set_epki_cert_key_files(rpc_certificate_file, rpc_private_key_file); + bld.set_epki_certificate_authority_file(rpc_ca_certificate_file); + bld.set_epki_private_password_key_cmd(rpc_private_key_password_cmd); + bld.set_rpc_encryption("required"); + bld.enable_inbound_tls(); + } + + bld.set_num_reactors(n_reactors); + bld.set_connection_keepalive_time(MonoDelta::FromMilliseconds(keepalive_time_ms_)); + if (keepalive_time_ms_ >= 0) { + // In order for the keepalive timing to be accurate, we need to scan connections + // significantly more frequently than the keepalive time. This "coarse timer" + // granularity determines this. + bld.set_coarse_timer_granularity( + MonoDelta::FromMilliseconds(std::min(keepalive_time_ms_ / 5, 100))); + } + bld.set_metric_entity(metric_entity_); + return bld.Build(messenger); + } + + Status DoTestSyncCall(const Proxy &p, const char *method, + CredentialsPolicy policy = CredentialsPolicy::ANY_CREDENTIALS) { + AddRequestPB req; + req.set_x(rand()); + req.set_y(rand()); + AddResponsePB resp; + RpcController controller; + controller.set_timeout(MonoDelta::FromMilliseconds(10000)); + controller.set_credentials_policy(policy); + RETURN_NOT_OK(p.SyncRequest(method, req, &resp, &controller)); + + CHECK_EQ(req.x() + req.y(), resp.result()); + return Status::OK(); + } + + void DoTestSidecar(const Proxy &p, int size1, int size2) { + const uint32_t kSeed = 12345; + + SendTwoStringsRequestPB req; + req.set_size1(size1); + req.set_size2(size2); + req.set_random_seed(kSeed); + + SendTwoStringsResponsePB resp; + RpcController controller; + controller.set_timeout(MonoDelta::FromMilliseconds(10000)); + CHECK_OK(p.SyncRequest(GenericCalculatorService::kSendTwoStringsMethodName, + req, &resp, &controller)); + + Slice first = GetSidecarPointer(controller, resp.sidecar1(), size1); + Slice second = GetSidecarPointer(controller, resp.sidecar2(), size2); + Random rng(kSeed); + faststring expected; + + expected.resize(size1); + RandomString(expected.data(), size1, &rng); + CHECK_EQ(0, first.compare(Slice(expected))); + + expected.resize(size2); + RandomString(expected.data(), size2, &rng); + CHECK_EQ(0, second.compare(Slice(expected))); + } + + Status DoTestOutgoingSidecar(const Proxy &p, int size1, int size2) { + PushTwoStringsRequestPB request; + RpcController controller; + + int idx1; + std::string s1(size1, 'a'); + CHECK_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(Slice(s1)), &idx1)); + + int idx2; + std::string s2(size2, 'b'); + CHECK_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(Slice(s2)), &idx2)); + + request.set_sidecar1_idx(idx1); + request.set_sidecar2_idx(idx2); + + PushTwoStringsResponsePB resp; + KUDU_RETURN_NOT_OK(p.SyncRequest(GenericCalculatorService::kPushTwoStringsMethodName, + request, &resp, &controller)); + CHECK_EQ(size1, resp.size1()); + CHECK_EQ(resp.data1(), s1); + CHECK_EQ(size2, resp.size2()); + CHECK_EQ(resp.data2(), s2); + return Status::OK(); + } + + void DoTestOutgoingSidecarExpectOK(const Proxy &p, int size1, int size2) { + CHECK_OK(DoTestOutgoingSidecar(p, size1, size2)); + } + + void DoTestExpectTimeout(const Proxy& p, + const MonoDelta& timeout, + bool* is_negotiaton_error = nullptr) { + SleepRequestPB req; + SleepResponsePB resp; + // Sleep for 500ms longer than the call timeout. + int sleep_micros = timeout.ToMicroseconds() + 500 * 1000; + req.set_sleep_micros(sleep_micros); + + RpcController c; + c.set_timeout(timeout); + Stopwatch sw; + sw.start(); + Status s = p.SyncRequest(GenericCalculatorService::kSleepMethodName, req, &resp, &c); + sw.stop(); + ASSERT_FALSE(s.ok()); + if (is_negotiaton_error != nullptr) { + *is_negotiaton_error = c.negotiation_failed(); + } + + int expected_millis = timeout.ToMilliseconds(); + int elapsed_millis = sw.elapsed().wall_millis(); + + // We shouldn't timeout significantly faster than our configured timeout. + EXPECT_GE(elapsed_millis, expected_millis - 10); + // And we also shouldn't take the full time that we asked for + EXPECT_LT(elapsed_millis * 1000, sleep_micros); + EXPECT_TRUE(s.IsTimedOut()); + LOG(INFO) << "status: " << s.ToString() << ", seconds elapsed: " << sw.elapsed().wall_seconds(); + } + + Status StartTestServer(Sockaddr *server_addr, + bool enable_ssl = false, + const std::string& rpc_certificate_file = "", + const std::string& rpc_private_key_file = "", + const std::string& rpc_ca_certificate_file = "", + const std::string& rpc_private_key_password_cmd = "") { + return DoStartTestServer<GenericCalculatorService>( + server_addr, enable_ssl, rpc_certificate_file, rpc_private_key_file, + rpc_ca_certificate_file, rpc_private_key_password_cmd); + } + + Status StartTestServerWithGeneratedCode(Sockaddr *server_addr, bool enable_ssl = false) { + return DoStartTestServer<CalculatorService>(server_addr, enable_ssl); + } + + Status StartTestServerWithCustomMessenger(Sockaddr *server_addr, + const std::shared_ptr<Messenger>& messenger, bool enable_ssl = false) { + return DoStartTestServer<GenericCalculatorService>( + server_addr, enable_ssl, "", "", "", "", messenger); + } + + // Start a simple socket listening on a local port, returning the address. + // This isn't an RPC server -- just a plain socket which can be helpful for testing. + Status StartFakeServer(Socket *listen_sock, Sockaddr *listen_addr) { + Sockaddr bind_addr; + bind_addr.set_port(0); + RETURN_NOT_OK(listen_sock->Init(0)); + RETURN_NOT_OK(listen_sock->BindAndListen(bind_addr, 1)); + RETURN_NOT_OK(listen_sock->GetSocketAddress(listen_addr)); + LOG(INFO) << "Bound to: " << listen_addr->ToString(); + return Status::OK(); + } + + private: + + static Slice GetSidecarPointer(const RpcController& controller, int idx, + int expected_size) { + Slice sidecar; + CHECK_OK(controller.GetInboundSidecar(idx, &sidecar)); + CHECK_EQ(expected_size, sidecar.size()); + return Slice(sidecar.data(), expected_size); + } + + template<class ServiceClass> + Status DoStartTestServer(Sockaddr *server_addr, + bool enable_ssl = false, + const std::string& rpc_certificate_file = "", + const std::string& rpc_private_key_file = "", + const std::string& rpc_ca_certificate_file = "", + const std::string& rpc_private_key_password_cmd = "", + const std::shared_ptr<Messenger>& messenger = nullptr) { + if (!messenger) { + RETURN_NOT_OK(CreateMessenger( + "TestServer", &server_messenger_, n_server_reactor_threads_, enable_ssl, + rpc_certificate_file, rpc_private_key_file, rpc_ca_certificate_file, + rpc_private_key_password_cmd)); + } else { + server_messenger_ = messenger; + } + std::shared_ptr<AcceptorPool> pool; + RETURN_NOT_OK(server_messenger_->AddAcceptorPool(Sockaddr(), &pool)); + RETURN_NOT_OK(pool->Start(2)); + *server_addr = pool->bind_address(); + mem_tracker_ = MemTracker::CreateTracker(-1, "result_tracker"); + result_tracker_.reset(new ResultTracker(mem_tracker_)); + + gscoped_ptr<ServiceIf> service(new ServiceClass(metric_entity_, result_tracker_)); + service_name_ = service->service_name(); + scoped_refptr<MetricEntity> metric_entity = server_messenger_->metric_entity(); + service_pool_ = new ServicePool(std::move(service), metric_entity, service_queue_length_); + server_messenger_->RegisterService(service_name_, service_pool_); + RETURN_NOT_OK(service_pool_->Init(n_worker_threads_)); + + return Status::OK(); + } + + protected: + std::string service_name_; + std::shared_ptr<Messenger> server_messenger_; + scoped_refptr<ServicePool> service_pool_; + std::shared_ptr<kudu::MemTracker> mem_tracker_; + scoped_refptr<ResultTracker> result_tracker_; + int n_worker_threads_; + int service_queue_length_; + int n_server_reactor_threads_; + int keepalive_time_ms_; + + MetricRegistry metric_registry_; + scoped_refptr<MetricEntity> metric_entity_; +}; + +} // namespace rpc +} // namespace kudu +#endif
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rpc-test.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rpc-test.cc b/be/src/kudu/rpc/rpc-test.cc new file mode 100644 index 0000000..077b5a3 --- /dev/null +++ b/be/src/kudu/rpc/rpc-test.cc @@ -0,0 +1,1364 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/rpc-test-base.h" + +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <cstring> +#include <limits> +#include <memory> +#include <ostream> +#include <set> +#include <string> +#include <unistd.h> +#include <unordered_map> +#include <vector> + +#include <boost/bind.hpp> +#include <boost/core/ref.hpp> +#include <boost/function.hpp> +#include <gflags/gflags_declare.h> +#include <glog/logging.h> +#include <gtest/gtest.h> + +#include "kudu/gutil/casts.h" +#include "kudu/gutil/gscoped_ptr.h" +#include "kudu/gutil/map-util.h" +#include "kudu/gutil/ref_counted.h" +#include "kudu/gutil/stl_util.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/acceptor_pool.h" +#include "kudu/rpc/constants.h" +#include "kudu/rpc/messenger.h" +#include "kudu/rpc/outbound_call.h" +#include "kudu/rpc/proxy.h" +#include "kudu/rpc/reactor.h" +#include "kudu/rpc/rpc_controller.h" +#include "kudu/rpc/rpc_introspection.pb.h" +#include "kudu/rpc/rpc_sidecar.h" +#include "kudu/rpc/rtest.pb.h" +#include "kudu/rpc/serialization.h" +#include "kudu/rpc/transfer.h" +#include "kudu/security/test/test_certs.h" +#include "kudu/util/countdown_latch.h" +#include "kudu/util/env.h" +#include "kudu/util/metrics.h" +#include "kudu/util/monotime.h" +#include "kudu/util/net/sockaddr.h" +#include "kudu/util/net/socket.h" +#include "kudu/util/random.h" +#include "kudu/util/scoped_cleanup.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" +#include "kudu/util/test_macros.h" +#include "kudu/util/test_util.h" +#include "kudu/util/thread.h" + +METRIC_DECLARE_histogram(handler_latency_kudu_rpc_test_CalculatorService_Sleep); +METRIC_DECLARE_histogram(rpc_incoming_queue_time); + +DECLARE_bool(rpc_reopen_outbound_connections); +DECLARE_int32(rpc_negotiation_inject_delay_ms); + +using std::shared_ptr; +using std::string; +using std::unique_ptr; +using std::unordered_map; +using std::vector; + +namespace kudu { +namespace rpc { + +class TestRpc : public RpcTestBase, public ::testing::WithParamInterface<bool> { +}; + +// This is used to run all parameterized tests with and without SSL. +INSTANTIATE_TEST_CASE_P(OptionalSSL, TestRpc, testing::Values(false, true)); + +TEST_F(TestRpc, TestSockaddr) { + Sockaddr addr1, addr2; + addr1.set_port(1000); + addr2.set_port(2000); + // port is ignored when comparing Sockaddr objects + ASSERT_FALSE(addr1 < addr2); + ASSERT_FALSE(addr2 < addr1); + ASSERT_EQ(1000, addr1.port()); + ASSERT_EQ(2000, addr2.port()); + ASSERT_EQ(string("0.0.0.0:1000"), addr1.ToString()); + ASSERT_EQ(string("0.0.0.0:2000"), addr2.ToString()); + Sockaddr addr3(addr1); + ASSERT_EQ(string("0.0.0.0:1000"), addr3.ToString()); +} + +TEST_P(TestRpc, TestMessengerCreateDestroy) { + shared_ptr<Messenger> messenger; + ASSERT_OK(CreateMessenger("TestCreateDestroy", &messenger, 1, GetParam())); + LOG(INFO) << "started messenger " << messenger->name(); + messenger->Shutdown(); +} + +// Test starting and stopping a messenger. This is a regression +// test for a segfault seen in early versions of the RPC code, +// in which shutting down the acceptor would trigger an assert, +// making our tests flaky. +TEST_P(TestRpc, TestAcceptorPoolStartStop) { + int n_iters = AllowSlowTests() ? 100 : 5; + for (int i = 0; i < n_iters; i++) { + shared_ptr<Messenger> messenger; + ASSERT_OK(CreateMessenger("TestAcceptorPoolStartStop", &messenger, 1, GetParam())); + shared_ptr<AcceptorPool> pool; + ASSERT_OK(messenger->AddAcceptorPool(Sockaddr(), &pool)); + Sockaddr bound_addr; + ASSERT_OK(pool->GetBoundAddress(&bound_addr)); + ASSERT_NE(0, bound_addr.port()); + ASSERT_OK(pool->Start(2)); + messenger->Shutdown(); + } +} + +TEST_F(TestRpc, TestConnHeaderValidation) { + MessengerBuilder mb("TestRpc.TestConnHeaderValidation"); + const int conn_hdr_len = kMagicNumberLength + kHeaderFlagsLength; + uint8_t buf[conn_hdr_len]; + serialization::SerializeConnHeader(buf); + ASSERT_OK(serialization::ValidateConnHeader(Slice(buf, conn_hdr_len))); +} + +// Regression test for KUDU-2041 +TEST_P(TestRpc, TestNegotiationDeadlock) { + bool enable_ssl = GetParam(); + + // The deadlock would manifest in cases where the number of concurrent connection + // requests >= the number of threads. 1 thread and 1 cnxn to ourself is just the easiest + // way to reproduce the issue, because the server negotiation task must get queued after + // the client negotiation task if they share the same thread pool. + MessengerBuilder mb("TestRpc.TestNegotiationDeadlock"); + mb.set_min_negotiation_threads(1) + .set_max_negotiation_threads(1) + .set_metric_entity(metric_entity_); + if (enable_ssl) mb.enable_inbound_tls(); + + shared_ptr<Messenger> messenger; + CHECK_OK(mb.Build(&messenger)); + + Sockaddr server_addr; + ASSERT_OK(StartTestServerWithCustomMessenger(&server_addr, messenger, enable_ssl)); + + Proxy p(messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); +} + +// Test making successful RPC calls. +TEST_P(TestRpc, TestCall) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + ASSERT_STR_CONTAINS(p.ToString(), strings::Substitute("kudu.rpc.GenericCalculatorService@" + "{remote=$0, user_credentials=", + server_addr.ToString())); + + for (int i = 0; i < 10; i++) { + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + } +} + +// Test for KUDU-2091 and KUDU-2220. +TEST_P(TestRpc, TestCallWithChainCertAndChainCA) { + bool enable_ssl = GetParam(); + // We're only interested in running this test with TLS enabled. + if (!enable_ssl) return; + + string rpc_certificate_file; + string rpc_private_key_file; + string rpc_ca_certificate_file; + ASSERT_OK(security::CreateTestSSLCertSignedByChain(GetTestDataDirectory(), + &rpc_certificate_file, + &rpc_private_key_file, + &rpc_ca_certificate_file)); + // Set up server. + Sockaddr server_addr; + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + SCOPED_TRACE(strings::Substitute("Connecting to $0", server_addr.ToString())); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl, + rpc_certificate_file, rpc_private_key_file, rpc_ca_certificate_file)); + + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + ASSERT_STR_CONTAINS(p.ToString(), strings::Substitute("kudu.rpc.GenericCalculatorService@" + "{remote=$0, user_credentials=", + server_addr.ToString())); + + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); +} + +// Test for KUDU-2041. +TEST_P(TestRpc, TestCallWithChainCertAndRootCA) { + bool enable_ssl = GetParam(); + // We're only interested in running this test with TLS enabled. + if (!enable_ssl) return; + + string rpc_certificate_file; + string rpc_private_key_file; + string rpc_ca_certificate_file; + ASSERT_OK(security::CreateTestSSLCertWithChainSignedByRoot(GetTestDataDirectory(), + &rpc_certificate_file, + &rpc_private_key_file, + &rpc_ca_certificate_file)); + // Set up server. + Sockaddr server_addr; + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + SCOPED_TRACE(strings::Substitute("Connecting to $0", server_addr.ToString())); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl, + rpc_certificate_file, rpc_private_key_file, rpc_ca_certificate_file)); + + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + ASSERT_STR_CONTAINS(p.ToString(), strings::Substitute("kudu.rpc.GenericCalculatorService@" + "{remote=$0, user_credentials=", + server_addr.ToString())); + + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); +} + +// Test making successful RPC calls while using a TLS certificate with a password protected +// private key. +TEST_P(TestRpc, TestCallWithPasswordProtectedKey) { + bool enable_ssl = GetParam(); + // We're only interested in running this test with TLS enabled. + if (!enable_ssl) return; + + string rpc_certificate_file; + string rpc_private_key_file; + string rpc_ca_certificate_file; + string rpc_private_key_password_cmd; + string passwd; + ASSERT_OK(security::CreateTestSSLCertWithEncryptedKey(GetTestDataDirectory(), + &rpc_certificate_file, + &rpc_private_key_file, + &passwd)); + rpc_ca_certificate_file = rpc_certificate_file; + rpc_private_key_password_cmd = strings::Substitute("echo $0", passwd); + // Set up server. + Sockaddr server_addr; + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + SCOPED_TRACE(strings::Substitute("Connecting to $0", server_addr.ToString())); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl, + rpc_certificate_file, rpc_private_key_file, rpc_ca_certificate_file, + rpc_private_key_password_cmd)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + ASSERT_STR_CONTAINS(p.ToString(), strings::Substitute("kudu.rpc.GenericCalculatorService@" + "{remote=$0, user_credentials=", + server_addr.ToString())); + + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); +} + +// Test that using a TLS certificate with a password protected private key and providing +// the wrong password for that private key, causes a server startup failure. +TEST_P(TestRpc, TestCallWithBadPasswordProtectedKey) { + bool enable_ssl = GetParam(); + // We're only interested in running this test with TLS enabled. + if (!enable_ssl) return; + + string rpc_certificate_file; + string rpc_private_key_file; + string rpc_ca_certificate_file; + string rpc_private_key_password_cmd; + string passwd; + ASSERT_OK(security::CreateTestSSLCertWithEncryptedKey(GetTestDataDirectory(), + &rpc_certificate_file, + &rpc_private_key_file, + &passwd)); + // Overwrite the password with an invalid one. + passwd = "badpassword"; + rpc_ca_certificate_file = rpc_certificate_file; + rpc_private_key_password_cmd = strings::Substitute("echo $0", passwd); + // Verify that the server fails to start up. + Sockaddr server_addr; + Status s = StartTestServer(&server_addr, enable_ssl, rpc_certificate_file, rpc_private_key_file, + rpc_ca_certificate_file, rpc_private_key_password_cmd); + ASSERT_TRUE(s.IsRuntimeError()); + ASSERT_STR_CONTAINS(s.ToString(), "failed to load private key file"); +} + +// Test that connecting to an invalid server properly throws an error. +TEST_P(TestRpc, TestCallToBadServer) { + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, GetParam())); + Sockaddr addr; + addr.set_port(0); + Proxy p(client_messenger, addr, addr.host(), + GenericCalculatorService::static_service_name()); + + // Loop a few calls to make sure that we properly set up and tear down + // the connections. + for (int i = 0; i < 5; i++) { + Status s = DoTestSyncCall(p, GenericCalculatorService::kAddMethodName); + LOG(INFO) << "Status: " << s.ToString(); + ASSERT_TRUE(s.IsNetworkError()) << "unexpected status: " << s.ToString(); + } +} + +// Test that RPC calls can be failed with an error status on the server. +TEST_P(TestRpc, TestInvalidMethodCall) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Call the method which fails. + Status s = DoTestSyncCall(p, "ThisMethodDoesNotExist"); + ASSERT_TRUE(s.IsRemoteError()) << "unexpected status: " << s.ToString(); + ASSERT_STR_CONTAINS(s.ToString(), "bad method"); +} + +// Test that the error message returned when connecting to the wrong service +// is reasonable. +TEST_P(TestRpc, TestWrongService) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client with the wrong service name. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, "localhost", "WrongServiceName"); + + // Call the method which fails. + Status s = DoTestSyncCall(p, "ThisMethodDoesNotExist"); + ASSERT_TRUE(s.IsRemoteError()) << "unexpected status: " << s.ToString(); + ASSERT_STR_CONTAINS(s.ToString(), + "Service unavailable: service WrongServiceName " + "not registered on TestServer"); +} + +// Test that we can still make RPC connections even if many fds are in use. +// This is a regression test for KUDU-650. +TEST_P(TestRpc, TestHighFDs) { + // This test can only run if ulimit is set high. + const int kNumFakeFiles = 3500; + const int kMinUlimit = kNumFakeFiles + 100; + if (env_->GetResourceLimit(Env::ResourceLimitType::OPEN_FILES_PER_PROCESS) < kMinUlimit) { + LOG(INFO) << "Test skipped: must increase ulimit -n to at least " << kMinUlimit; + return; + } + + // Open a bunch of fds just to increase our fd count. + vector<RandomAccessFile*> fake_files; + ElementDeleter d(&fake_files); + for (int i = 0; i < kNumFakeFiles; i++) { + unique_ptr<RandomAccessFile> f; + CHECK_OK(Env::Default()->NewRandomAccessFile("/dev/zero", &f)); + fake_files.push_back(f.release()); + } + + // Set up server and client, and verify we can make a successful call. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); +} + +// Test that connections are kept alive between calls. +TEST_P(TestRpc, TestConnectionKeepalive) { + // Only run one reactor per messenger, so we can grab the metrics from that + // one without having to check all. + n_server_reactor_threads_ = 1; + keepalive_time_ms_ = 500; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + + SleepFor(MonoDelta::FromMilliseconds(5)); + + ReactorMetrics metrics; + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(1, metrics.num_server_connections_) << "Server should have 1 server connection"; + ASSERT_EQ(0, metrics.num_client_connections_) << "Server should have 0 client connections"; + + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.num_server_connections_) << "Client should have 0 server connections"; + ASSERT_EQ(1, metrics.num_client_connections_) << "Client should have 1 client connections"; + + SleepFor(MonoDelta::FromMilliseconds(2 * keepalive_time_ms_)); + + // After sleeping, the keepalive timer should have closed both sides of + // the connection. + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.num_server_connections_) << "Server should have 0 server connections"; + ASSERT_EQ(0, metrics.num_client_connections_) << "Server should have 0 client connections"; + + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.num_server_connections_) << "Client should have 0 server connections"; + ASSERT_EQ(0, metrics.num_client_connections_) << "Client should have 0 client connections"; +} + +// Test that idle connection is kept alive when 'keepalive_time_ms_' is set to -1. +TEST_P(TestRpc, TestConnectionAlwaysKeepalive) { + // Only run one reactor per messenger, so we can grab the metrics from that + // one without having to check all. + n_server_reactor_threads_ = 1; + keepalive_time_ms_ = -1; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + + ReactorMetrics metrics; + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(1, metrics.num_server_connections_) << "Server should have 1 server connection"; + ASSERT_EQ(0, metrics.num_client_connections_) << "Server should have 0 client connections"; + + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.num_server_connections_) << "Client should have 0 server connections"; + ASSERT_EQ(1, metrics.num_client_connections_) << "Client should have 1 client connections"; + + SleepFor(MonoDelta::FromSeconds(3)); + + // After sleeping, the connection should still be alive. + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(1, metrics.num_server_connections_) << "Server should have 1 server connections"; + ASSERT_EQ(0, metrics.num_client_connections_) << "Server should have 0 client connections"; + + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.num_server_connections_) << "Client should have 0 server connections"; + ASSERT_EQ(1, metrics.num_client_connections_) << "Client should have 1 client connections"; +} + +// Test that the metrics on a per connection level work accurately. +TEST_P(TestRpc, TestClientConnectionMetrics) { + // Only run one reactor per messenger, so we can grab the metrics from that + // one without having to check all. + n_server_reactor_threads_ = 1; + keepalive_time_ms_ = -1; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Cause the reactor thread to be blocked for 2 seconds. + server_messenger_->ScheduleOnReactor(boost::bind(sleep, 2), MonoDelta::FromSeconds(0)); + + RpcController controller; + DumpRunningRpcsRequestPB dump_req; + DumpRunningRpcsResponsePB dump_resp; + dump_req.set_include_traces(false); + + // We'll send several calls asynchronously to force RPC queueing on the sender side. + int n_calls = 1000; + AddRequestPB add_req; + add_req.set_x(rand()); + add_req.set_y(rand()); + AddResponsePB add_resp; + + vector<unique_ptr<RpcController>> controllers; + CountDownLatch latch(n_calls); + for (int i = 0; i < n_calls; i++) { + controllers.emplace_back(new RpcController()); + p.AsyncRequest(GenericCalculatorService::kAddMethodName, add_req, &add_resp, + controllers.back().get(), boost::bind(&CountDownLatch::CountDown, boost::ref(latch))); + } + + // Since we blocked the only reactor thread for sometime, we should see RPCs queued on the + // OutboundTransfer queue, unless the main thread is very slow. + ASSERT_OK(client_messenger->DumpRunningRpcs(dump_req, &dump_resp)); + ASSERT_EQ(1, dump_resp.outbound_connections_size()); + ASSERT_GT(dump_resp.outbound_connections(0).outbound_queue_size(), 0); + + // Wait for the calls to be marked finished. + latch.Wait(); + + // Verify that all the RPCs have finished. + for (const auto& controller : controllers) { + ASSERT_TRUE(controller->finished()); + } +} + +// Test that outbound connections to the same server are reopen upon every RPC +// call when the 'rpc_reopen_outbound_connections' flag is set. +TEST_P(TestRpc, TestReopenOutboundConnections) { + // Set the flag to enable special mode: close and reopen already established + // outbound connections. + FLAGS_rpc_reopen_outbound_connections = true; + + // Only run one reactor per messenger, so we can grab the metrics from that + // one without having to check all. + n_server_reactor_threads_ = 1; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Verify the initial counters. + ReactorMetrics metrics; + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.total_client_connections_); + ASSERT_EQ(0, metrics.total_server_connections_); + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.total_client_connections_); + ASSERT_EQ(0, metrics.total_server_connections_); + + // Run several iterations, just in case. + for (int i = 0; i < 32; ++i) { + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.total_client_connections_); + ASSERT_EQ(i + 1, metrics.total_server_connections_); + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(i + 1, metrics.total_client_connections_); + ASSERT_EQ(0, metrics.total_server_connections_); + } +} + +// Test that an outbound connection is closed and a new one is open if going +// from ANY_CREDENTIALS to PRIMARY_CREDENTIALS policy for RPC calls to the same +// destination. +// Test that changing from PRIMARY_CREDENTIALS policy to ANY_CREDENTIALS policy +// re-uses the connection established with PRIMARY_CREDENTIALS policy. +TEST_P(TestRpc, TestCredentialsPolicy) { + // Only run one reactor per messenger, so we can grab the metrics from that + // one without having to check all. + n_server_reactor_threads_ = 1; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Verify the initial counters. + ReactorMetrics metrics; + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.total_client_connections_); + ASSERT_EQ(0, metrics.total_server_connections_); + ASSERT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + ASSERT_EQ(0, metrics.total_client_connections_); + ASSERT_EQ(0, metrics.total_server_connections_); + + // Make an RPC call with ANY_CREDENTIALS policy. + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + EXPECT_EQ(0, metrics.total_client_connections_); + EXPECT_EQ(1, metrics.total_server_connections_); + EXPECT_EQ(1, metrics.num_server_connections_); + EXPECT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + EXPECT_EQ(1, metrics.total_client_connections_); + EXPECT_EQ(0, metrics.total_server_connections_); + EXPECT_EQ(1, metrics.num_client_connections_); + + // This is to allow all the data to be sent so the connection becomes idle. + SleepFor(MonoDelta::FromMilliseconds(5)); + + // Make an RPC call with PRIMARY_CREDENTIALS policy. Currently open connection + // with ANY_CREDENTIALS policy should be closed and a new one established + // with PRIMARY_CREDENTIALS policy. + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName, + CredentialsPolicy::PRIMARY_CREDENTIALS)); + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + EXPECT_EQ(0, metrics.total_client_connections_); + EXPECT_EQ(2, metrics.total_server_connections_); + EXPECT_EQ(1, metrics.num_server_connections_); + EXPECT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + EXPECT_EQ(2, metrics.total_client_connections_); + EXPECT_EQ(0, metrics.total_server_connections_); + EXPECT_EQ(1, metrics.num_client_connections_); + + // Make another RPC call with ANY_CREDENTIALS policy. The already established + // connection with PRIMARY_CREDENTIALS policy should be re-used because + // the ANY_CREDENTIALS policy satisfies the PRIMARY_CREDENTIALS policy which + // the currently open connection has been established with. + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + ASSERT_OK(server_messenger_->reactors_[0]->GetMetrics(&metrics)); + EXPECT_EQ(0, metrics.total_client_connections_); + EXPECT_EQ(2, metrics.total_server_connections_); + EXPECT_EQ(1, metrics.num_server_connections_); + EXPECT_OK(client_messenger->reactors_[0]->GetMetrics(&metrics)); + EXPECT_EQ(2, metrics.total_client_connections_); + EXPECT_EQ(0, metrics.total_server_connections_); + EXPECT_EQ(1, metrics.num_client_connections_); +} + +// Test that a call which takes longer than the keepalive time +// succeeds -- i.e that we don't consider a connection to be "idle" on the +// server if there is a call outstanding on it. +TEST_P(TestRpc, TestCallLongerThanKeepalive) { + // Set a short keepalive. + keepalive_time_ms_ = 1000; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Make a call which sleeps longer than the keepalive. + RpcController controller; + SleepRequestPB req; + req.set_sleep_micros(3 * 1000 * 1000); // 3 seconds. + req.set_deferred(true); + SleepResponsePB resp; + ASSERT_OK(p.SyncRequest(GenericCalculatorService::kSleepMethodName, + req, &resp, &controller)); +} + +// Test that the RpcSidecar transfers the expected messages. +TEST_P(TestRpc, TestRpcSidecar) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, GetParam())); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Test a zero-length sidecar + DoTestSidecar(p, 0, 0); + + // Test some small sidecars + DoTestSidecar(p, 123, 456); + + // Test some larger sidecars to verify that we properly handle the case where + // we can't write the whole response to the socket in a single call. + DoTestSidecar(p, 3000 * 1024, 2000 * 1024); + + DoTestOutgoingSidecarExpectOK(p, 0, 0); + DoTestOutgoingSidecarExpectOK(p, 123, 456); + DoTestOutgoingSidecarExpectOK(p, 3000 * 1024, 2000 * 1024); +} + +TEST_P(TestRpc, TestRpcSidecarLimits) { + { + // Test that the limits on the number of sidecars is respected. + RpcController controller; + string s = "foo"; + int idx; + for (int i = 0; i < TransferLimits::kMaxSidecars; ++i) { + ASSERT_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(Slice(s)), &idx)); + } + + ASSERT_TRUE(controller.AddOutboundSidecar( + RpcSidecar::FromSlice(Slice(s)), &idx).IsRuntimeError()); + } + + // Construct a string to use as a maximal payload in following tests + string max_string(TransferLimits::kMaxTotalSidecarBytes, 'a'); + + { + // Test that limit on the total size of sidecars is respected. The maximal payload + // reaches the limit exactly. + RpcController controller; + int idx; + ASSERT_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(Slice(max_string)), &idx)); + + // Trying to add another byte will fail. + int dummy = 0; + string s2(1, 'b'); + Status max_sidecar_status = + controller.AddOutboundSidecar(RpcSidecar::FromSlice(Slice(s2)), &dummy); + ASSERT_FALSE(max_sidecar_status.ok()); + ASSERT_STR_MATCHES(max_sidecar_status.ToString(), "Total size of sidecars"); + } + + // Test two cases: + // 1) The RPC has maximal size and exceeds rpc_max_message_size. This tests the + // functionality of rpc_max_message_size. The server will close the connection + // immediately. + // 2) The RPC has maximal size, but rpc_max_message_size has been set to a higher + // value. This tests the client's ability to send the maximal message. + // The server will reject the message after it has been transferred. + // This test is disabled for TSAN due to high memory requirements. + std::vector<int64_t> rpc_max_message_values; + rpc_max_message_values.push_back(FLAGS_rpc_max_message_size); +#ifndef THREAD_SANITIZER + rpc_max_message_values.push_back(std::numeric_limits<int64_t>::max()); +#endif + for (int64_t rpc_max_message_size_val : rpc_max_message_values) { + // Set rpc_max_message_size + FLAGS_rpc_max_message_size = rpc_max_message_size_val; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, GetParam())); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + RpcController controller; + // KUDU-2305: Test with a maximal payload to verify that the implementation + // can handle the limits. + int idx; + ASSERT_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(Slice(max_string)), &idx)); + + PushTwoStringsRequestPB request; + request.set_sidecar1_idx(idx); + request.set_sidecar2_idx(idx); + PushTwoStringsResponsePB resp; + Status status = p.SyncRequest(GenericCalculatorService::kPushTwoStringsMethodName, + request, &resp, &controller); + ASSERT_TRUE(status.IsNetworkError()) << "Unexpected error: " << status.ToString(); + // Remote responds to extra-large payloads by closing the connection. + ASSERT_STR_MATCHES(status.ToString(), + // Linux + "Connection reset by peer" + // While reading from socket. + "|recv got EOF from" + // Linux, SSL enabled + "|failed to read from TLS socket" + // macOS, while writing to socket. + "|Protocol wrong type for socket" + // macOS, sendmsg(): the sum of the iov_len values overflows an ssize_t + "|sendmsg error: Invalid argument"); + } +} + +// Test that timeouts are properly handled. +TEST_P(TestRpc, TestCallTimeout) { + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Test a very short timeout - we expect this will time out while the + // call is still trying to connect, or in the send queue. This was triggering ASAN failures + // before. + ASSERT_NO_FATAL_FAILURE(DoTestExpectTimeout(p, MonoDelta::FromNanoseconds(1))); + + // Test a longer timeout - expect this will time out after we send the request, + // but shorter than our threshold for two-stage timeout handling. + ASSERT_NO_FATAL_FAILURE(DoTestExpectTimeout(p, MonoDelta::FromMilliseconds(200))); + + // Test a longer timeout - expect this will trigger the "two-stage timeout" + // code path. + ASSERT_NO_FATAL_FAILURE(DoTestExpectTimeout(p, MonoDelta::FromMilliseconds(1500))); +} + +// Inject 500ms delay in negotiation, and send a call with a short timeout, followed by +// one with a long timeout. The call with the long timeout should succeed even though +// the previous one failed. +// +// This is a regression test against prior behavior where the connection negotiation +// was assigned the timeout of the first call on that connection. So, if the first +// call had a short timeout, the later call would also inherit the timed-out negotiation. +TEST_P(TestRpc, TestCallTimeoutDoesntAffectNegotiation) { + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + FLAGS_rpc_negotiation_inject_delay_ms = 500; + ASSERT_NO_FATAL_FAILURE(DoTestExpectTimeout(p, MonoDelta::FromMilliseconds(50))); + ASSERT_OK(DoTestSyncCall(p, GenericCalculatorService::kAddMethodName)); + + // Only the second call should have been received by the server, because we + // don't bother sending an already-timed-out call. + auto metric_map = server_messenger_->metric_entity()->UnsafeMetricsMapForTests(); + auto* metric = FindOrDie(metric_map, &METRIC_rpc_incoming_queue_time).get(); + ASSERT_EQ(1, down_cast<Histogram*>(metric)->TotalCount()); +} + +static void AcceptAndReadForever(Socket* listen_sock) { + // Accept the TCP connection. + Socket server_sock; + Sockaddr remote; + CHECK_OK(listen_sock->Accept(&server_sock, &remote, 0)); + + MonoTime deadline = MonoTime::Now() + MonoDelta::FromSeconds(10); + + size_t nread; + uint8_t buf[1024]; + while (server_sock.BlockingRecv(buf, sizeof(buf), &nread, deadline).ok()) { + } +} + +// Starts a fake listening socket which never actually negotiates. +// Ensures that the client gets a reasonable status code in this case. +TEST_F(TestRpc, TestNegotiationTimeout) { + // Set up a simple socket server which accepts a connection. + Sockaddr server_addr; + Socket listen_sock; + ASSERT_OK(StartFakeServer(&listen_sock, &server_addr)); + + // Create another thread to accept the connection on the fake server. + scoped_refptr<Thread> acceptor_thread; + ASSERT_OK(Thread::Create("test", "acceptor", + AcceptAndReadForever, &listen_sock, + &acceptor_thread)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + bool is_negotiation_error = false; + ASSERT_NO_FATAL_FAILURE(DoTestExpectTimeout( + p, MonoDelta::FromMilliseconds(100), &is_negotiation_error)); + EXPECT_TRUE(is_negotiation_error); + + acceptor_thread->Join(); +} + +// Test that client calls get failed properly when the server they're connected to +// shuts down. +TEST_F(TestRpc, TestServerShutsDown) { + // Set up a simple socket server which accepts a connection. + Sockaddr server_addr; + Socket listen_sock; + ASSERT_OK(StartFakeServer(&listen_sock, &server_addr)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Send a call. + AddRequestPB req; + req.set_x(rand()); + req.set_y(rand()); + AddResponsePB resp; + + vector<unique_ptr<RpcController>> controllers; + + // We'll send several calls async, and ensure that they all + // get the error status when the connection drops. + int n_calls = 5; + + CountDownLatch latch(n_calls); + for (int i = 0; i < n_calls; i++) { + controllers.emplace_back(new RpcController()); + p.AsyncRequest(GenericCalculatorService::kAddMethodName, req, &resp, controllers.back().get(), + boost::bind(&CountDownLatch::CountDown, boost::ref(latch))); + } + + // Accept the TCP connection. + Socket server_sock; + Sockaddr remote; + ASSERT_OK(listen_sock.Accept(&server_sock, &remote, 0)); + + // The call is still in progress at this point. + for (const auto& controller : controllers) { + ASSERT_FALSE(controller->finished()); + } + + // Shut down the socket. + ASSERT_OK(listen_sock.Close()); + ASSERT_OK(server_sock.Close()); + + // Wait for the call to be marked finished. + latch.Wait(); + + // Should get the appropriate error on the client for all calls; + for (const auto& controller : controllers) { + ASSERT_TRUE(controller->finished()); + Status s = controller->status(); + ASSERT_TRUE(s.IsNetworkError()) << + "Unexpected status: " << s.ToString(); + + // Any of these errors could happen, depending on whether we were + // in the middle of sending a call while the connection died, or + // if we were already waiting for responses. + // + // ECONNREFUSED is possible because the sending of the calls is async. + // For example, the following interleaving: + // - Enqueue 3 calls + // - Reactor wakes up, creates connection, starts writing calls + // - Enqueue 2 more calls + // - Shut down socket + // - Reactor wakes up, tries to write more of the first 3 calls, gets error + // - Reactor shuts down connection + // - Reactor sees the 2 remaining calls, makes a new connection + // - Because the socket is shut down, gets ECONNREFUSED. + // + // EINVAL is possible if the controller socket had already disconnected by + // the time it trys to set the SO_SNDTIMEO socket option as part of the + // normal blocking SASL handshake. + ASSERT_TRUE(s.posix_code() == EPIPE || + s.posix_code() == ECONNRESET || + s.posix_code() == ESHUTDOWN || + s.posix_code() == ECONNREFUSED || + s.posix_code() == EINVAL) + << "Unexpected status: " << s.ToString(); + } +} + +// Test handler latency metric. +TEST_P(TestRpc, TestRpcHandlerLatencyMetric) { + + const uint64_t sleep_micros = 20 * 1000; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServerWithGeneratedCode(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + CalculatorService::static_service_name()); + + RpcController controller; + SleepRequestPB req; + req.set_sleep_micros(sleep_micros); + req.set_deferred(true); + SleepResponsePB resp; + ASSERT_OK(p.SyncRequest("Sleep", req, &resp, &controller)); + + const unordered_map<const MetricPrototype*, scoped_refptr<Metric> > metric_map = + server_messenger_->metric_entity()->UnsafeMetricsMapForTests(); + + scoped_refptr<Histogram> latency_histogram = down_cast<Histogram *>( + FindOrDie(metric_map, + &METRIC_handler_latency_kudu_rpc_test_CalculatorService_Sleep).get()); + + LOG(INFO) << "Sleep() min lat: " << latency_histogram->MinValueForTests(); + LOG(INFO) << "Sleep() mean lat: " << latency_histogram->MeanValueForTests(); + LOG(INFO) << "Sleep() max lat: " << latency_histogram->MaxValueForTests(); + LOG(INFO) << "Sleep() #calls: " << latency_histogram->TotalCount(); + + ASSERT_EQ(1, latency_histogram->TotalCount()); + ASSERT_GE(latency_histogram->MaxValueForTests(), sleep_micros); + ASSERT_TRUE(latency_histogram->MinValueForTests() == latency_histogram->MaxValueForTests()); + + // TODO: Implement an incoming queue latency test. + // For now we just assert that the metric exists. + ASSERT_TRUE(FindOrDie(metric_map, &METRIC_rpc_incoming_queue_time)); +} + +static void DestroyMessengerCallback(shared_ptr<Messenger>* messenger, + CountDownLatch* latch) { + messenger->reset(); + latch->CountDown(); +} + +TEST_P(TestRpc, TestRpcCallbackDestroysMessenger) { + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, GetParam())); + Sockaddr bad_addr; + CountDownLatch latch(1); + + AddRequestPB req; + req.set_x(rand()); + req.set_y(rand()); + AddResponsePB resp; + RpcController controller; + controller.set_timeout(MonoDelta::FromMilliseconds(1)); + { + Proxy p(client_messenger, bad_addr, "xxx-host", "xxx-service"); + p.AsyncRequest("my-fake-method", req, &resp, &controller, + boost::bind(&DestroyMessengerCallback, &client_messenger, &latch)); + } + latch.Wait(); +} + +// Test that setting the client timeout / deadline gets propagated to RPC +// services. +TEST_P(TestRpc, TestRpcContextClientDeadline) { + const uint64_t sleep_micros = 20 * 1000; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServerWithGeneratedCode(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + CalculatorService::static_service_name()); + + SleepRequestPB req; + req.set_sleep_micros(sleep_micros); + req.set_client_timeout_defined(true); + SleepResponsePB resp; + RpcController controller; + Status s = p.SyncRequest("Sleep", req, &resp, &controller); + ASSERT_TRUE(s.IsRemoteError()); + ASSERT_STR_CONTAINS(s.ToString(), "Missing required timeout"); + + controller.Reset(); + controller.set_timeout(MonoDelta::FromMilliseconds(1000)); + ASSERT_OK(p.SyncRequest("Sleep", req, &resp, &controller)); +} + +// Test that setting an call-level application feature flag to an unknown value +// will make the server reject the call. +TEST_P(TestRpc, TestApplicationFeatureFlag) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServerWithGeneratedCode(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + CalculatorService::static_service_name()); + + { // Supported flag + AddRequestPB req; + req.set_x(1); + req.set_y(2); + AddResponsePB resp; + RpcController controller; + controller.RequireServerFeature(FeatureFlags::FOO); + Status s = p.SyncRequest("Add", req, &resp, &controller); + SCOPED_TRACE(strings::Substitute("supported response: $0", s.ToString())); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(resp.result(), 3); + } + + { // Unsupported flag + AddRequestPB req; + req.set_x(1); + req.set_y(2); + AddResponsePB resp; + RpcController controller; + controller.RequireServerFeature(FeatureFlags::FOO); + controller.RequireServerFeature(99); + Status s = p.SyncRequest("Add", req, &resp, &controller); + SCOPED_TRACE(strings::Substitute("unsupported response: $0", s.ToString())); + ASSERT_TRUE(s.IsRemoteError()); + } +} + +TEST_P(TestRpc, TestApplicationFeatureFlagUnsupportedServer) { + auto savedFlags = kSupportedServerRpcFeatureFlags; + auto cleanup = MakeScopedCleanup([&] () { kSupportedServerRpcFeatureFlags = savedFlags; }); + kSupportedServerRpcFeatureFlags = {}; + + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServerWithGeneratedCode(&server_addr, enable_ssl)); + + // Set up client. + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + CalculatorService::static_service_name()); + + { // Required flag + AddRequestPB req; + req.set_x(1); + req.set_y(2); + AddResponsePB resp; + RpcController controller; + controller.RequireServerFeature(FeatureFlags::FOO); + Status s = p.SyncRequest("Add", req, &resp, &controller); + SCOPED_TRACE(strings::Substitute("supported response: $0", s.ToString())); + ASSERT_TRUE(s.IsNotSupported()); + } + + { // No required flag + AddRequestPB req; + req.set_x(1); + req.set_y(2); + AddResponsePB resp; + RpcController controller; + Status s = p.SyncRequest("Add", req, &resp, &controller); + SCOPED_TRACE(strings::Substitute("supported response: $0", s.ToString())); + ASSERT_TRUE(s.ok()); + } +} + +TEST_P(TestRpc, TestCancellation) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + for (int i = OutboundCall::READY; i <= OutboundCall::FINISHED_SUCCESS; ++i) { + FLAGS_rpc_inject_cancellation_state = i; + switch (i) { + case OutboundCall::READY: + case OutboundCall::ON_OUTBOUND_QUEUE: + case OutboundCall::SENDING: + case OutboundCall::SENT: + ASSERT_TRUE(DoTestOutgoingSidecar(p, 0, 0).IsAborted()); + ASSERT_TRUE(DoTestOutgoingSidecar(p, 123, 456).IsAborted()); + ASSERT_TRUE(DoTestOutgoingSidecar(p, 3000 * 1024, 2000 * 1024).IsAborted()); + break; + case OutboundCall::NEGOTIATION_TIMED_OUT: + case OutboundCall::TIMED_OUT: + DoTestExpectTimeout(p, MonoDelta::FromMilliseconds(1000)); + break; + case OutboundCall::CANCELLED: + break; + case OutboundCall::FINISHED_NEGOTIATION_ERROR: + case OutboundCall::FINISHED_ERROR: { + AddRequestPB req; + req.set_x(1); + req.set_y(2); + AddResponsePB resp; + RpcController controller; + controller.RequireServerFeature(FeatureFlags::FOO); + controller.RequireServerFeature(99); + Status s = p.SyncRequest("Add", req, &resp, &controller); + ASSERT_TRUE(s.IsRemoteError()); + break; + } + case OutboundCall::FINISHED_SUCCESS: + DoTestOutgoingSidecarExpectOK(p, 0, 0); + DoTestOutgoingSidecarExpectOK(p, 123, 456); + DoTestOutgoingSidecarExpectOK(p, 3000 * 1024, 2000 * 1024); + break; + } + } + client_messenger->Shutdown(); +} + +#define TEST_PAYLOAD_SIZE (1 << 23) +#define TEST_SLEEP_TIME_MS (500) + +static void SleepCallback(uint8_t* payload, CountDownLatch* latch) { + // Overwrites the payload which the sidecar is pointing to. The server + // checks if the payload matches the expected pattern to detect cases + // in which the payload is overwritten while it's being sent. + memset(payload, 0, TEST_PAYLOAD_SIZE); + latch->CountDown(); +} + +// Test to verify that sidecars aren't corrupted when cancelling an async RPC. +TEST_P(TestRpc, TestCancellationAsync) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + RpcController controller; + + // The payload to be used during the RPC. + gscoped_array<uint8_t> payload(new uint8_t[TEST_PAYLOAD_SIZE]); + + // Used to generate sleep time between invoking RPC and requesting cancellation. + Random rand(SeedRandom()); + + for (int i = 0; i < 10; ++i) { + SleepWithSidecarRequestPB req; + SleepWithSidecarResponsePB resp; + + // Initialize the payload with non-zero pattern. + memset(payload.get(), 0xff, TEST_PAYLOAD_SIZE); + req.set_sleep_micros(TEST_SLEEP_TIME_MS); + req.set_pattern(0xffffffff); + req.set_num_repetitions(TEST_PAYLOAD_SIZE / sizeof(uint32_t)); + + int idx; + Slice s(payload.get(), TEST_PAYLOAD_SIZE); + CHECK_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(s), &idx)); + req.set_sidecar_idx(idx); + + CountDownLatch latch(1); + p.AsyncRequest(GenericCalculatorService::kSleepWithSidecarMethodName, + req, &resp, &controller, + boost::bind(SleepCallback, payload.get(), &latch)); + // Sleep for a while before cancelling the RPC. + if (i > 0) SleepFor(MonoDelta::FromMicroseconds(rand.Uniform64(i * 30))); + controller.Cancel(); + latch.Wait(); + ASSERT_TRUE(controller.status().IsAborted() || controller.status().ok()); + controller.Reset(); + } + client_messenger->Shutdown(); +} + +// This function loops for 40 iterations and for each iteration, sends an async RPC +// and sleeps for some time between 1 to 100 microseconds before cancelling the RPC. +// This serves as a helper function for TestCancellationMultiThreads() to exercise +// cancellation when there are concurrent RPCs. +static void SendAndCancelRpcs(Proxy* p, const Slice& slice) { + RpcController controller; + + // Used to generate sleep time between invoking RPC and requesting cancellation. + Random rand(SeedRandom()); + + auto end_time = MonoTime::Now() + MonoDelta::FromSeconds( + AllowSlowTests() ? 15 : 3); + + int i = 0; + while (MonoTime::Now() < end_time) { + controller.Reset(); + PushTwoStringsRequestPB request; + PushTwoStringsResponsePB resp; + int idx; + CHECK_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(slice), &idx)); + request.set_sidecar1_idx(idx); + CHECK_OK(controller.AddOutboundSidecar(RpcSidecar::FromSlice(slice), &idx)); + request.set_sidecar2_idx(idx); + + CountDownLatch latch(1); + p->AsyncRequest(GenericCalculatorService::kPushTwoStringsMethodName, + request, &resp, &controller, + boost::bind(&CountDownLatch::CountDown, boost::ref(latch))); + + if ((i++ % 8) != 0) { + // Sleep for a while before cancelling the RPC. + SleepFor(MonoDelta::FromMicroseconds(rand.Uniform64(100))); + controller.Cancel(); + } + latch.Wait(); + CHECK(controller.status().IsAborted() || controller.status().IsServiceUnavailable() || + controller.status().ok()) << controller.status().ToString(); + } +} + +// Test to exercise cancellation when there are multiple concurrent RPCs from the +// same client to the same server. +TEST_P(TestRpc, TestCancellationMultiThreads) { + // Set up server. + Sockaddr server_addr; + bool enable_ssl = GetParam(); + ASSERT_OK(StartTestServer(&server_addr, enable_ssl)); + + // Set up client. + LOG(INFO) << "Connecting to " << server_addr.ToString(); + shared_ptr<Messenger> client_messenger; + ASSERT_OK(CreateMessenger("Client", &client_messenger, 1, enable_ssl)); + Proxy p(client_messenger, server_addr, server_addr.host(), + GenericCalculatorService::static_service_name()); + + // Buffer used for sidecars by SendAndCancelRpcs(). + string buf(16 * 1024 * 1024, 'a'); + Slice slice(buf); + + // Start a bunch of threads which invoke async RPC and cancellation. + std::vector<scoped_refptr<Thread>> threads; + for (int i = 0; i < 30; ++i) { + scoped_refptr<Thread> rpc_thread; + ASSERT_OK(Thread::Create("test", "rpc", SendAndCancelRpcs, &p, slice, &rpc_thread)); + threads.push_back(rpc_thread); + } + // Wait for all threads to complete. + for (scoped_refptr<Thread>& rpc_thread : threads) { + rpc_thread->Join(); + } + client_messenger->Shutdown(); +} + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rpc.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rpc.cc b/be/src/kudu/rpc/rpc.cc new file mode 100644 index 0000000..84ea892 --- /dev/null +++ b/be/src/kudu/rpc/rpc.cc @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/rpc.h" + +#include <cstdlib> +#include <string> + +#include <boost/bind.hpp> // IWYU pragma: keep +#include <boost/function.hpp> +#include <glog/logging.h> + +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/messenger.h" +#include "kudu/rpc/rpc_header.pb.h" + +using std::shared_ptr; +using std::string; +using strings::Substitute; +using strings::SubstituteAndAppend; + +namespace kudu { + +namespace rpc { + +bool RpcRetrier::HandleResponse(Rpc* rpc, Status* out_status) { + DCHECK(rpc); + DCHECK(out_status); + + // Always retry TOO_BUSY and UNAVAILABLE errors. + const Status controller_status = controller_.status(); + if (controller_status.IsRemoteError()) { + const ErrorStatusPB* err = controller_.error_response(); + if (err && + err->has_code() && + (err->code() == ErrorStatusPB::ERROR_SERVER_TOO_BUSY || + err->code() == ErrorStatusPB::ERROR_UNAVAILABLE)) { + // The UNAVAILABLE code is a broader counterpart of the + // SERVER_TOO_BUSY. In both cases it's necessary to retry a bit later. + DelayedRetry(rpc, controller_status); + return true; + } + } + + *out_status = controller_status; + return false; +} + +void RpcRetrier::DelayedRetry(Rpc* rpc, const Status& why_status) { + if (!why_status.ok() && (last_error_.ok() || last_error_.IsTimedOut())) { + last_error_ = why_status; + } + // Add some jitter to the retry delay. + // + // If the delay causes us to miss our deadline, RetryCb will fail the + // RPC on our behalf. + int num_ms = ++attempt_num_ + ((rand() % 5)); + messenger_->ScheduleOnReactor(boost::bind(&RpcRetrier::DelayedRetryCb, + this, + rpc, _1), + MonoDelta::FromMilliseconds(num_ms)); +} + +void RpcRetrier::DelayedRetryCb(Rpc* rpc, const Status& status) { + Status new_status = status; + if (new_status.ok()) { + // Has this RPC timed out? + if (deadline_.Initialized()) { + if (MonoTime::Now() > deadline_) { + string err_str = Substitute("$0 passed its deadline", rpc->ToString()); + if (!last_error_.ok()) { + SubstituteAndAppend(&err_str, ": $0", last_error_.ToString()); + } + new_status = Status::TimedOut(err_str); + } + } + } + if (new_status.ok()) { + controller_.Reset(); + rpc->SendRpc(); + } else { + rpc->SendRpcCb(new_status); + } +} + +} // namespace rpc +} // namespace kudu