http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rpc_stub-test.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rpc_stub-test.cc b/be/src/kudu/rpc/rpc_stub-test.cc new file mode 100644 index 0000000..e626276 --- /dev/null +++ b/be/src/kudu/rpc/rpc_stub-test.cc @@ -0,0 +1,726 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <atomic> +#include <csignal> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <limits> +#include <memory> +#include <ostream> +#include <string> +#include <thread> +#include <vector> + +#include <boost/bind.hpp> +#include <boost/core/ref.hpp> +#include <boost/function.hpp> +#include <gflags/gflags.h> +#include <gflags/gflags_declare.h> +#include <glog/logging.h> +#include <glog/stl_logging.h> +#include <gtest/gtest.h> + +#include "kudu/gutil/atomicops.h" +#include "kudu/gutil/gscoped_ptr.h" +#include "kudu/gutil/ref_counted.h" +#include "kudu/gutil/stl_util.h" +#include "kudu/rpc/messenger.h" +#include "kudu/rpc/proxy.h" +#include "kudu/rpc/rpc-test-base.h" +#include "kudu/rpc/rpc_controller.h" +#include "kudu/rpc/rpc_header.pb.h" +#include "kudu/rpc/rpc_introspection.pb.h" +#include "kudu/rpc/rpcz_store.h" +#include "kudu/rpc/rtest.pb.h" +#include "kudu/rpc/rtest.proxy.h" +#include "kudu/rpc/service_pool.h" +#include "kudu/rpc/user_credentials.h" +#include "kudu/util/countdown_latch.h" +#include "kudu/util/env.h" +#include "kudu/util/metrics.h" +#include "kudu/util/monotime.h" +#include "kudu/util/net/sockaddr.h" +#include "kudu/util/pb_util.h" +#include "kudu/util/random.h" +#include "kudu/util/status.h" +#include "kudu/util/subprocess.h" +#include "kudu/util/test_macros.h" +#include "kudu/util/test_util.h" +#include "kudu/util/thread_restrictions.h" +#include "kudu/util/user.h" + +DEFINE_bool(is_panic_test_child, false, "Used by TestRpcPanic"); +DECLARE_bool(socket_inject_short_recvs); + +using kudu::pb_util::SecureDebugString; +using std::shared_ptr; +using std::string; +using std::unique_ptr; +using std::vector; +using base::subtle::NoBarrier_Load; + +namespace kudu { +namespace rpc { + +class RpcStubTest : public RpcTestBase { + public: + void SetUp() override { + RpcTestBase::SetUp(); + // Use a shorter queue length since some tests below need to start enough + // threads to saturate the queue. + service_queue_length_ = 10; + ASSERT_OK(StartTestServerWithGeneratedCode(&server_addr_)); + ASSERT_OK(CreateMessenger("Client", &client_messenger_)); + } + protected: + void SendSimpleCall() { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController controller; + AddRequestPB req; + req.set_x(10); + req.set_y(20); + AddResponsePB resp; + ASSERT_OK(p.Add(req, &resp, &controller)); + ASSERT_EQ(30, resp.result()); + } + + Sockaddr server_addr_; + shared_ptr<Messenger> client_messenger_; +}; + +TEST_F(RpcStubTest, TestSimpleCall) { + SendSimpleCall(); +} + +// Regression test for a bug in which we would not properly parse a call +// response when recv() returned a 'short read'. This injects such short +// reads and then makes a number of calls. +TEST_F(RpcStubTest, TestShortRecvs) { + FLAGS_socket_inject_short_recvs = true; + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + for (int i = 0; i < 100; i++) { + NO_FATALS(SendSimpleCall()); + } +} + +// Test calls which are rather large. +// This test sends many of them at once using the async API and then +// waits for them all to return. This is meant to ensure that the +// IO threads can deal with read/write calls that don't succeed +// in sending the entire data in one go. +TEST_F(RpcStubTest, TestBigCallData) { + const int kNumSentAtOnce = 20; + const size_t kMessageSize = 5 * 1024 * 1024; + string data; + data.resize(kMessageSize); + + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + EchoRequestPB req; + req.set_data(data); + + vector<unique_ptr<EchoResponsePB>> resps; + vector<unique_ptr<RpcController>> controllers; + + CountDownLatch latch(kNumSentAtOnce); + for (int i = 0; i < kNumSentAtOnce; i++) { + resps.emplace_back(new EchoResponsePB); + controllers.emplace_back(new RpcController); + + p.EchoAsync(req, resps.back().get(), controllers.back().get(), + boost::bind(&CountDownLatch::CountDown, boost::ref(latch))); + } + + latch.Wait(); + + for (const auto& c : controllers) { + ASSERT_OK(c->status()); + } +} + +TEST_F(RpcStubTest, TestRespondDeferred) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController controller; + SleepRequestPB req; + req.set_sleep_micros(1000); + req.set_deferred(true); + SleepResponsePB resp; + ASSERT_OK(p.Sleep(req, &resp, &controller)); +} + +// Test that the default user credentials are propagated to the server. +TEST_F(RpcStubTest, TestDefaultCredentialsPropagated) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + string expected; + ASSERT_OK(GetLoggedInUser(&expected)); + + RpcController controller; + WhoAmIRequestPB req; + WhoAmIResponsePB resp; + ASSERT_OK(p.WhoAmI(req, &resp, &controller)); + ASSERT_EQ(expected, resp.credentials().real_user()); + ASSERT_FALSE(resp.credentials().has_effective_user()); +} + +// Test that the user can specify other credentials. +TEST_F(RpcStubTest, TestCustomCredentialsPropagated) { + const char* const kFakeUserName = "some fake user"; + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + UserCredentials creds; + creds.set_real_user(kFakeUserName); + p.set_user_credentials(creds); + + RpcController controller; + WhoAmIRequestPB req; + WhoAmIResponsePB resp; + ASSERT_OK(p.WhoAmI(req, &resp, &controller)); + ASSERT_EQ(kFakeUserName, resp.credentials().real_user()); + ASSERT_FALSE(resp.credentials().has_effective_user()); +} + +TEST_F(RpcStubTest, TestAuthorization) { + // First test calling WhoAmI() as user "alice", who is disallowed. + { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + UserCredentials creds; + creds.set_real_user("alice"); + p.set_user_credentials(creds); + + // Alice is disallowed by all RPCs. + RpcController controller; + WhoAmIRequestPB req; + WhoAmIResponsePB resp; + Status s = p.WhoAmI(req, &resp, &controller); + ASSERT_FALSE(s.ok()); + ASSERT_EQ(s.ToString(), + "Remote error: Not authorized: alice is not allowed to call this method"); + } + + // Try some calls as "bob". + { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + UserCredentials creds; + creds.set_real_user("bob"); + p.set_user_credentials(creds); + + // "bob" is allowed to call WhoAmI(). + { + RpcController controller; + WhoAmIRequestPB req; + WhoAmIResponsePB resp; + ASSERT_OK(p.WhoAmI(req, &resp, &controller)); + } + + // "bob" is not allowed to call "Sleep". + { + RpcController controller; + SleepRequestPB req; + req.set_sleep_micros(10); + SleepResponsePB resp; + Status s = p.Sleep(req, &resp, &controller); + ASSERT_EQ(s.ToString(), + "Remote error: Not authorized: bob is not allowed to call this method"); + } + } +} + +// Test that the user's remote address is accessible to the server. +TEST_F(RpcStubTest, TestRemoteAddress) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController controller; + WhoAmIRequestPB req; + WhoAmIResponsePB resp; + ASSERT_OK(p.WhoAmI(req, &resp, &controller)); + ASSERT_STR_CONTAINS(resp.address(), "127.0.0.1:"); +} + +//////////////////////////////////////////////////////////// +// Tests for error cases +//////////////////////////////////////////////////////////// + +// Test sending a PB parameter with a missing field, where the client +// thinks it has sent a full PB. (eg due to version mismatch) +TEST_F(RpcStubTest, TestCallWithInvalidParam) { + Proxy p(client_messenger_, server_addr_, server_addr_.host(), + CalculatorService::static_service_name()); + + rpc_test::AddRequestPartialPB req; + req.set_x(rand()); + // AddRequestPartialPB is missing the 'y' field. + AddResponsePB resp; + RpcController controller; + Status s = p.SyncRequest("Add", req, &resp, &controller); + ASSERT_TRUE(s.IsRemoteError()) << "Bad status: " << s.ToString(); + ASSERT_STR_CONTAINS(s.ToString(), + "Invalid argument: invalid parameter for call " + "kudu.rpc_test.CalculatorService.Add: " + "missing fields: y"); +} + +// Wrapper around AtomicIncrement, since AtomicIncrement returns the 'old' +// value, and our callback needs to be a void function. +static void DoIncrement(Atomic32* count) { + base::subtle::Barrier_AtomicIncrement(count, 1); +} + +// Test sending a PB parameter with a missing field on the client side. +// This also ensures that the async callback is only called once +// (regression test for a previously-encountered bug). +TEST_F(RpcStubTest, TestCallWithMissingPBFieldClientSide) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController controller; + AddRequestPB req; + req.set_x(10); + // Request is missing the 'y' field. + AddResponsePB resp; + Atomic32 callback_count = 0; + p.AddAsync(req, &resp, &controller, boost::bind(&DoIncrement, &callback_count)); + while (NoBarrier_Load(&callback_count) == 0) { + SleepFor(MonoDelta::FromMicroseconds(10)); + } + SleepFor(MonoDelta::FromMicroseconds(100)); + ASSERT_EQ(1, NoBarrier_Load(&callback_count)); + ASSERT_STR_CONTAINS(controller.status().ToString(), + "Invalid argument: invalid parameter for call " + "kudu.rpc_test.CalculatorService.Add: missing fields: y"); +} + +TEST_F(RpcStubTest, TestResponseWithMissingField) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController rpc; + TestInvalidResponseRequestPB req; + TestInvalidResponseResponsePB resp; + req.set_error_type(rpc_test::TestInvalidResponseRequestPB_ErrorType_MISSING_REQUIRED_FIELD); + Status s = p.TestInvalidResponse(req, &resp, &rpc); + ASSERT_STR_CONTAINS(s.ToString(), + "invalid RPC response, missing fields: response"); +} + +// Test case where the server responds with a message which is larger than the maximum +// configured RPC message size. The server should send the response, but the client +// will reject it. +TEST_F(RpcStubTest, TestResponseLargerThanFrameSize) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController rpc; + TestInvalidResponseRequestPB req; + TestInvalidResponseResponsePB resp; + req.set_error_type(rpc_test::TestInvalidResponseRequestPB_ErrorType_RESPONSE_TOO_LARGE); + Status s = p.TestInvalidResponse(req, &resp, &rpc); + ASSERT_STR_CONTAINS(s.ToString(), "Network error: RPC frame had a length of"); +} + +// Test sending a call which isn't implemented by the server. +TEST_F(RpcStubTest, TestCallMissingMethod) { + Proxy p(client_messenger_, server_addr_, server_addr_.host(), + CalculatorService::static_service_name()); + + Status s = DoTestSyncCall(p, "DoesNotExist"); + ASSERT_TRUE(s.IsRemoteError()) << "Bad status: " << s.ToString(); + ASSERT_STR_CONTAINS(s.ToString(), "with an invalid method name: DoesNotExist"); +} + +TEST_F(RpcStubTest, TestApplicationError) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + RpcController controller; + SleepRequestPB req; + SleepResponsePB resp; + req.set_sleep_micros(1); + req.set_return_app_error(true); + Status s = p.Sleep(req, &resp, &controller); + ASSERT_TRUE(s.IsRemoteError()); + EXPECT_EQ("Remote error: Got some error", s.ToString()); + EXPECT_EQ("message: \"Got some error\"\n" + "[kudu.rpc_test.CalculatorError.app_error_ext] {\n" + " extra_error_data: \"some application-specific error data\"\n" + "}\n", + SecureDebugString(*controller.error_response())); +} + +TEST_F(RpcStubTest, TestRpcPanic) { + if (!FLAGS_is_panic_test_child) { + // This is a poor man's death test. We call this same + // test case, but set the above flag, and verify that + // it aborted. gtest death tests don't work here because + // there are already threads started up. + vector<string> argv; + string executable_path; + CHECK_OK(env_->GetExecutablePath(&executable_path)); + argv.push_back(executable_path); + argv.emplace_back("--is_panic_test_child"); + argv.emplace_back("--gtest_filter=RpcStubTest.TestRpcPanic"); + Subprocess subp(argv); + subp.ShareParentStderr(false); + CHECK_OK(subp.Start()); + FILE* in = fdopen(subp.from_child_stderr_fd(), "r"); + PCHECK(in); + + // Search for string "Test method panicking!" somewhere in stderr + char buf[1024]; + bool found_string = false; + while (fgets(buf, sizeof(buf), in)) { + if (strstr(buf, "Test method panicking!")) { + found_string = true; + break; + } + } + CHECK(found_string); + + // Check return status + int wait_status = 0; + CHECK_OK(subp.Wait(&wait_status)); + CHECK(!WIFEXITED(wait_status)); // should not have been successful + if (WIFSIGNALED(wait_status)) { + CHECK_EQ(WTERMSIG(wait_status), SIGABRT); + } else { + // On some systems, we get exit status 134 from SIGABRT rather than + // WIFSIGNALED getting flagged. + CHECK_EQ(WEXITSTATUS(wait_status), 134); + } + return; + } else { + // Before forcing the panic, explicitly remove the test directory. This + // should be safe; this test doesn't generate any data. + CHECK_OK(env_->DeleteRecursively(test_dir_)); + + // Make an RPC which causes the server to abort. + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + RpcController controller; + PanicRequestPB req; + PanicResponsePB resp; + p.Panic(req, &resp, &controller); + } +} + +struct AsyncSleep { + AsyncSleep() : latch(1) {} + + RpcController rpc; + SleepRequestPB req; + SleepResponsePB resp; + CountDownLatch latch; +}; + +TEST_F(RpcStubTest, TestDontHandleTimedOutCalls) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + vector<AsyncSleep*> sleeps; + ElementDeleter d(&sleeps); + + // Send enough sleep calls to occupy the worker threads. + for (int i = 0; i < n_worker_threads_; i++) { + gscoped_ptr<AsyncSleep> sleep(new AsyncSleep); + sleep->rpc.set_timeout(MonoDelta::FromSeconds(1)); + sleep->req.set_sleep_micros(1000*1000); // 1sec + p.SleepAsync(sleep->req, &sleep->resp, &sleep->rpc, + boost::bind(&CountDownLatch::CountDown, &sleep->latch)); + sleeps.push_back(sleep.release()); + } + + // We asynchronously sent the RPCs above, but the RPCs might still + // be in the queue. Because the RPC we send next has a lower timeout, + // it would take priority over the long-timeout RPCs. So, we have to + // wait until the above RPCs are being processed before we continue + // the test. + const Histogram* queue_time_metric = service_pool_->IncomingQueueTimeMetricForTests(); + while (queue_time_metric->TotalCount() < n_worker_threads_) { + SleepFor(MonoDelta::FromMilliseconds(1)); + } + + // Send another call with a short timeout. This shouldn't get processed, because + // it'll get stuck in the queue for longer than its timeout. + ASSERT_EVENTUALLY([&]() { + RpcController rpc; + SleepRequestPB req; + SleepResponsePB resp; + req.set_sleep_micros(1); // unused but required. + rpc.set_timeout(MonoDelta::FromMilliseconds(5)); + Status s = p.Sleep(req, &resp, &rpc); + ASSERT_TRUE(s.IsTimedOut()) << s.ToString(); + // Since our timeout was short, it's possible in rare circumstances + // that we time out the RPC on the outbound queue, in which case + // we won't trigger the desired behavior here. In that case, the + // timeout error status would have the string 'ON_OUTBOUND_QUEUE' + // instead of 'SENT', so this assertion would fail and cause the + // ASSERT_EVENTUALLY to loop. + ASSERT_STR_CONTAINS(s.ToString(), "SENT"); + }); + + for (AsyncSleep* s : sleeps) { + s->latch.Wait(); + } + + // Verify that the timedout call got short circuited before being processed. + // We may need to loop a short amount of time as we are racing with the reactor + // thread to process the remaining elements of the queue. + const Counter* timed_out_in_queue = service_pool_->RpcsTimedOutInQueueMetricForTests(); + ASSERT_EVENTUALLY([&]{ + ASSERT_EQ(1, timed_out_in_queue->value()); + }); +} + +// Test which ensures that the RPC queue accepts requests with the earliest +// deadline first (EDF), and upon overflow rejects requests with the latest deadlines. +// +// In particular, this simulates a workload experienced with Impala where the local +// impalad would spawn more scanner threads than the total number of handlers plus queue +// slots, guaranteeing that some of those clients would see SERVER_TOO_BUSY rejections on +// scan requests and be forced to back off and retry. Without EDF scheduling, we saw that +// the "unlucky" threads that got rejected would likely continue to get rejected upon +// retries, and some would be starved continually until they missed their overall deadline +// and failed the query. +// +// With EDF scheduling, the retries take priority over the original requests (because +// they retain their original deadlines). This prevents starvation of unlucky threads. +TEST_F(RpcStubTest, TestEarliestDeadlineFirstQueue) { + const int num_client_threads = service_queue_length_ + n_worker_threads_ + 5; + vector<std::thread> threads; + vector<int> successes(num_client_threads); + std::atomic<bool> done(false); + for (int thread_id = 0; thread_id < num_client_threads; thread_id++) { + threads.emplace_back([&, thread_id] { + Random rng(thread_id); + CalculatorServiceProxy p( + client_messenger_, server_addr_, server_addr_.host()); + while (!done.load()) { + // Set a deadline in the future. We'll keep using this same deadline + // on each of our retries. + MonoTime deadline = MonoTime::Now() + MonoDelta::FromSeconds(8); + + for (int attempt = 1; !done.load(); attempt++) { + RpcController controller; + SleepRequestPB req; + SleepResponsePB resp; + controller.set_deadline(deadline); + req.set_sleep_micros(100000); + Status s = p.Sleep(req, &resp, &controller); + if (s.ok()) { + successes[thread_id]++; + break; + } + // We expect to get SERVER_TOO_BUSY errors because we have more clients than the + // server has handlers and queue slots. No other errors are expected. + CHECK(s.IsRemoteError() && + controller.error_response()->code() == rpc::ErrorStatusPB::ERROR_SERVER_TOO_BUSY) + << "Unexpected RPC failure: " << s.ToString(); + // Randomized exponential backoff (similar to that done by the scanners in the Kudu + // client.). + int backoff = (0.5 + rng.NextDoubleFraction() * 0.5) * (std::min(1 << attempt, 1000)); + VLOG(1) << "backoff " << backoff << "ms"; + SleepFor(MonoDelta::FromMilliseconds(backoff)); + } + } + }); + } + // Let the threads run for 5 seconds before stopping them. + SleepFor(MonoDelta::FromSeconds(5)); + done.store(true); + for (auto& t : threads) { + t.join(); + } + + // Before switching to earliest-deadline-first scheduling, the results + // here would typically look something like: + // 1 1 0 1 10 17 6 1 12 12 17 10 8 7 12 9 16 15 + // With the fix, we see something like: + // 9 9 9 8 9 9 9 9 9 9 9 9 9 9 9 9 9 + LOG(INFO) << "thread RPC success counts: " << successes; + + int sum = 0; + int min = std::numeric_limits<int>::max(); + for (int x : successes) { + sum += x; + min = std::min(min, x); + } + int avg = sum / successes.size(); + ASSERT_GT(min, avg / 2) + << "expected the least lucky thread to have at least half as many successes " + << "as the average thread: min=" << min << " avg=" << avg; +} + +TEST_F(RpcStubTest, TestDumpCallsInFlight) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + AsyncSleep sleep; + sleep.req.set_sleep_micros(100 * 1000); // 100ms + p.SleepAsync(sleep.req, &sleep.resp, &sleep.rpc, + boost::bind(&CountDownLatch::CountDown, &sleep.latch)); + + // Check the running RPC status on the client messenger. + DumpRunningRpcsRequestPB dump_req; + DumpRunningRpcsResponsePB dump_resp; + dump_req.set_include_traces(true); + + ASSERT_OK(client_messenger_->DumpRunningRpcs(dump_req, &dump_resp)); + LOG(INFO) << "client messenger: " << SecureDebugString(dump_resp); + ASSERT_EQ(1, dump_resp.outbound_connections_size()); + ASSERT_EQ(1, dump_resp.outbound_connections(0).calls_in_flight_size()); + ASSERT_EQ("Sleep", dump_resp.outbound_connections(0).calls_in_flight(0). + header().remote_method().method_name()); + ASSERT_GT(dump_resp.outbound_connections(0).calls_in_flight(0).micros_elapsed(), 0); + + // And the server messenger. + // We have to loop this until we find a result since the actual call is sent + // asynchronously off of the main thread (ie the server may not be handling it yet) + for (int i = 0; i < 100; i++) { + dump_resp.Clear(); + ASSERT_OK(server_messenger_->DumpRunningRpcs(dump_req, &dump_resp)); + if (dump_resp.inbound_connections_size() > 0 && + dump_resp.inbound_connections(0).calls_in_flight_size() > 0) { + break; + } + SleepFor(MonoDelta::FromMilliseconds(1)); + } + + LOG(INFO) << "server messenger: " << SecureDebugString(dump_resp); + ASSERT_EQ(1, dump_resp.inbound_connections_size()); + ASSERT_EQ(1, dump_resp.inbound_connections(0).calls_in_flight_size()); + ASSERT_EQ("Sleep", dump_resp.inbound_connections(0).calls_in_flight(0). + header().remote_method().method_name()); + ASSERT_GT(dump_resp.inbound_connections(0).calls_in_flight(0).micros_elapsed(), 0); + ASSERT_STR_CONTAINS(dump_resp.inbound_connections(0).calls_in_flight(0).trace_buffer(), + "Inserting onto call queue"); + sleep.latch.Wait(); +} + +TEST_F(RpcStubTest, TestDumpSampledCalls) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + // Issue two calls that fall into different latency buckets. + AsyncSleep sleeps[2]; + sleeps[0].req.set_sleep_micros(150 * 1000); // 150ms + sleeps[1].req.set_sleep_micros(1500 * 1000); // 1500ms + + for (auto& sleep : sleeps) { + p.SleepAsync(sleep.req, &sleep.resp, &sleep.rpc, + boost::bind(&CountDownLatch::CountDown, &sleep.latch)); + } + for (auto& sleep : sleeps) { + sleep.latch.Wait(); + } + + // Dump the sampled RPCs and expect to see the calls + // above. + + DumpRpczStoreResponsePB sampled_rpcs; + server_messenger_->rpcz_store()->DumpPB(DumpRpczStoreRequestPB(), &sampled_rpcs); + EXPECT_EQ(sampled_rpcs.methods_size(), 1); + ASSERT_STR_CONTAINS(SecureDebugString(sampled_rpcs), + " metrics {\n" + " key: \"test_sleep_us\"\n" + " value: 150000\n" + " }\n"); + ASSERT_STR_CONTAINS(SecureDebugString(sampled_rpcs), + " metrics {\n" + " key: \"test_sleep_us\"\n" + " value: 1500000\n" + " }\n"); + ASSERT_STR_CONTAINS(SecureDebugString(sampled_rpcs), + " metrics {\n" + " child_path: \"test_child\"\n" + " key: \"related_trace_metric\"\n" + " value: 1\n" + " }"); + ASSERT_STR_CONTAINS(SecureDebugString(sampled_rpcs), "SleepRequestPB"); + ASSERT_STR_CONTAINS(SecureDebugString(sampled_rpcs), "duration_ms"); +} + +namespace { +struct RefCountedTest : public RefCountedThreadSafe<RefCountedTest> { +}; + +// Test callback which takes a refcounted pointer. +// We don't use this parameter, but it's used to validate that the bound callback +// is cleared in TestCallbackClearedAfterRunning. +void MyTestCallback(CountDownLatch* latch, scoped_refptr<RefCountedTest> my_refptr) { + latch->CountDown(); +} +} // anonymous namespace + +// Verify that, after a call has returned, no copy of the call's callback +// is held. This is important when the callback holds a refcounted ptr, +// since we expect to be able to release that pointer when the call is done. +TEST_F(RpcStubTest, TestCallbackClearedAfterRunning) { + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + CountDownLatch latch(1); + scoped_refptr<RefCountedTest> my_refptr(new RefCountedTest); + RpcController controller; + AddRequestPB req; + req.set_x(10); + req.set_y(20); + AddResponsePB resp; + p.AddAsync(req, &resp, &controller, + boost::bind(MyTestCallback, &latch, my_refptr)); + latch.Wait(); + + // The ref count should go back down to 1. However, we need to loop a little + // bit, since the deref is happening on another thread. If the other thread gets + // descheduled directly after calling our callback, we'd fail without these sleeps. + for (int i = 0; i < 100 && !my_refptr->HasOneRef(); i++) { + SleepFor(MonoDelta::FromMilliseconds(1)); + } + ASSERT_TRUE(my_refptr->HasOneRef()); +} + +// Regression test for KUDU-1409: if the client reactor thread is blocked (e.g due to a +// process-wide pause or a slow callback) then we should not cause RPC calls to time out. +TEST_F(RpcStubTest, DontTimeOutWhenReactorIsBlocked) { + CHECK_EQ(client_messenger_->num_reactors(), 1) + << "This test requires only a single reactor. Otherwise the injected sleep might " + << "be scheduled on a different reactor than the RPC call."; + + CalculatorServiceProxy p(client_messenger_, server_addr_, server_addr_.host()); + + // Schedule a 1-second sleep on the reactor thread. + // + // This will cause us the reactor to be blocked while the call response is received, and + // still be blocked when the timeout would normally occur. Despite this, the call should + // not time out. + // + // 0s 0.5s 1.2s 1.5s + // RPC call running + // |---------------------| + // Reactor blocked in sleep + // |----------------------| + // \_ RPC would normally time out + + client_messenger_->ScheduleOnReactor([](const Status& s) { + ThreadRestrictions::ScopedAllowWait allow_wait; + SleepFor(MonoDelta::FromSeconds(1)); + }, MonoDelta::FromSeconds(0.5)); + + RpcController controller; + SleepRequestPB req; + SleepResponsePB resp; + req.set_sleep_micros(800 * 1000); + controller.set_timeout(MonoDelta::FromMilliseconds(1200)); + ASSERT_OK(p.Sleep(req, &resp, &controller)); +} + +} // namespace rpc +} // namespace kudu
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rpcz_store.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rpcz_store.cc b/be/src/kudu/rpc/rpcz_store.cc new file mode 100644 index 0000000..2f0e9c8 --- /dev/null +++ b/be/src/kudu/rpc/rpcz_store.cc @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/rpcz_store.h" + +#include <algorithm> // IWYU pragma: keep +#include <array> +#include <cstdint> +#include <mutex> // for unique_lock +#include <ostream> +#include <string> +#include <utility> +#include <vector> + +#include <gflags/gflags.h> +#include <glog/logging.h> +#include <google/protobuf/message.h> + +#include "kudu/gutil/port.h" +#include "kudu/gutil/ref_counted.h" +#include "kudu/gutil/strings/stringpiece.h" +#include "kudu/gutil/walltime.h" +#include "kudu/rpc/inbound_call.h" +#include "kudu/rpc/rpc_header.pb.h" +#include "kudu/rpc/rpc_introspection.pb.h" +#include "kudu/rpc/service_if.h" +#include "kudu/util/atomic.h" +#include "kudu/util/flag_tags.h" +#include "kudu/util/monotime.h" +#include "kudu/util/trace.h" +#include "kudu/util/trace_metrics.h" + +DEFINE_bool(rpc_dump_all_traces, false, + "If true, dump all RPC traces at INFO level"); +TAG_FLAG(rpc_dump_all_traces, advanced); +TAG_FLAG(rpc_dump_all_traces, runtime); + +DEFINE_int32(rpc_duration_too_long_ms, 1000, + "Threshold (in milliseconds) above which a RPC is considered too long and its " + "duration and method name are logged at INFO level. The time measured is between " + "when a RPC is accepted and when its call handler completes."); +TAG_FLAG(rpc_duration_too_long_ms, advanced); +TAG_FLAG(rpc_duration_too_long_ms, runtime); + +using std::pair; +using std::string; +using std::vector; +using std::unique_ptr; + +namespace kudu { +namespace rpc { + +// Sample an RPC call once every N milliseconds within each +// bucket. If the current sample in a latency bucket is older +// than this threshold, a new sample will be taken. +static const int kSampleIntervalMs = 1000; + +static const int kBucketThresholdsMs[] = {10, 100, 1000}; +static constexpr int kNumBuckets = arraysize(kBucketThresholdsMs) + 1; + +// An instance of this class is created For each RPC method implemented +// on the server. It keeps several recent samples for each RPC, currently +// based on fixed time buckets. +class MethodSampler { + public: + MethodSampler() {} + ~MethodSampler() {} + + // Potentially sample a single call. + void SampleCall(InboundCall* call); + + // Dump the current samples. + void GetSamplePBs(RpczMethodPB* pb); + + private: + // Convert the trace metrics from 't' into protobuf entries in 'sample_pb'. + // This function recurses through the parent-child relationship graph, + // keeping the current tree path in 'child_path' (empty at the root). + static void GetTraceMetrics(const Trace& t, + const string& child_path, + RpczSamplePB* sample_pb); + + // An individual recorded sample. + struct Sample { + RequestHeader header; + scoped_refptr<Trace> trace; + int duration_ms; + }; + + // A sample, including the particular time at which it was + // sampled, and a lock protecting it. + struct SampleBucket { + SampleBucket() : last_sample_time(0) {} + + AtomicInt<int64_t> last_sample_time; + simple_spinlock sample_lock; + Sample sample; + }; + std::array<SampleBucket, kNumBuckets> buckets_; + + DISALLOW_COPY_AND_ASSIGN(MethodSampler); +}; + +MethodSampler* RpczStore::SamplerForCall(InboundCall* call) { + if (PREDICT_FALSE(!call->method_info())) { + return nullptr; + } + + // Most likely, we already have a sampler created for the call. + { + shared_lock<rw_spinlock> l(samplers_lock_.get_lock()); + auto it = method_samplers_.find(call->method_info()); + if (PREDICT_TRUE(it != method_samplers_.end())) { + return it->second.get(); + } + } + + // If missing, create a new sampler for this method and try to insert it. + unique_ptr<MethodSampler> ms(new MethodSampler()); + std::lock_guard<percpu_rwlock> lock(samplers_lock_); + auto it = method_samplers_.find(call->method_info()); + if (it != method_samplers_.end()) { + return it->second.get(); + } + auto* ret = ms.get(); + method_samplers_[call->method_info()] = std::move(ms); + return ret; +} + +void MethodSampler::SampleCall(InboundCall* call) { + // First determine which sample bucket to put this in. + int duration_ms = call->timing().TotalDuration().ToMilliseconds(); + + SampleBucket* bucket = &buckets_[kNumBuckets - 1]; + for (int i = 0 ; i < kNumBuckets - 1; i++) { + if (duration_ms < kBucketThresholdsMs[i]) { + bucket = &buckets_[i]; + break; + } + } + + MicrosecondsInt64 now = GetMonoTimeMicros(); + int64_t us_since_trace = now - bucket->last_sample_time.Load(); + if (us_since_trace > kSampleIntervalMs * 1000) { + Sample new_sample = {call->header(), call->trace(), duration_ms}; + { + std::unique_lock<simple_spinlock> lock(bucket->sample_lock, std::try_to_lock); + // If another thread is already taking a sample, it's not worth waiting. + if (!lock.owns_lock()) { + return; + } + std::swap(bucket->sample, new_sample); + bucket->last_sample_time.Store(now); + } + VLOG(2) << "Sampled call " << call->ToString(); + } +} + +void MethodSampler::GetTraceMetrics(const Trace& t, + const string& child_path, + RpczSamplePB* sample_pb) { + auto m = t.metrics().Get(); + for (const auto& e : m) { + auto* pb = sample_pb->add_metrics(); + pb->set_key(e.first); + pb->set_value(e.second); + if (!child_path.empty()) { + pb->set_child_path(child_path); + } + } + + for (const auto& child_pair : t.ChildTraces()) { + string path = child_path; + if (!path.empty()) { + path += "."; + } + path += child_pair.first.ToString(); + GetTraceMetrics(*child_pair.second.get(), path, sample_pb); + } +} + +void MethodSampler::GetSamplePBs(RpczMethodPB* method_pb) { + for (auto& bucket : buckets_) { + if (bucket.last_sample_time.Load() == 0) continue; + + std::unique_lock<simple_spinlock> lock(bucket.sample_lock); + auto* sample_pb = method_pb->add_samples(); + sample_pb->mutable_header()->CopyFrom(bucket.sample.header); + sample_pb->set_trace(bucket.sample.trace->DumpToString(Trace::INCLUDE_TIME_DELTAS)); + + GetTraceMetrics(*bucket.sample.trace.get(), "", sample_pb); + sample_pb->set_duration_ms(bucket.sample.duration_ms); + } +} + +RpczStore::RpczStore() {} +RpczStore::~RpczStore() {} + +void RpczStore::AddCall(InboundCall* call) { + LogTrace(call); + auto* sampler = SamplerForCall(call); + if (PREDICT_FALSE(!sampler)) return; + + sampler->SampleCall(call); +} + +void RpczStore::DumpPB(const DumpRpczStoreRequestPB& req, + DumpRpczStoreResponsePB* resp) { + vector<pair<RpcMethodInfo*, MethodSampler*>> samplers; + { + shared_lock<rw_spinlock> l(samplers_lock_.get_lock()); + for (const auto& p : method_samplers_) { + samplers.emplace_back(p.first, p.second.get()); + } + } + + for (const auto& p : samplers) { + auto* sampler = p.second; + + RpczMethodPB* method_pb = resp->add_methods(); + // TODO: use the actual RPC name instead of the request type name. + // Currently this isn't conveniently plumbed here, but the type name + // is close enough. + method_pb->set_method_name(p.first->req_prototype->GetTypeName()); + sampler->GetSamplePBs(method_pb); + } +} + +void RpczStore::LogTrace(InboundCall* call) { + int duration_ms = call->timing().TotalDuration().ToMilliseconds(); + + if (call->header_.has_timeout_millis() && call->header_.timeout_millis() > 0) { + double log_threshold = call->header_.timeout_millis() * 0.75f; + if (duration_ms > log_threshold) { + // TODO: consider pushing this onto another thread since it may be slow. + // The traces may also be too large to fit in a log message. + LOG(WARNING) << call->ToString() << " took " << duration_ms << "ms (client timeout " + << call->header_.timeout_millis() << ")."; + string s = call->trace()->DumpToString(); + if (!s.empty()) { + LOG(WARNING) << "Trace:\n" << s; + } + return; + } + } + + if (PREDICT_FALSE(FLAGS_rpc_dump_all_traces)) { + LOG(INFO) << call->ToString() << " took " << duration_ms << "ms. Trace:"; + call->trace()->Dump(&LOG(INFO), true); + } else if (duration_ms > FLAGS_rpc_duration_too_long_ms) { + LOG(INFO) << call->ToString() << " took " << duration_ms << "ms. " + << "Request Metrics: " << call->trace()->MetricsAsJSON(); + } +} + + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rpcz_store.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rpcz_store.h b/be/src/kudu/rpc/rpcz_store.h new file mode 100644 index 0000000..48e4474 --- /dev/null +++ b/be/src/kudu/rpc/rpcz_store.h @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include "kudu/gutil/macros.h" + +#include <memory> +#include <unordered_map> + +#include "kudu/util/locks.h" + +namespace kudu { +namespace rpc { + +class DumpRpczStoreRequestPB; +class DumpRpczStoreResponsePB; +class InboundCall; +class MethodSampler; +struct RpcMethodInfo; + +// Responsible for storing sampled traces associated with completed calls. +// Before each call is responded to, it is added to this store. +class RpczStore { + public: + RpczStore(); + ~RpczStore(); + + // Process a single call, potentially sampling it for later analysis. + // + // If the call is sampled, it might be mutated. For example, the request + // and response might be taken from the call and stored as part of the + // sample. This should be called just before a call response is sent + // to the client. + void AddCall(InboundCall* c); + + // Dump all of the collected RPC samples in response to a user query. + void DumpPB(const DumpRpczStoreRequestPB& req, + DumpRpczStoreResponsePB* resp); + + private: + // Look up or create the particular MethodSampler instance which should + // store samples for this call. + MethodSampler* SamplerForCall(InboundCall* call); + + // Log a WARNING message if the RPC response was slow enough that the + // client likely timed out. This is based on the client-provided timeout + // value. + // Also can be configured to log _all_ RPC traces for help debugging. + void LogTrace(InboundCall* call); + + percpu_rwlock samplers_lock_; + + // Protected by samplers_lock_. + std::unordered_map<RpcMethodInfo*, std::unique_ptr<MethodSampler>> method_samplers_; + + DISALLOW_COPY_AND_ASSIGN(RpczStore); +}; + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rtest.proto ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rtest.proto b/be/src/kudu/rpc/rtest.proto new file mode 100644 index 0000000..d212cef --- /dev/null +++ b/be/src/kudu/rpc/rtest.proto @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Test protocol for kudu RPC. +syntax = "proto2"; +package kudu.rpc_test; + +import "kudu/rpc/rpc_header.proto"; +import "kudu/rpc/rtest_diff_package.proto"; + +message AddRequestPB { + required uint32 x = 1; + required uint32 y = 2; +} + +// Used by tests to simulate an old client which is missing +// a newly added required field. +message AddRequestPartialPB { + required uint32 x = 1; +} + +message AddResponsePB { + required uint32 result = 1; +} + +message SleepRequestPB { + required uint32 sleep_micros = 1; + + // Used in rpc_stub-test: if this is true, it will respond from a different + // thread than the one that receives the request. + optional bool deferred = 2 [ default = false ]; + + // If set, returns a CalculatorError response. + optional bool return_app_error = 3 [ default = false ]; + + // Used in rpc-test: if this is set to true and no client timeout is set, + // the service will respond to the client with an error. + optional bool client_timeout_defined = 4 [ default = false ]; +} + +message SleepResponsePB { +} + +message SleepWithSidecarRequestPB { + required uint32 sleep_micros = 1; + required uint32 pattern = 2; + required uint32 num_repetitions = 3; + required uint32 sidecar_idx = 4; +} + +message SleepWithSidecarResponsePB { +} + +message SendTwoStringsRequestPB { + required uint32 random_seed = 1; + required uint64 size1 = 2; + required uint64 size2 = 3; +} + +message SendTwoStringsResponsePB { + required uint32 sidecar1 = 1; + required uint32 sidecar2 = 2; +} + +// Push two strings to the server as part of the request, in sidecars. +message PushTwoStringsRequestPB { + required uint32 sidecar1_idx = 1; + required uint32 sidecar2_idx = 2; +} + +message PushTwoStringsResponsePB { + required uint32 size1 = 1; + required string data1 = 2; + required uint32 size2 = 3; + required string data2 = 4; +} + +message EchoRequestPB { + required string data = 1; +} +message EchoResponsePB { + required string data = 1; +} + +message WhoAmIRequestPB { +} +message WhoAmIResponsePB { + required kudu.rpc.UserInformationPB credentials = 1; + required string address = 2; +} + +message CalculatorError { + extend kudu.rpc.ErrorStatusPB { + optional CalculatorError app_error_ext = 101; + } + + required string extra_error_data = 1; +} + +message PanicRequestPB {} +message PanicResponsePB {} + +message TestInvalidResponseRequestPB { + enum ErrorType { + MISSING_REQUIRED_FIELD = 1; + RESPONSE_TOO_LARGE = 2; + } + required ErrorType error_type = 1; +} + +message TestInvalidResponseResponsePB { + required bytes response = 1; +} + +enum FeatureFlags { + UNKNOWN=0; + FOO=1; +} + +message ExactlyOnceRequestPB { + optional uint32 sleep_for_ms = 1 [default = 0]; + required uint32 value_to_add = 2; + optional bool randomly_fail = 3 [default = false]; +} +message ExactlyOnceResponsePB { + required uint32 current_val = 1; + required fixed64 current_time_micros = 2; +} + +service CalculatorService { + option (kudu.rpc.default_authz_method) = "AuthorizeDisallowAlice"; + + rpc Add(AddRequestPB) returns(AddResponsePB); + rpc Sleep(SleepRequestPB) returns(SleepResponsePB) { + option (kudu.rpc.authz_method) = "AuthorizeDisallowBob"; + }; + rpc Echo(EchoRequestPB) returns(EchoResponsePB); + rpc WhoAmI(WhoAmIRequestPB) returns (WhoAmIResponsePB); + rpc TestArgumentsInDiffPackage(kudu.rpc_test_diff_package.ReqDiffPackagePB) + returns(kudu.rpc_test_diff_package.RespDiffPackagePB); + rpc Panic(PanicRequestPB) returns (PanicResponsePB); + rpc AddExactlyOnce(ExactlyOnceRequestPB) returns (ExactlyOnceResponsePB) { + option (kudu.rpc.track_rpc_result) = true; + } + rpc TestInvalidResponse(TestInvalidResponseRequestPB) returns (TestInvalidResponseResponsePB); +} http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/rtest_diff_package.proto ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/rtest_diff_package.proto b/be/src/kudu/rpc/rtest_diff_package.proto new file mode 100644 index 0000000..f6f9b60 --- /dev/null +++ b/be/src/kudu/rpc/rtest_diff_package.proto @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Request/Response in different package to test that RPC methods +// handle arguments with packages different from the service itself. +syntax = "proto2"; +package kudu.rpc_test_diff_package; + +message ReqDiffPackagePB { +} +message RespDiffPackagePB { +} http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/sasl_common.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_common.cc b/be/src/kudu/rpc/sasl_common.cc new file mode 100644 index 0000000..645e854 --- /dev/null +++ b/be/src/kudu/rpc/sasl_common.cc @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/sasl_common.h" + +#include <cstdio> +#include <cstring> +#include <limits> +#include <mutex> +#include <ostream> +#include <string> + +#include <boost/algorithm/string/predicate.hpp> +#include <glog/logging.h> +#include <regex.h> +#include <sasl/sasl.h> +#include <sasl/saslplug.h> + +#include "kudu/gutil/macros.h" +#include "kudu/rpc/constants.h" +#include "kudu/security/init.h" +#include "kudu/util/mutex.h" +#include "kudu/util/net/sockaddr.h" +#include "kudu/util/rw_mutex.h" + +using std::set; +using std::string; + +namespace kudu { +namespace rpc { + +const char* const kSaslMechPlain = "PLAIN"; +const char* const kSaslMechGSSAPI = "GSSAPI"; +extern const size_t kSaslMaxBufSize = 1024; + +// See WrapSaslCall(). +static __thread string* g_auth_failure_capture = nullptr; + +// Determine whether initialization was ever called +static Status sasl_init_status = Status::OK(); +static bool sasl_is_initialized = false; + +// If true, then we expect someone else has initialized SASL. +static bool g_disable_sasl_init = false; + +// If true, we expect kerberos to be enabled. +static bool has_kerberos_keytab = false; + +// Output Sasl messages. +// context: not used. +// level: logging level. +// message: message to output; +static int SaslLogCallback(void* context, int level, const char* message) { + + if (message == nullptr) return SASL_BADPARAM; + + switch (level) { + case SASL_LOG_NONE: + break; + + case SASL_LOG_ERR: + LOG(ERROR) << "SASL: " << message; + break; + + case SASL_LOG_WARN: + LOG(WARNING) << "SASL: " << message; + break; + + case SASL_LOG_NOTE: + LOG(INFO) << "SASL: " << message; + break; + + case SASL_LOG_FAIL: + // Capture authentication failures in a thread-local to be picked up + // by WrapSaslCall() below. + VLOG(1) << "SASL: " << message; + if (g_auth_failure_capture) { + *g_auth_failure_capture = message; + } + break; + + case SASL_LOG_DEBUG: + VLOG(1) << "SASL: " << message; + break; + + case SASL_LOG_TRACE: + case SASL_LOG_PASS: + VLOG(3) << "SASL: " << message; + break; + } + + return SASL_OK; +} + +// Get Sasl option. +// context: not used +// plugin_name: name of plugin for which an option is being requested. +// option: option requested +// result: set to result which persists until next getopt in same thread, +// unchanged if option not found +// len: length of the result +// Return SASL_FAIL if the option is not handled, this does not fail the handshake. +static int SaslGetOption(void* context, const char* plugin_name, const char* option, + const char** result, unsigned* len) { + // Handle Sasl Library options + if (plugin_name == nullptr) { + // Return the logging level that we want the sasl library to use. + if (strcmp("log_level", option) == 0) { + int level = SASL_LOG_NOTE; + if (VLOG_IS_ON(1)) { + level = SASL_LOG_DEBUG; + } else if (VLOG_IS_ON(3)) { + level = SASL_LOG_TRACE; + } + // The library's contract for this method is that the caller gets to keep + // the returned buffer until the next call by the same thread, so we use a + // threadlocal for the buffer. + static __thread char buf[4]; + snprintf(buf, arraysize(buf), "%d", level); + *result = buf; + if (len != nullptr) *len = strlen(buf); + return SASL_OK; + } + // Options can default so don't complain. + VLOG(4) << "SaslGetOption: Unknown library option: " << option; + return SASL_FAIL; + } + VLOG(4) << "SaslGetOption: Unknown plugin: " << plugin_name; + return SASL_FAIL; +} + +// Array of callbacks for the sasl library. +static sasl_callback_t callbacks[] = { + { SASL_CB_LOG, reinterpret_cast<int (*)()>(&SaslLogCallback), nullptr }, + { SASL_CB_GETOPT, reinterpret_cast<int (*)()>(&SaslGetOption), nullptr }, + { SASL_CB_LIST_END, nullptr, nullptr } + // TODO(todd): provide a CANON_USER callback? This is necessary if we want + // to support some kind of auth-to-local mapping of Kerberos principals + // to local usernames. See Impala's implementation for inspiration. +}; + + +// SASL requires mutexes for thread safety, but doesn't implement +// them itself. So, we have to hook them up to our mutex implementation. +static void* SaslMutexAlloc() { + return static_cast<void*>(new Mutex()); +} +static void SaslMutexFree(void* m) { + delete static_cast<Mutex*>(m); +} +static int SaslMutexLock(void* m) { + static_cast<Mutex*>(m)->lock(); + return 0; // indicates success. +} +static int SaslMutexUnlock(void* m) { + static_cast<Mutex*>(m)->unlock(); + return 0; // indicates success. +} + +namespace internal { +void SaslSetMutex() { + sasl_set_mutex(&SaslMutexAlloc, &SaslMutexLock, &SaslMutexUnlock, &SaslMutexFree); +} +} // namespace internal + +// Sasl initialization detection methods. The OS X SASL library doesn't define +// the sasl_global_utils symbol, so we have to use less robust methods of +// detection. +#if defined(__APPLE__) +static bool SaslIsInitialized() { + return sasl_global_listmech() != nullptr; +} +static bool SaslMutexImplementationProvided() { + return SaslIsInitialized(); +} +#else + +// This symbol is exported by the SASL library but not defined +// in the headers. It's marked as an API in the library source, +// so seems safe to rely on. +extern "C" sasl_utils_t* sasl_global_utils; +static bool SaslIsInitialized() { + return sasl_global_utils != nullptr; +} +static bool SaslMutexImplementationProvided() { + if (!SaslIsInitialized()) return false; + void* m = sasl_global_utils->mutex_alloc(); + sasl_global_utils->mutex_free(m); + // The default implementation of mutex_alloc just returns the constant pointer 0x1. + // This is a bit of an ugly heuristic, but seems unlikely that anyone would ever + // provide a valid implementation that returns an invalid pointer value. + return m != reinterpret_cast<void*>(1); +} +#endif + +// Actually perform the initialization for the SASL subsystem. +// Meant to be called via GoogleOnceInit(). +static void DoSaslInit(bool kerberos_keytab_provided) { + VLOG(3) << "Initializing SASL library"; + + has_kerberos_keytab = kerberos_keytab_provided; + + bool sasl_initialized = SaslIsInitialized(); + if (sasl_initialized && !g_disable_sasl_init) { + LOG(WARNING) << "SASL was initialized prior to Kudu's initialization. Skipping " + << "initialization. Call kudu::client::DisableSaslInitialization() " + << "to suppress this message."; + g_disable_sasl_init = true; + } + + if (g_disable_sasl_init) { + if (!sasl_initialized) { + sasl_init_status = Status::RuntimeError( + "SASL initialization was disabled, but SASL was not externally initialized."); + return; + } + if (!SaslMutexImplementationProvided()) { + LOG(WARNING) + << "SASL appears to be initialized by code outside of Kudu " + << "but was not provided with a mutex implementation! If " + << "manually initializing SASL, use sasl_set_mutex(3)."; + } + return; + } + internal::SaslSetMutex(); + int result = sasl_client_init(&callbacks[0]); + if (result != SASL_OK) { + sasl_init_status = Status::RuntimeError("Could not initialize SASL client", + sasl_errstring(result, nullptr, nullptr)); + return; + } + + result = sasl_server_init(&callbacks[0], kSaslAppName); + if (result != SASL_OK) { + sasl_init_status = Status::RuntimeError("Could not initialize SASL server", + sasl_errstring(result, nullptr, nullptr)); + return; + } + + sasl_is_initialized = true; +} + +Status DisableSaslInitialization() { + if (g_disable_sasl_init) return Status::OK(); + if (sasl_is_initialized) { + return Status::IllegalState("SASL already initialized. Initialization can only be disabled " + "before first usage."); + } + g_disable_sasl_init = true; + return Status::OK(); +} + +Status SaslInit(bool kerberos_keytab_provided) { + // Only execute SASL initialization once + static std::once_flag once; + std::call_once(once, DoSaslInit, kerberos_keytab_provided); + DCHECK_EQ(kerberos_keytab_provided, has_kerberos_keytab); + + return sasl_init_status; +} + +static string CleanSaslError(const string& err) { + // In the case of GSS failures, we often get the actual error message + // buried inside a bunch of generic cruft. Use a regexp to extract it + // out. Note that we avoid std::regex because it appears to be broken + // with older libstdcxx. + static regex_t re; + static std::once_flag once; + +#if defined(__APPLE__) + static const char* kGssapiPattern = "GSSAPI Error: Miscellaneous failure \\(see text \\((.+)\\)"; +#else + static const char* kGssapiPattern = "Unspecified GSS failure. +" + "Minor code may provide more information +" + "\\((.+)\\)"; +#endif + + std::call_once(once, []{ CHECK_EQ(0, regcomp(&re, kGssapiPattern, REG_EXTENDED)); }); + regmatch_t matches[2]; + if (regexec(&re, err.c_str(), arraysize(matches), matches, 0) == 0) { + return err.substr(matches[1].rm_so, matches[1].rm_eo - matches[1].rm_so); + } + return err; +} + +static string SaslErrDesc(int status, sasl_conn_t* conn) { + string err; + if (conn != nullptr) { + err = sasl_errdetail(conn); + } else { + err = sasl_errstring(status, nullptr, nullptr); + } + + return CleanSaslError(err); +} + +Status WrapSaslCall(sasl_conn_t* conn, const std::function<int()>& call) { + // In many cases, the GSSAPI SASL plugin will generate a nice error + // message as a message logged at SASL_LOG_FAIL logging level, but then + // return a useless one in sasl_errstring(). So, we set a global thread-local + // variable to capture any auth failure log message while we make the + // call into the library. + // + // The thread-local thing is a bit of a hack, but the logging callback + // is set globally rather than on a per-connection basis. + string err; + g_auth_failure_capture = &err; + + // Take the 'kerberos_reinit_lock' here to avoid a possible race with ticket renewal. + if (has_kerberos_keytab) kudu::security::KerberosReinitLock()->ReadLock(); + int rc = call(); + if (has_kerberos_keytab) kudu::security::KerberosReinitLock()->ReadUnlock(); + g_auth_failure_capture = nullptr; + + switch (rc) { + case SASL_OK: + return Status::OK(); + case SASL_CONTINUE: + return Status::Incomplete(""); + case SASL_FAIL: // Generic failure (encompasses missing krb5 credentials). + case SASL_BADAUTH: // Authentication failure. + case SASL_BADMAC: // Decode failure. + case SASL_NOAUTHZ: // Authorization failure. + case SASL_NOUSER: // User not found. + case SASL_WRONGMECH: // Server doesn't support requested mechanism. + case SASL_BADSERV: { // Server failed mutual authentication. + if (err.empty()) { + err = SaslErrDesc(rc, conn); + } else { + err = CleanSaslError(err); + } + return Status::NotAuthorized(err); + } + default: + return Status::RuntimeError(SaslErrDesc(rc, conn)); + } +} + +bool NeedsWrap(sasl_conn_t* sasl_conn) { + const unsigned* ssf; + int rc = sasl_getprop(sasl_conn, SASL_SSF, reinterpret_cast<const void**>(&ssf)); + CHECK_EQ(rc, SASL_OK) << "Failed to get SSF property on authenticated SASL connection"; + return *ssf != 0; +} + +uint32_t GetMaxSendBufferSize(sasl_conn_t* sasl_conn) { + const unsigned* max_buf_size; + int rc = sasl_getprop(sasl_conn, SASL_MAXOUTBUF, reinterpret_cast<const void**>(&max_buf_size)); + CHECK_EQ(rc, SASL_OK) + << "Failed to get max output buffer property on authenticated SASL connection"; + return *max_buf_size; +} + +Status SaslEncode(sasl_conn_t* conn, Slice plaintext, Slice* ciphertext) { + const char* out; + unsigned out_len; + RETURN_NOT_OK_PREPEND(WrapSaslCall(conn, [&] { + return sasl_encode(conn, + reinterpret_cast<const char*>(plaintext.data()), + plaintext.size(), + &out, &out_len); + }), "SASL encode failed"); + *ciphertext = Slice(out, out_len); + return Status::OK(); +} + +Status SaslDecode(sasl_conn_t* conn, Slice ciphertext, Slice* plaintext) { + const char* out; + unsigned out_len; + RETURN_NOT_OK_PREPEND(WrapSaslCall(conn, [&] { + return sasl_decode(conn, + reinterpret_cast<const char*>(ciphertext.data()), + ciphertext.size(), + &out, &out_len); + }), "SASL decode failed"); + *plaintext = Slice(out, out_len); + return Status::OK(); +} + +string SaslIpPortString(const Sockaddr& addr) { + string addr_str = addr.ToString(); + size_t colon_pos = addr_str.find(':'); + if (colon_pos != string::npos) { + addr_str[colon_pos] = ';'; + } + return addr_str; +} + +set<SaslMechanism::Type> SaslListAvailableMechs() { + set<SaslMechanism::Type> mechs; + + // Array of NULL-terminated strings. Array terminated with NULL. + for (const char** mech_strings = sasl_global_listmech(); + mech_strings != nullptr && *mech_strings != nullptr; + mech_strings++) { + auto mech = SaslMechanism::value_of(*mech_strings); + if (mech != SaslMechanism::INVALID) { + mechs.insert(mech); + } + } + return mechs; +} + +sasl_callback_t SaslBuildCallback(int id, int (*proc)(void), void* context) { + sasl_callback_t callback; + callback.id = id; + callback.proc = proc; + callback.context = context; + return callback; +} + +Status EnableProtection(sasl_conn_t* sasl_conn, + SaslProtection::Type minimum_protection, + size_t max_recv_buf_size) { + sasl_security_properties_t sec_props; + memset(&sec_props, 0, sizeof(sec_props)); + sec_props.min_ssf = minimum_protection; + sec_props.max_ssf = std::numeric_limits<sasl_ssf_t>::max(); + sec_props.maxbufsize = max_recv_buf_size; + + RETURN_NOT_OK_PREPEND(WrapSaslCall(sasl_conn, [&] { + return sasl_setprop(sasl_conn, SASL_SEC_PROPS, &sec_props); + }), "failed to set SASL security properties"); + return Status::OK(); +} + +SaslMechanism::Type SaslMechanism::value_of(const string& mech) { + if (boost::iequals(mech, "PLAIN")) { + return PLAIN; + } + if (boost::iequals(mech, "GSSAPI")) { + return GSSAPI; + } + return INVALID; +} + +const char* SaslMechanism::name_of(SaslMechanism::Type val) { + switch (val) { + case PLAIN: return "PLAIN"; + case GSSAPI: return "GSSAPI"; + default: + return "INVALID"; + } +} + +const char* SaslProtection::name_of(SaslProtection::Type val) { + switch (val) { + case SaslProtection::kAuthentication: return "authentication"; + case SaslProtection::kIntegrity: return "integrity"; + case SaslProtection::kPrivacy: return "privacy"; + } + LOG(FATAL) << "unknown SASL protection type: " << val; +} + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/sasl_common.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_common.h b/be/src/kudu/rpc/sasl_common.h new file mode 100644 index 0000000..2454cfd --- /dev/null +++ b/be/src/kudu/rpc/sasl_common.h @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef KUDU_RPC_SASL_COMMON_H +#define KUDU_RPC_SASL_COMMON_H + +#include <cstddef> +#include <cstdint> +#include <functional> +#include <set> +#include <string> + +#include <sasl/sasl.h> + +#include "kudu/gutil/port.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" + +namespace kudu { + +class Sockaddr; + +namespace rpc { + +// Constants +extern const char* const kSaslMechPlain; +extern const char* const kSaslMechGSSAPI; +extern const size_t kSaslMaxBufSize; + +struct SaslMechanism { + enum Type { + INVALID, + PLAIN, + GSSAPI + }; + static Type value_of(const std::string& mech); + static const char* name_of(Type val); +}; + +struct SaslProtection { + enum Type { + // SASL authentication without integrity or privacy. + kAuthentication = 0, + // Integrity protection, i.e. messages are HMAC'd. + kIntegrity = 1, + // Privacy protection, i.e. messages are encrypted. + kPrivacy = 2, + }; + static const char* name_of(Type val); +}; + +// Initialize the SASL library. +// appname: Name of the application for logging messages & sasl plugin configuration. +// Note that this string must remain allocated for the lifetime of the program. +// This function must be called before using SASL. +// If the library initializes without error, calling more than once has no effect. +// +// Some SASL plugins take time to initialize random number generators and other things, +// so the first time this function is invoked it may execute for several seconds. +// After that, it should be very fast. This function should be invoked as early as possible +// in the application lifetime to avoid SASL initialization taking place in a +// performance-critical section. +// +// This function is thread safe and uses a static lock. +// This function should NOT be called during static initialization. +Status SaslInit(bool kerberos_keytab_provided = false) WARN_UNUSED_RESULT; + +// Disable Kudu's initialization of SASL. See equivalent method in client.h. +Status DisableSaslInitialization() WARN_UNUSED_RESULT; + +// Wrap a call into the SASL library. 'call' should be a lambda which +// returns a SASL error code. +// +// The result is translated into a Status as follows: +// +// SASL_OK: Status::OK() +// SASL_CONTINUE: Status::Incomplete() +// otherwise: Status::NotAuthorized() +// +// The Status message is beautified to be more user-friendly compared +// to the underlying sasl_errdetails() call. +Status WrapSaslCall(sasl_conn_t* conn, const std::function<int()>& call) WARN_UNUSED_RESULT; + +// Return <ip>;<port> string formatted for SASL library use. +std::string SaslIpPortString(const Sockaddr& addr); + +// Return available plugin mechanisms for the given connection. +std::set<SaslMechanism::Type> SaslListAvailableMechs(); + +// Initialize and return a libsasl2 callback data structure based on the passed args. +// id: A SASL callback identifier (e.g., SASL_CB_GETOPT). +// proc: A C-style callback with appropriate signature based on the callback id, or NULL. +// context: An object to pass to the callback as the context pointer, or NULL. +sasl_callback_t SaslBuildCallback(int id, int (*proc)(void), void* context); + +// Enable SASL integrity and privacy protection on the connection. Also allows +// setting the minimum required protection level, and the maximum receive buffer +// size. +Status EnableProtection(sasl_conn_t* sasl_conn, + SaslProtection::Type minimum_protection = SaslProtection::kAuthentication, + size_t max_recv_buf_size = kSaslMaxBufSize) WARN_UNUSED_RESULT; + +// Returns true if the SASL connection has been negotiated with auth-int or +// auth-conf. 'sasl_conn' must already be negotiated. +bool NeedsWrap(sasl_conn_t* sasl_conn); + +// Retrieves the negotiated maximum send buffer size for auth-int or auth-conf +// protected channels. +uint32_t GetMaxSendBufferSize(sasl_conn_t* sasl_conn) WARN_UNUSED_RESULT; + +// Encode the provided data. +// +// The plaintext data must not be longer than the negotiated maximum buffer size. +// +// The output 'ciphertext' slice is only valid until the next use of the SASL connection. +Status SaslEncode(sasl_conn_t* conn, + Slice plaintext, + Slice* ciphertext) WARN_UNUSED_RESULT; + +// Decode the provided SASL-encoded data. +// +// The decoded plaintext must not be longer than the negotiated maximum buffer size. +// +// The output 'plaintext' slice is only valid until the next use of the SASL connection. +Status SaslDecode(sasl_conn_t* conn, + Slice ciphertext, + Slice* plaintext) WARN_UNUSED_RESULT; + +// Deleter for sasl_conn_t instances, for use with gscoped_ptr after calling sasl_*_new() +struct SaslDeleter { + inline void operator()(sasl_conn_t* conn) { + sasl_dispose(&conn); + } +}; + +// Internals exposed in the header for test purposes. +namespace internal { +void SaslSetMutex(); +} // namespace internal + +} // namespace rpc +} // namespace kudu + +#endif http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/sasl_helper.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_helper.cc b/be/src/kudu/rpc/sasl_helper.cc new file mode 100644 index 0000000..765118e --- /dev/null +++ b/be/src/kudu/rpc/sasl_helper.cc @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/sasl_helper.h" + +#include <cstring> +#include <ostream> +#include <set> +#include <string> + +#include <glog/logging.h> +#include <sasl/sasl.h> + +#include "kudu/gutil/map-util.h" +#include "kudu/gutil/port.h" +#include "kudu/gutil/strings/join.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/constants.h" +#include "kudu/rpc/rpc_header.pb.h" +#include "kudu/rpc/sasl_common.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" + +using std::string; + +namespace kudu { +namespace rpc { + +SaslHelper::SaslHelper(PeerType peer_type) + : peer_type_(peer_type), + global_mechs_(SaslListAvailableMechs()), + plain_enabled_(false), + gssapi_enabled_(false) { + tag_ = (peer_type_ == SERVER) ? "Server" : "Client"; +} + +void SaslHelper::set_server_fqdn(const string& domain_name) { + server_fqdn_ = domain_name; +} +const char* SaslHelper::server_fqdn() const { + return server_fqdn_.empty() ? nullptr : server_fqdn_.c_str(); +} + +const char* SaslHelper::EnabledMechsString() const { + enabled_mechs_string_ = JoinMapped(enabled_mechs_, SaslMechanism::name_of, " "); + return enabled_mechs_string_.c_str(); +} + +int SaslHelper::GetOptionCb(const char* plugin_name, const char* option, + const char** result, unsigned* len) { + DVLOG(4) << tag_ << ": GetOption Callback called. "; + DVLOG(4) << tag_ << ": GetOption Plugin name: " + << (plugin_name == nullptr ? "NULL" : plugin_name); + DVLOG(4) << tag_ << ": GetOption Option name: " << option; + + if (PREDICT_FALSE(result == nullptr)) { + LOG(DFATAL) << tag_ << ": SASL Library passed NULL result out-param to GetOption callback!"; + return SASL_BADPARAM; + } + + if (plugin_name == nullptr) { + // SASL library option, not a plugin option + if (strcmp(option, "mech_list") == 0) { + *result = EnabledMechsString(); + if (len != nullptr) *len = strlen(*result); + VLOG(4) << tag_ << ": Enabled mech list: " << *result; + return SASL_OK; + } + VLOG(4) << tag_ << ": GetOptionCb: Unknown library option: " << option; + } else { + VLOG(4) << tag_ << ": GetOptionCb: Unknown plugin: " << plugin_name; + } + return SASL_FAIL; +} + +Status SaslHelper::EnablePlain() { + RETURN_NOT_OK(EnableMechanism(SaslMechanism::PLAIN)); + plain_enabled_ = true; + return Status::OK(); +} + +Status SaslHelper::EnableGSSAPI() { + RETURN_NOT_OK(EnableMechanism(SaslMechanism::GSSAPI)); + gssapi_enabled_ = true; + return Status::OK(); +} + +Status SaslHelper::EnableMechanism(SaslMechanism::Type mech) { + if (PREDICT_FALSE(!ContainsKey(global_mechs_, mech))) { + return Status::InvalidArgument("unable to find SASL plugin", SaslMechanism::name_of(mech)); + } + enabled_mechs_.insert(mech); + return Status::OK(); +} + +bool SaslHelper::IsPlainEnabled() const { + return plain_enabled_; +} + +Status SaslHelper::CheckNegotiateCallId(int32_t call_id) const { + if (call_id != kNegotiateCallId) { + Status s = Status::IllegalState(strings::Substitute( + "Received illegal call-id during negotiation; expected: $0, received: $1", + kNegotiateCallId, call_id)); + LOG(DFATAL) << tag_ << ": " << s.ToString(); + return s; + } + return Status::OK(); +} + +Status SaslHelper::ParseNegotiatePB(const Slice& param_buf, NegotiatePB* msg) { + if (!msg->ParseFromArray(param_buf.data(), param_buf.size())) { + return Status::IOError(tag_ + ": Invalid SASL message, missing fields", + msg->InitializationErrorString()); + } + return Status::OK(); +} + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/sasl_helper.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_helper.h b/be/src/kudu/rpc/sasl_helper.h new file mode 100644 index 0000000..aa0c8bf --- /dev/null +++ b/be/src/kudu/rpc/sasl_helper.h @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef KUDU_RPC_SASL_HELPER_H +#define KUDU_RPC_SASL_HELPER_H + +#include <cstdint> +#include <set> +#include <string> + +#include "kudu/gutil/macros.h" +#include "kudu/rpc/sasl_common.h" + +namespace kudu { + +class Slice; +class Status; + +namespace rpc { + +class NegotiatePB; + +// Helper class which contains functionality that is common to client and server +// SASL negotiations. Most of these methods are convenience methods for +// interacting with the libsasl2 library. +class SaslHelper { + public: + enum PeerType { + CLIENT, + SERVER + }; + + explicit SaslHelper(PeerType peer_type); + ~SaslHelper() = default; + + // Specify the fully-qualified domain name of the remote server. + void set_server_fqdn(const std::string& domain_name); + const char* server_fqdn() const; + + // Globally-registered available SASL plugins. + const std::set<SaslMechanism::Type>& GlobalMechs() const { + return global_mechs_; + } + + // Helper functions for managing the list of active SASL mechanisms. + const std::set<SaslMechanism::Type>& EnabledMechs() const { + return enabled_mechs_; + } + + // Implements the client_mech_list / mech_list callbacks. + int GetOptionCb(const char* plugin_name, const char* option, const char** result, unsigned* len); + + // Enable the PLAIN SASL mechanism. + Status EnablePlain(); + + // Enable the GSSAPI (Kerberos) mechanism. + Status EnableGSSAPI(); + + // Check for the PLAIN SASL mechanism. + bool IsPlainEnabled() const; + + // Sanity check that the call ID is the negotiation call ID. + // Logs DFATAL if call_id does not match. + Status CheckNegotiateCallId(int32_t call_id) const; + + // Parse msg from the given Slice. + Status ParseNegotiatePB(const Slice& param_buf, NegotiatePB* msg); + + private: + Status EnableMechanism(SaslMechanism::Type mech); + + // Returns space-delimited local mechanism list string suitable for passing + // to libsasl2, such as via "mech_list" callbacks. + // The returned pointer is valid only until the next call to EnabledMechsString(). + const char* EnabledMechsString() const; + + std::string server_fqdn_; + + // Authentication types and data. + const PeerType peer_type_; + std::string tag_; + std::set<SaslMechanism::Type> global_mechs_; // Cache of global mechanisms. + std::set<SaslMechanism::Type> enabled_mechs_; // Active mechanisms. + mutable std::string enabled_mechs_string_; // Mechanism list string returned by callbacks. + + bool plain_enabled_; + bool gssapi_enabled_; + + DISALLOW_COPY_AND_ASSIGN(SaslHelper); +}; + +} // namespace rpc +} // namespace kudu + +#endif // KUDU_RPC_SASL_HELPER_H http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/serialization.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/serialization.cc b/be/src/kudu/rpc/serialization.cc new file mode 100644 index 0000000..473a817 --- /dev/null +++ b/be/src/kudu/rpc/serialization.cc @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/serialization.h" + +#include <limits> +#include <ostream> +#include <string> + +#include <gflags/gflags_declare.h> +#include <glog/logging.h> +#include <google/protobuf/message_lite.h> +#include <google/protobuf/io/coded_stream.h> + +#include "kudu/gutil/endian.h" +#include "kudu/gutil/port.h" +#include "kudu/gutil/stringprintf.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/constants.h" +#include "kudu/util/faststring.h" +#include "kudu/util/logging.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" + +DECLARE_int64(rpc_max_message_size); + +using google::protobuf::MessageLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; +using strings::Substitute; + +namespace kudu { +namespace rpc { +namespace serialization { + +enum { + kHeaderPosVersion = 0, + kHeaderPosServiceClass = 1, + kHeaderPosAuthProto = 2 +}; + +void SerializeMessage(const MessageLite& message, faststring* param_buf, + int additional_size, bool use_cached_size) { + DCHECK_GE(additional_size, 0); + int pb_size = use_cached_size ? message.GetCachedSize() : message.ByteSize(); + DCHECK_EQ(message.ByteSize(), pb_size); + // Use 8-byte integers to avoid overflowing when additional_size approaches INT_MAX. + int64_t recorded_size = static_cast<int64_t>(pb_size) + + static_cast<int64_t>(additional_size); + int64_t size_with_delim = static_cast<int64_t>(pb_size) + + static_cast<int64_t>(CodedOutputStream::VarintSize32(recorded_size)); + int64_t total_size = size_with_delim + static_cast<int64_t>(additional_size); + // The message format relies on an unsigned 32-bit integer to express the size, so + // the message must not exceed this size. Since additional_size is limited to INT_MAX, + // this is a safe limitation. + CHECK_LE(total_size, std::numeric_limits<uint32_t>::max()); + + if (total_size > FLAGS_rpc_max_message_size) { + LOG(WARNING) << Substitute("Serialized $0 ($1 bytes) is larger than the maximum configured " + "RPC message size ($2 bytes). " + "Sending anyway, but peer may reject the data.", + message.GetTypeName(), total_size, FLAGS_rpc_max_message_size); + } + + param_buf->resize(size_with_delim); + uint8_t* dst = param_buf->data(); + dst = CodedOutputStream::WriteVarint32ToArray(recorded_size, dst); + dst = message.SerializeWithCachedSizesToArray(dst); + CHECK_EQ(dst, param_buf->data() + size_with_delim); +} + +void SerializeHeader(const MessageLite& header, + size_t param_len, + faststring* header_buf) { + + CHECK(header.IsInitialized()) + << "RPC header missing fields: " << header.InitializationErrorString(); + + // Compute all the lengths for the packet. + size_t header_pb_len = header.ByteSize(); + size_t header_tot_len = kMsgLengthPrefixLength // Int prefix for the total length. + + CodedOutputStream::VarintSize32(header_pb_len) // Varint delimiter for header PB. + + header_pb_len; // Length for the header PB itself. + size_t total_size = header_tot_len + param_len; + + header_buf->resize(header_tot_len); + uint8_t* dst = header_buf->data(); + + // 1. The length for the whole request, not including the 4-byte + // length prefix. + NetworkByteOrder::Store32(dst, total_size - kMsgLengthPrefixLength); + dst += sizeof(uint32_t); + + // 2. The varint-prefixed RequestHeader PB + dst = CodedOutputStream::WriteVarint32ToArray(header_pb_len, dst); + dst = header.SerializeWithCachedSizesToArray(dst); + + // We should have used the whole buffer we allocated. + CHECK_EQ(dst, header_buf->data() + header_tot_len); +} + +Status ParseMessage(const Slice& buf, + MessageLite* parsed_header, + Slice* parsed_main_message) { + + // First grab the total length + if (PREDICT_FALSE(buf.size() < kMsgLengthPrefixLength)) { + return Status::Corruption("Invalid packet: not enough bytes for length header", + KUDU_REDACT(buf.ToDebugString())); + } + + uint32_t total_len = NetworkByteOrder::Load32(buf.data()); + DCHECK_EQ(total_len, buf.size() - kMsgLengthPrefixLength) + << "Got mis-sized buffer: " << KUDU_REDACT(buf.ToDebugString()); + + if (total_len > std::numeric_limits<int32_t>::max()) { + return Status::Corruption(Substitute("Invalid packet: message had a length of $0, " + "but we only support messages up to $1 bytes\n", + total_len, std::numeric_limits<int32_t>::max())); + } + + CodedInputStream in(buf.data(), buf.size()); + // Protobuf enforces a 64MB total bytes limit on CodedInputStream by default. + // Override this default with the actual size of the buffer to allow messages + // larger than 64MB. + in.SetTotalBytesLimit(buf.size(), -1); + in.Skip(kMsgLengthPrefixLength); + + uint32_t header_len; + if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) { + return Status::Corruption("Invalid packet: missing header delimiter", + KUDU_REDACT(buf.ToDebugString())); + } + + CodedInputStream::Limit l; + l = in.PushLimit(header_len); + if (PREDICT_FALSE(!parsed_header->ParseFromCodedStream(&in))) { + return Status::Corruption("Invalid packet: header too short", + KUDU_REDACT(buf.ToDebugString())); + } + in.PopLimit(l); + + uint32_t main_msg_len; + if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) { + return Status::Corruption("Invalid packet: missing main msg length", + KUDU_REDACT(buf.ToDebugString())); + } + + if (PREDICT_FALSE(!in.Skip(main_msg_len))) { + return Status::Corruption( + StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len), + KUDU_REDACT(buf.ToDebugString())); + } + + if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) { + return Status::Corruption( + StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()), + KUDU_REDACT(buf.ToDebugString())); + } + + *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len, + main_msg_len); + return Status::OK(); +} + +void SerializeConnHeader(uint8_t* buf) { + memcpy(reinterpret_cast<char *>(buf), kMagicNumber, kMagicNumberLength); + buf += kMagicNumberLength; + buf[kHeaderPosVersion] = kCurrentRpcVersion; + buf[kHeaderPosServiceClass] = 0; // TODO: implement + buf[kHeaderPosAuthProto] = 0; // TODO: implement +} + +// validate the entire rpc header (magic number + flags) +Status ValidateConnHeader(const Slice& slice) { + DCHECK_EQ(kMagicNumberLength + kHeaderFlagsLength, slice.size()) + << "Invalid RPC header length"; + + // validate actual magic + if (!slice.starts_with(kMagicNumber)) { + if (slice.starts_with("GET ") || + slice.starts_with("POST") || + slice.starts_with("HEAD")) { + return Status::InvalidArgument("invalid negotation, appears to be an HTTP client on " + "the RPC port"); + } + return Status::InvalidArgument("connection must begin with magic number", kMagicNumber); + } + + const uint8_t *data = slice.data(); + data += kMagicNumberLength; + + // validate version + if (data[kHeaderPosVersion] != kCurrentRpcVersion) { + return Status::InvalidArgument("Unsupported RPC version", + StringPrintf("Received: %d, Supported: %d", + data[kHeaderPosVersion], kCurrentRpcVersion)); + } + + // TODO: validate additional header flags: + // RPC_SERVICE_CLASS + // RPC_AUTH_PROTOCOL + + return Status::OK(); +} + +} // namespace serialization +} // namespace rpc +} // namespace kudu