http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/sasl_common.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_common.cc b/be/src/kudu/rpc/sasl_common.cc new file mode 100644 index 0000000..9f14413 --- /dev/null +++ b/be/src/kudu/rpc/sasl_common.cc @@ -0,0 +1,459 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/sasl_common.h" + +#include <string.h> + +#include <algorithm> +#include <limits> +#include <mutex> +#include <string> + +#include <boost/algorithm/string/predicate.hpp> +#include <gflags/gflags.h> +#include <glog/logging.h> +#include <regex.h> +#include <sasl/sasl.h> +#include <sasl/saslplug.h> + +#include "kudu/gutil/macros.h" +#include "kudu/gutil/once.h" +#include "kudu/gutil/stringprintf.h" +#include "kudu/rpc/constants.h" +#include "kudu/util/flag_tags.h" +#include "kudu/util/mutex.h" +#include "kudu/util/net/sockaddr.h" +#include "kudu/util/rw_mutex.h" +#include "kudu/security/init.h" + +using std::set; + +DECLARE_string(keytab_file); + +namespace kudu { +namespace rpc { + +const char* const kSaslMechPlain = "PLAIN"; +const char* const kSaslMechGSSAPI = "GSSAPI"; +extern const size_t kSaslMaxOutBufLen = 1024; + +// See WrapSaslCall(). +static __thread string* g_auth_failure_capture = nullptr; + +// Determine whether initialization was ever called +static Status sasl_init_status = Status::OK(); +static bool sasl_is_initialized = false; + +// If true, then we expect someone else has initialized SASL. +static bool g_disable_sasl_init = false; + +// Output Sasl messages. +// context: not used. +// level: logging level. +// message: message to output; +static int SaslLogCallback(void* context, int level, const char* message) { + + if (message == nullptr) return SASL_BADPARAM; + + switch (level) { + case SASL_LOG_NONE: + break; + + case SASL_LOG_ERR: + LOG(ERROR) << "SASL: " << message; + break; + + case SASL_LOG_WARN: + LOG(WARNING) << "SASL: " << message; + break; + + case SASL_LOG_NOTE: + LOG(INFO) << "SASL: " << message; + break; + + case SASL_LOG_FAIL: + // Capture authentication failures in a thread-local to be picked up + // by WrapSaslCall() below. + VLOG(1) << "SASL: " << message; + if (g_auth_failure_capture) { + *g_auth_failure_capture = message; + } + break; + + case SASL_LOG_DEBUG: + VLOG(1) << "SASL: " << message; + break; + + case SASL_LOG_TRACE: + case SASL_LOG_PASS: + VLOG(3) << "SASL: " << message; + break; + } + + return SASL_OK; +} + +// Get Sasl option. +// context: not used +// plugin_name: name of plugin for which an option is being requested. +// option: option requested +// result: set to result which persists until next getopt in same thread, +// unchanged if option not found +// len: length of the result +// Return SASL_FAIL if the option is not handled, this does not fail the handshake. +static int SaslGetOption(void* context, const char* plugin_name, const char* option, + const char** result, unsigned* len) { + // Handle Sasl Library options + if (plugin_name == nullptr) { + // Return the logging level that we want the sasl library to use. + if (strcmp("log_level", option) == 0) { + int level = SASL_LOG_NOTE; + if (VLOG_IS_ON(1)) { + level = SASL_LOG_DEBUG; + } else if (VLOG_IS_ON(3)) { + level = SASL_LOG_TRACE; + } + // The library's contract for this method is that the caller gets to keep + // the returned buffer until the next call by the same thread, so we use a + // threadlocal for the buffer. + static __thread char buf[4]; + snprintf(buf, arraysize(buf), "%d", level); + *result = buf; + if (len != nullptr) *len = strlen(buf); + return SASL_OK; + } + // Options can default so don't complain. + VLOG(4) << "SaslGetOption: Unknown library option: " << option; + return SASL_FAIL; + } + VLOG(4) << "SaslGetOption: Unknown plugin: " << plugin_name; + return SASL_FAIL; +} + +// Array of callbacks for the sasl library. +static sasl_callback_t callbacks[] = { + { SASL_CB_LOG, reinterpret_cast<int (*)()>(&SaslLogCallback), nullptr }, + { SASL_CB_GETOPT, reinterpret_cast<int (*)()>(&SaslGetOption), nullptr }, + { SASL_CB_LIST_END, nullptr, nullptr } + // TODO(todd): provide a CANON_USER callback? This is necessary if we want + // to support some kind of auth-to-local mapping of Kerberos principals + // to local usernames. See Impala's implementation for inspiration. +}; + + +// SASL requires mutexes for thread safety, but doesn't implement +// them itself. So, we have to hook them up to our mutex implementation. +static void* SaslMutexAlloc() { + return static_cast<void*>(new Mutex()); +} +static void SaslMutexFree(void* m) { + delete static_cast<Mutex*>(m); +} +static int SaslMutexLock(void* m) { + static_cast<Mutex*>(m)->lock(); + return 0; // indicates success. +} +static int SaslMutexUnlock(void* m) { + static_cast<Mutex*>(m)->unlock(); + return 0; // indicates success. +} + +namespace internal { +void SaslSetMutex() { + sasl_set_mutex(&SaslMutexAlloc, &SaslMutexLock, &SaslMutexUnlock, &SaslMutexFree); +} +} // namespace internal + +// Sasl initialization detection methods. The OS X SASL library doesn't define +// the sasl_global_utils symbol, so we have to use less robust methods of +// detection. +#if defined(__APPLE__) +static bool SaslIsInitialized() { + return sasl_global_listmech() != nullptr; +} +static bool SaslMutexImplementationProvided() { + return SaslIsInitialized(); +} +#else + +// This symbol is exported by the SASL library but not defined +// in the headers. It's marked as an API in the library source, +// so seems safe to rely on. +extern "C" sasl_utils_t* sasl_global_utils; +static bool SaslIsInitialized() { + return sasl_global_utils != nullptr; +} +static bool SaslMutexImplementationProvided() { + if (!SaslIsInitialized()) return false; + void* m = sasl_global_utils->mutex_alloc(); + sasl_global_utils->mutex_free(m); + // The default implementation of mutex_alloc just returns the constant pointer 0x1. + // This is a bit of an ugly heuristic, but seems unlikely that anyone would ever + // provide a valid implementation that returns an invalid pointer value. + return m != reinterpret_cast<void*>(1); +} +#endif + +// Actually perform the initialization for the SASL subsystem. +// Meant to be called via GoogleOnceInit(). +static void DoSaslInit() { + VLOG(3) << "Initializing SASL library"; + + bool sasl_initialized = SaslIsInitialized(); + if (sasl_initialized && !g_disable_sasl_init) { + LOG(WARNING) << "SASL was initialized prior to Kudu's initialization. Skipping " + << "initialization. Call kudu::client::DisableSaslInitialization() " + << "to suppress this message."; + g_disable_sasl_init = true; + } + + if (g_disable_sasl_init) { + if (!sasl_initialized) { + sasl_init_status = Status::RuntimeError( + "SASL initialization was disabled, but SASL was not externally initialized."); + return; + } + if (!SaslMutexImplementationProvided()) { + LOG(WARNING) + << "SASL appears to be initialized by code outside of Kudu " + << "but was not provided with a mutex implementation! If " + << "manually initializing SASL, use sasl_set_mutex(3)."; + } + return; + } + internal::SaslSetMutex(); + int result = sasl_client_init(&callbacks[0]); + if (result != SASL_OK) { + sasl_init_status = Status::RuntimeError("Could not initialize SASL client", + sasl_errstring(result, nullptr, nullptr)); + return; + } + + result = sasl_server_init(&callbacks[0], kSaslAppName); + if (result != SASL_OK) { + sasl_init_status = Status::RuntimeError("Could not initialize SASL server", + sasl_errstring(result, nullptr, nullptr)); + return; + } + + sasl_is_initialized = true; +} + +Status DisableSaslInitialization() { + if (g_disable_sasl_init) return Status::OK(); + if (sasl_is_initialized) { + return Status::IllegalState("SASL already initialized. Initialization can only be disabled " + "before first usage."); + } + g_disable_sasl_init = true; + return Status::OK(); +} + +Status SaslInit() { + // Only execute SASL initialization once + static GoogleOnceType once = GOOGLE_ONCE_INIT; + GoogleOnceInit(&once, &DoSaslInit); + return sasl_init_status; +} + +static string CleanSaslError(const string& err) { + // In the case of GSS failures, we often get the actual error message + // buried inside a bunch of generic cruft. Use a regexp to extract it + // out. Note that we avoid std::regex because it appears to be broken + // with older libstdcxx. + static regex_t re; + static std::once_flag once; + +#if defined(__APPLE__) + static const char* kGssapiPattern = "GSSAPI Error: Miscellaneous failure \\(see text \\((.+)\\)"; +#else + static const char* kGssapiPattern = "Unspecified GSS failure. +" + "Minor code may provide more information +" + "\\((.+)\\)"; +#endif + + std::call_once(once, []{ CHECK_EQ(0, regcomp(&re, kGssapiPattern, REG_EXTENDED)); }); + regmatch_t matches[2]; + if (regexec(&re, err.c_str(), arraysize(matches), matches, 0) == 0) { + return err.substr(matches[1].rm_so, matches[1].rm_eo - matches[1].rm_so); + } + return err; +} + +static string SaslErrDesc(int status, sasl_conn_t* conn) { + string err; + if (conn != nullptr) { + err = sasl_errdetail(conn); + } else { + err = sasl_errstring(status, nullptr, nullptr); + } + + return CleanSaslError(err); +} + +Status WrapSaslCall(sasl_conn_t* conn, const std::function<int()>& call) { + // In many cases, the GSSAPI SASL plugin will generate a nice error + // message as a message logged at SASL_LOG_FAIL logging level, but then + // return a useless one in sasl_errstring(). So, we set a global thread-local + // variable to capture any auth failure log message while we make the + // call into the library. + // + // The thread-local thing is a bit of a hack, but the logging callback + // is set globally rather than on a per-connection basis. + string err; + g_auth_failure_capture = &err; + + // Take the 'kerberos_reinit_lock' here to avoid a possible race with ticket renewal. + bool kerberos_supported = !FLAGS_keytab_file.empty(); + if (kerberos_supported) kudu::security::KerberosReinitLock()->ReadLock(); + int rc = call(); + if (kerberos_supported) kudu::security::KerberosReinitLock()->ReadUnlock(); + g_auth_failure_capture = nullptr; + + switch (rc) { + case SASL_OK: + return Status::OK(); + case SASL_CONTINUE: + return Status::Incomplete(""); + case SASL_FAIL: // Generic failure (encompasses missing krb5 credentials). + case SASL_BADAUTH: // Authentication failure. + case SASL_BADMAC: // Decode failure. + case SASL_NOAUTHZ: // Authorization failure. + case SASL_NOUSER: // User not found. + case SASL_WRONGMECH: // Server doesn't support requested mechanism. + case SASL_BADSERV: { // Server failed mutual authentication. + if (err.empty()) { + err = SaslErrDesc(rc, conn); + } else { + err = CleanSaslError(err); + } + return Status::NotAuthorized(err); + } + default: + return Status::RuntimeError(SaslErrDesc(rc, conn)); + } +} + +Status SaslEncode(sasl_conn_t* conn, const std::string& plaintext, std::string* encoded) { + size_t offset = 0; + + // The SASL library can only encode up to a maximum amount at a time, so we + // have to call encode multiple times if our input is larger than this max. + while (offset < plaintext.size()) { + const char* out; + unsigned out_len; + size_t len = std::min(kSaslMaxOutBufLen, plaintext.size() - offset); + + RETURN_NOT_OK(WrapSaslCall(conn, [&]() { + return sasl_encode(conn, plaintext.data() + offset, len, &out, &out_len); + })); + + encoded->append(out, out_len); + offset += len; + } + + return Status::OK(); +} + +Status SaslDecode(sasl_conn_t* conn, const string& encoded, string* plaintext) { + size_t offset = 0; + + // The SASL library can only decode up to a maximum amount at a time, so we + // have to call decode multiple times if our input is larger than this max. + while (offset < encoded.size()) { + const char* out; + unsigned out_len; + size_t len = std::min(kSaslMaxOutBufLen, encoded.size() - offset); + + RETURN_NOT_OK(WrapSaslCall(conn, [&]() { + return sasl_decode(conn, encoded.data() + offset, len, &out, &out_len); + })); + + plaintext->append(out, out_len); + offset += len; + } + + return Status::OK(); +} + +string SaslIpPortString(const Sockaddr& addr) { + string addr_str = addr.ToString(); + size_t colon_pos = addr_str.find(':'); + if (colon_pos != string::npos) { + addr_str[colon_pos] = ';'; + } + return addr_str; +} + +set<SaslMechanism::Type> SaslListAvailableMechs() { + set<SaslMechanism::Type> mechs; + + // Array of NULL-terminated strings. Array terminated with NULL. + for (const char** mech_strings = sasl_global_listmech(); + mech_strings != nullptr && *mech_strings != nullptr; + mech_strings++) { + auto mech = SaslMechanism::value_of(*mech_strings); + if (mech != SaslMechanism::INVALID) { + mechs.insert(mech); + } + } + return mechs; +} + +sasl_callback_t SaslBuildCallback(int id, int (*proc)(void), void* context) { + sasl_callback_t callback; + callback.id = id; + callback.proc = proc; + callback.context = context; + return callback; +} + +Status EnableIntegrityProtection(sasl_conn_t* sasl_conn) { + sasl_security_properties_t sec_props; + memset(&sec_props, 0, sizeof(sec_props)); + sec_props.min_ssf = 1; + sec_props.max_ssf = std::numeric_limits<sasl_ssf_t>::max(); + sec_props.maxbufsize = kSaslMaxOutBufLen; + + RETURN_NOT_OK_PREPEND(WrapSaslCall(sasl_conn, [&] () { + return sasl_setprop(sasl_conn, SASL_SEC_PROPS, &sec_props); + }), "failed to set SASL security properties"); + return Status::OK(); +} + +SaslMechanism::Type SaslMechanism::value_of(const string& mech) { + if (boost::iequals(mech, "PLAIN")) { + return PLAIN; + } + if (boost::iequals(mech, "GSSAPI")) { + return GSSAPI; + } + return INVALID; +} + +const char* SaslMechanism::name_of(SaslMechanism::Type val) { + switch (val) { + case PLAIN: return "PLAIN"; + case GSSAPI: return "GSSAPI"; + default: + return "INVALID"; + } +} + +} // namespace rpc +} // namespace kudu
http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/sasl_common.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_common.h b/be/src/kudu/rpc/sasl_common.h new file mode 100644 index 0000000..6022f9e --- /dev/null +++ b/be/src/kudu/rpc/sasl_common.h @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef KUDU_RPC_SASL_COMMON_H +#define KUDU_RPC_SASL_COMMON_H + +#include <stdint.h> // Required for sasl/sasl.h + +#include <string> +#include <set> + +#include <sasl/sasl.h> + +#include "kudu/util/status.h" + +namespace kudu { + +class Sockaddr; + +namespace rpc { + +using std::string; + +// Constants +extern const char* const kSaslMechPlain; +extern const char* const kSaslMechGSSAPI; +extern const size_t kSaslMaxOutBufLen; + +struct SaslMechanism { + enum Type { + INVALID, + PLAIN, + GSSAPI + }; + static Type value_of(const std::string& mech); + static const char* name_of(Type val); +}; + +// Initialize the SASL library. +// appname: Name of the application for logging messages & sasl plugin configuration. +// Note that this string must remain allocated for the lifetime of the program. +// This function must be called before using SASL. +// If the library initializes without error, calling more than once has no effect. +// +// Some SASL plugins take time to initialize random number generators and other things, +// so the first time this function is invoked it may execute for several seconds. +// After that, it should be very fast. This function should be invoked as early as possible +// in the application lifetime to avoid SASL initialization taking place in a +// performance-critical section. +// +// This function is thread safe and uses a static lock. +// This function should NOT be called during static initialization. +Status SaslInit() WARN_UNUSED_RESULT; + +// Disable Kudu's initialization of SASL. See equivalent method in client.h. +Status DisableSaslInitialization() WARN_UNUSED_RESULT; + +// Wrap a call into the SASL library. 'call' should be a lambda which +// returns a SASL error code. +// +// The result is translated into a Status as follows: +// +// SASL_OK: Status::OK() +// SASL_CONTINUE: Status::Incomplete() +// otherwise: Status::NotAuthorized() +// +// The Status message is beautified to be more user-friendly compared +// to the underlying sasl_errdetails() call. +Status WrapSaslCall(sasl_conn_t* conn, const std::function<int()>& call) WARN_UNUSED_RESULT; + +// Return <ip>;<port> string formatted for SASL library use. +string SaslIpPortString(const Sockaddr& addr); + +// Return available plugin mechanisms for the given connection. +std::set<SaslMechanism::Type> SaslListAvailableMechs(); + +// Initialize and return a libsasl2 callback data structure based on the passed args. +// id: A SASL callback identifier (e.g., SASL_CB_GETOPT). +// proc: A C-style callback with appropriate signature based on the callback id, or NULL. +// context: An object to pass to the callback as the context pointer, or NULL. +sasl_callback_t SaslBuildCallback(int id, int (*proc)(void), void* context); + +// Require integrity protection on the SASL connection. Should be called before +// initiating the SASL negotiation. +Status EnableIntegrityProtection(sasl_conn_t* sasl_conn) WARN_UNUSED_RESULT; + +// Encode the provided data and append it to 'encoded'. +Status SaslEncode(sasl_conn_t* conn, + const std::string& plaintext, + std::string* encoded) WARN_UNUSED_RESULT; + +// Decode the provided SASL-encoded data and append it to 'plaintext'. +Status SaslDecode(sasl_conn_t* conn, + const std::string& encoded, + std::string* plaintext) WARN_UNUSED_RESULT; + +// Deleter for sasl_conn_t instances, for use with gscoped_ptr after calling sasl_*_new() +struct SaslDeleter { + inline void operator()(sasl_conn_t* conn) { + sasl_dispose(&conn); + } +}; + +// Internals exposed in the header for test purposes. +namespace internal { +void SaslSetMutex(); +} // namespace internal + +} // namespace rpc +} // namespace kudu + +#endif http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/sasl_helper.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_helper.cc b/be/src/kudu/rpc/sasl_helper.cc new file mode 100644 index 0000000..53f9d08 --- /dev/null +++ b/be/src/kudu/rpc/sasl_helper.cc @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/sasl_helper.h" + +#include <string> + +#include <glog/logging.h> +#include <google/protobuf/message_lite.h> + +#include "kudu/gutil/macros.h" +#include "kudu/gutil/map-util.h" +#include "kudu/gutil/port.h" +#include "kudu/gutil/strings/join.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/constants.h" +#include "kudu/rpc/rpc_header.pb.h" +#include "kudu/rpc/sasl_common.h" +#include "kudu/rpc/serialization.h" +#include "kudu/util/status.h" + +using std::string; + +namespace kudu { +namespace rpc { + +using google::protobuf::MessageLite; + +SaslHelper::SaslHelper(PeerType peer_type) + : peer_type_(peer_type), + global_mechs_(SaslListAvailableMechs()), + plain_enabled_(false), + gssapi_enabled_(false) { + tag_ = (peer_type_ == SERVER) ? "Server" : "Client"; +} + +void SaslHelper::set_server_fqdn(const string& domain_name) { + server_fqdn_ = domain_name; +} +const char* SaslHelper::server_fqdn() const { + return server_fqdn_.empty() ? nullptr : server_fqdn_.c_str(); +} + +const char* SaslHelper::EnabledMechsString() const { + enabled_mechs_string_ = JoinMapped(enabled_mechs_, SaslMechanism::name_of, " "); + return enabled_mechs_string_.c_str(); +} + +int SaslHelper::GetOptionCb(const char* plugin_name, const char* option, + const char** result, unsigned* len) { + DVLOG(4) << tag_ << ": GetOption Callback called. "; + DVLOG(4) << tag_ << ": GetOption Plugin name: " + << (plugin_name == nullptr ? "NULL" : plugin_name); + DVLOG(4) << tag_ << ": GetOption Option name: " << option; + + if (PREDICT_FALSE(result == nullptr)) { + LOG(DFATAL) << tag_ << ": SASL Library passed NULL result out-param to GetOption callback!"; + return SASL_BADPARAM; + } + + if (plugin_name == nullptr) { + // SASL library option, not a plugin option + if (strcmp(option, "mech_list") == 0) { + *result = EnabledMechsString(); + if (len != nullptr) *len = strlen(*result); + VLOG(4) << tag_ << ": Enabled mech list: " << *result; + return SASL_OK; + } + VLOG(4) << tag_ << ": GetOptionCb: Unknown library option: " << option; + } else { + VLOG(4) << tag_ << ": GetOptionCb: Unknown plugin: " << plugin_name; + } + return SASL_FAIL; +} + +Status SaslHelper::EnablePlain() { + RETURN_NOT_OK(EnableMechanism(SaslMechanism::PLAIN)); + plain_enabled_ = true; + return Status::OK(); +} + +Status SaslHelper::EnableGSSAPI() { + RETURN_NOT_OK(EnableMechanism(SaslMechanism::GSSAPI)); + gssapi_enabled_ = true; + return Status::OK(); +} + +Status SaslHelper::EnableMechanism(SaslMechanism::Type mech) { + if (PREDICT_FALSE(!ContainsKey(global_mechs_, mech))) { + return Status::InvalidArgument("unable to find SASL plugin", SaslMechanism::name_of(mech)); + } + enabled_mechs_.insert(mech); + return Status::OK(); +} + +bool SaslHelper::IsPlainEnabled() const { + return plain_enabled_; +} + +Status SaslHelper::CheckNegotiateCallId(int32_t call_id) const { + if (call_id != kNegotiateCallId) { + Status s = Status::IllegalState(strings::Substitute( + "Received illegal call-id during negotiation; expected: $0, received: $1", + kNegotiateCallId, call_id)); + LOG(DFATAL) << tag_ << ": " << s.ToString(); + return s; + } + return Status::OK(); +} + +Status SaslHelper::ParseNegotiatePB(const Slice& param_buf, NegotiatePB* msg) { + if (!msg->ParseFromArray(param_buf.data(), param_buf.size())) { + return Status::IOError(tag_ + ": Invalid SASL message, missing fields", + msg->InitializationErrorString()); + } + return Status::OK(); +} + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/sasl_helper.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/sasl_helper.h b/be/src/kudu/rpc/sasl_helper.h new file mode 100644 index 0000000..0a3107c --- /dev/null +++ b/be/src/kudu/rpc/sasl_helper.h @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef KUDU_RPC_SASL_HELPER_H +#define KUDU_RPC_SASL_HELPER_H + +#include <set> +#include <string> + +#include <sasl/sasl.h> + +#include "kudu/rpc/sasl_common.h" +#include "kudu/util/status.h" + +namespace kudu { + +class Sockaddr; + +namespace rpc { + +class NegotiatePB; + +// Helper class which contains functionality that is common to client and server +// SASL negotiations. Most of these methods are convenience methods for +// interacting with the libsasl2 library. +class SaslHelper { + public: + enum PeerType { + CLIENT, + SERVER + }; + + explicit SaslHelper(PeerType peer_type); + ~SaslHelper() = default; + + // Specify the fully-qualified domain name of the remote server. + void set_server_fqdn(const std::string& domain_name); + const char* server_fqdn() const; + + // Globally-registered available SASL plugins. + const std::set<SaslMechanism::Type>& GlobalMechs() const { + return global_mechs_; + } + + // Helper functions for managing the list of active SASL mechanisms. + const std::set<SaslMechanism::Type>& EnabledMechs() const { + return enabled_mechs_; + } + + // Implements the client_mech_list / mech_list callbacks. + int GetOptionCb(const char* plugin_name, const char* option, const char** result, unsigned* len); + + // Enable the PLAIN SASL mechanism. + Status EnablePlain(); + + // Enable the GSSAPI (Kerberos) mechanism. + Status EnableGSSAPI(); + + // Check for the PLAIN SASL mechanism. + bool IsPlainEnabled() const; + + // Sanity check that the call ID is the negotiation call ID. + // Logs DFATAL if call_id does not match. + Status CheckNegotiateCallId(int32_t call_id) const; + + // Parse msg from the given Slice. + Status ParseNegotiatePB(const Slice& param_buf, NegotiatePB* msg); + + private: + Status EnableMechanism(SaslMechanism::Type mech); + + // Returns space-delimited local mechanism list string suitable for passing + // to libsasl2, such as via "mech_list" callbacks. + // The returned pointer is valid only until the next call to EnabledMechsString(). + const char* EnabledMechsString() const; + + std::string server_fqdn_; + + // Authentication types and data. + const PeerType peer_type_; + std::string tag_; + std::set<SaslMechanism::Type> global_mechs_; // Cache of global mechanisms. + std::set<SaslMechanism::Type> enabled_mechs_; // Active mechanisms. + mutable std::string enabled_mechs_string_; // Mechanism list string returned by callbacks. + + bool plain_enabled_; + bool gssapi_enabled_; + + DISALLOW_COPY_AND_ASSIGN(SaslHelper); +}; + +} // namespace rpc +} // namespace kudu + +#endif // KUDU_RPC_SASL_HELPER_H http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/serialization.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/serialization.cc b/be/src/kudu/rpc/serialization.cc new file mode 100644 index 0000000..dbb0fc5 --- /dev/null +++ b/be/src/kudu/rpc/serialization.cc @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/serialization.h" + +#include <glog/logging.h> +#include <google/protobuf/message_lite.h> +#include <google/protobuf/io/coded_stream.h> + +#include "kudu/gutil/endian.h" +#include "kudu/gutil/stringprintf.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/constants.h" +#include "kudu/util/faststring.h" +#include "kudu/util/logging.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" + +DECLARE_int32(rpc_max_message_size); + +using google::protobuf::MessageLite; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; +using strings::Substitute; + +namespace kudu { +namespace rpc { +namespace serialization { + +enum { + kHeaderPosVersion = 0, + kHeaderPosServiceClass = 1, + kHeaderPosAuthProto = 2 +}; + +void SerializeMessage(const MessageLite& message, faststring* param_buf, + int additional_size, bool use_cached_size) { + int pb_size = use_cached_size ? message.GetCachedSize() : message.ByteSize(); + DCHECK_EQ(message.ByteSize(), pb_size); + int recorded_size = pb_size + additional_size; + int size_with_delim = pb_size + CodedOutputStream::VarintSize32(recorded_size); + int total_size = size_with_delim + additional_size; + + if (total_size > FLAGS_rpc_max_message_size) { + LOG(WARNING) << Substitute("Serialized $0 ($1 bytes) is larger than the maximum configured " + "RPC message size ($2 bytes). " + "Sending anyway, but peer may reject the data.", + message.GetTypeName(), total_size, FLAGS_rpc_max_message_size); + } + + param_buf->resize(size_with_delim); + uint8_t* dst = param_buf->data(); + dst = CodedOutputStream::WriteVarint32ToArray(recorded_size, dst); + dst = message.SerializeWithCachedSizesToArray(dst); + CHECK_EQ(dst, param_buf->data() + size_with_delim); +} + +void SerializeHeader(const MessageLite& header, + size_t param_len, + faststring* header_buf) { + + CHECK(header.IsInitialized()) + << "RPC header missing fields: " << header.InitializationErrorString(); + + // Compute all the lengths for the packet. + size_t header_pb_len = header.ByteSize(); + size_t header_tot_len = kMsgLengthPrefixLength // Int prefix for the total length. + + CodedOutputStream::VarintSize32(header_pb_len) // Varint delimiter for header PB. + + header_pb_len; // Length for the header PB itself. + size_t total_size = header_tot_len + param_len; + + header_buf->resize(header_tot_len); + uint8_t* dst = header_buf->data(); + + // 1. The length for the whole request, not including the 4-byte + // length prefix. + NetworkByteOrder::Store32(dst, total_size - kMsgLengthPrefixLength); + dst += sizeof(uint32_t); + + // 2. The varint-prefixed RequestHeader PB + dst = CodedOutputStream::WriteVarint32ToArray(header_pb_len, dst); + dst = header.SerializeWithCachedSizesToArray(dst); + + // We should have used the whole buffer we allocated. + CHECK_EQ(dst, header_buf->data() + header_tot_len); +} + +Status ParseMessage(const Slice& buf, + MessageLite* parsed_header, + Slice* parsed_main_message) { + + // First grab the total length + if (PREDICT_FALSE(buf.size() < kMsgLengthPrefixLength)) { + return Status::Corruption("Invalid packet: not enough bytes for length header", + KUDU_REDACT(buf.ToDebugString())); + } + + int total_len = NetworkByteOrder::Load32(buf.data()); + DCHECK_EQ(total_len + kMsgLengthPrefixLength, buf.size()) + << "Got mis-sized buffer: " << KUDU_REDACT(buf.ToDebugString()); + + CodedInputStream in(buf.data(), buf.size()); + in.Skip(kMsgLengthPrefixLength); + + uint32_t header_len; + if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) { + return Status::Corruption("Invalid packet: missing header delimiter", + KUDU_REDACT(buf.ToDebugString())); + } + + CodedInputStream::Limit l; + l = in.PushLimit(header_len); + if (PREDICT_FALSE(!parsed_header->ParseFromCodedStream(&in))) { + return Status::Corruption("Invalid packet: header too short", + KUDU_REDACT(buf.ToDebugString())); + } + in.PopLimit(l); + + uint32_t main_msg_len; + if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) { + return Status::Corruption("Invalid packet: missing main msg length", + KUDU_REDACT(buf.ToDebugString())); + } + + if (PREDICT_FALSE(!in.Skip(main_msg_len))) { + return Status::Corruption( + StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len), + KUDU_REDACT(buf.ToDebugString())); + } + + if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) { + return Status::Corruption( + StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()), + KUDU_REDACT(buf.ToDebugString())); + } + + *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len, + main_msg_len); + return Status::OK(); +} + +void SerializeConnHeader(uint8_t* buf) { + memcpy(reinterpret_cast<char *>(buf), kMagicNumber, kMagicNumberLength); + buf += kMagicNumberLength; + buf[kHeaderPosVersion] = kCurrentRpcVersion; + buf[kHeaderPosServiceClass] = 0; // TODO: implement + buf[kHeaderPosAuthProto] = 0; // TODO: implement +} + +// validate the entire rpc header (magic number + flags) +Status ValidateConnHeader(const Slice& slice) { + DCHECK_EQ(kMagicNumberLength + kHeaderFlagsLength, slice.size()) + << "Invalid RPC header length"; + + // validate actual magic + if (!slice.starts_with(kMagicNumber)) { + if (slice.starts_with("GET ") || + slice.starts_with("POST") || + slice.starts_with("HEAD")) { + return Status::InvalidArgument("invalid negotation, appears to be an HTTP client on " + "the RPC port"); + } + return Status::InvalidArgument("connection must begin with magic number", kMagicNumber); + } + + const uint8_t *data = slice.data(); + data += kMagicNumberLength; + + // validate version + if (data[kHeaderPosVersion] != kCurrentRpcVersion) { + return Status::InvalidArgument("Unsupported RPC version", + StringPrintf("Received: %d, Supported: %d", + data[kHeaderPosVersion], kCurrentRpcVersion)); + } + + // TODO: validate additional header flags: + // RPC_SERVICE_CLASS + // RPC_AUTH_PROTOCOL + + return Status::OK(); +} + +} // namespace serialization +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/serialization.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/serialization.h b/be/src/kudu/rpc/serialization.h new file mode 100644 index 0000000..26df3a7 --- /dev/null +++ b/be/src/kudu/rpc/serialization.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef KUDU_RPC_SERIALIZATION_H +#define KUDU_RPC_SERIALIZATION_H + +#include <inttypes.h> +#include <string.h> + +namespace google { +namespace protobuf { +class MessageLite; +} // namespace protobuf +} // namespace google + +namespace kudu { + +class Status; +class faststring; +class Slice; + +namespace rpc { +namespace serialization { + +// Serialize the request param into a buffer which is allocated by this function. +// Uses the message's cached size by calling MessageLite::GetCachedSize(). +// In : 'message' Protobuf Message to serialize +// 'additional_size' Optional argument which increases the recorded size +// within param_buf. This argument is necessary if there will be +// additional sidecars appended onto the message (that aren't part of +// the protobuf itself). +// 'use_cached_size' Additional optional argument whether to use the cached +// or explicit byte size by calling MessageLite::GetCachedSize() or +// MessageLite::ByteSize(), respectively. +// Out: The faststring 'param_buf' to be populated with the serialized bytes. +// The faststring's length is only determined by the amount that +// needs to be serialized for the protobuf (i.e., no additional space +// is reserved for 'additional_size', which only affects the +// size indicator prefix in 'param_buf'). +void SerializeMessage(const google::protobuf::MessageLite& message, + faststring* param_buf, int additional_size = 0, + bool use_cached_size = false); + +// Serialize the request or response header into a buffer which is allocated +// by this function. +// Includes leading 32-bit length of the buffer. +// In: Protobuf Header to serialize, +// Length of the message param following this header in the frame. +// Out: faststring to be populated with the serialized bytes. +void SerializeHeader(const google::protobuf::MessageLite& header, + size_t param_len, + faststring* header_buf); + +// Deserialize the request. +// In: data buffer Slice. +// Out: parsed_header PB initialized, +// parsed_main_message pointing to offset in original buffer containing +// the main payload. +Status ParseMessage(const Slice& buf, + google::protobuf::MessageLite* parsed_header, + Slice* parsed_main_message); + +// Serialize the RPC connection header (magic number + flags). +// buf must have 7 bytes available (kMagicNumberLength + kHeaderFlagsLength). +void SerializeConnHeader(uint8_t* buf); + +// Validate the entire rpc header (magic number + flags). +Status ValidateConnHeader(const Slice& slice); + + +} // namespace serialization +} // namespace rpc +} // namespace kudu +#endif // KUDU_RPC_SERIALIZATION_H http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/server_negotiation.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/server_negotiation.cc b/be/src/kudu/rpc/server_negotiation.cc new file mode 100644 index 0000000..5e6d070 --- /dev/null +++ b/be/src/kudu/rpc/server_negotiation.cc @@ -0,0 +1,980 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/rpc/server_negotiation.h" + +#include <limits> +#include <memory> +#include <set> +#include <string> + +#include <gflags/gflags.h> +#include <glog/logging.h> +#include <google/protobuf/message_lite.h> +#include <sasl/sasl.h> + +#include "kudu/gutil/casts.h" +#include "kudu/gutil/endian.h" +#include "kudu/gutil/map-util.h" +#include "kudu/gutil/strings/split.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/blocking_ops.h" +#include "kudu/rpc/constants.h" +#include "kudu/rpc/messenger.h" +#include "kudu/rpc/serialization.h" +#include "kudu/security/cert.h" +#include "kudu/security/crypto.h" +#include "kudu/security/init.h" +#include "kudu/security/tls_context.h" +#include "kudu/security/tls_handshake.h" +#include "kudu/security/tls_socket.h" +#include "kudu/security/token_verifier.h" +#include "kudu/util/fault_injection.h" +#include "kudu/util/flag_tags.h" +#include "kudu/util/logging.h" +#include "kudu/util/net/net_util.h" +#include "kudu/util/net/sockaddr.h" +#include "kudu/util/net/socket.h" +#include "kudu/util/scoped_cleanup.h" +#include "kudu/util/trace.h" + +using std::set; +using std::string; +using std::unique_ptr; + +// Fault injection flags. +DEFINE_double(rpc_inject_invalid_authn_token_ratio, 0, + "If set higher than 0, AuthenticateByToken() randomly injects " + "errors replying with FATAL_INVALID_AUTHENTICATION_TOKEN code. " + "The flag's value corresponds to the probability of the fault " + "injection event. Used for only for tests."); +TAG_FLAG(rpc_inject_invalid_authn_token_ratio, runtime); +TAG_FLAG(rpc_inject_invalid_authn_token_ratio, unsafe); + +DECLARE_bool(rpc_encrypt_loopback_connections); + +DEFINE_string(trusted_subnets, + "127.0.0.0/8,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16", + "A trusted subnet whitelist. If set explicitly, all unauthenticated " + "or unencrypted connections are prohibited except the ones from the " + "specified address blocks. Otherwise, private network (127.0.0.0/8, etc.) " + "and local subnets of all local network interfaces will be used. Set it " + "to '0.0.0.0/0' to allow unauthenticated/unencrypted connections from all " + "remote IP addresses. However, if network access is not otherwise restricted " + "by a firewall, malicious users may be able to gain unauthorized access."); +TAG_FLAG(trusted_subnets, advanced); +TAG_FLAG(trusted_subnets, evolving); + +static bool ValidateTrustedSubnets(const char* /*flagname*/, const string& value) { + if (value.empty()) { + return true; + } + + for (const auto& t : strings::Split(value, ",", strings::SkipEmpty())) { + kudu::Network network; + kudu::Status s = network.ParseCIDRString(t.ToString()); + if (!s.ok()) { + LOG(ERROR) << "Invalid subnet address: " << t + << ". Subnet must be specified in CIDR notation."; + return false; + } + } + + return true; +} + +DEFINE_validator(trusted_subnets, &ValidateTrustedSubnets); + +namespace kudu { +namespace rpc { + +namespace { +vector<Network>* g_trusted_subnets = nullptr; +} // anonymous namespace + +static int ServerNegotiationGetoptCb(ServerNegotiation* server_negotiation, + const char* plugin_name, + const char* option, + const char** result, + unsigned* len) { + return server_negotiation->GetOptionCb(plugin_name, option, result, len); +} + +static int ServerNegotiationPlainAuthCb(sasl_conn_t* conn, + ServerNegotiation* server_negotiation, + const char* user, + const char* pass, + unsigned passlen, + struct propctx* propctx) { + return server_negotiation->PlainAuthCb(conn, user, pass, passlen, propctx); +} + +ServerNegotiation::ServerNegotiation(unique_ptr<Socket> socket, + const security::TlsContext* tls_context, + const security::TokenVerifier* token_verifier, + RpcEncryption encryption) + : socket_(std::move(socket)), + helper_(SaslHelper::SERVER), + tls_context_(tls_context), + encryption_(encryption), + tls_negotiated_(false), + token_verifier_(token_verifier), + negotiated_authn_(AuthenticationType::INVALID), + negotiated_mech_(SaslMechanism::INVALID), + deadline_(MonoTime::Max()) { + callbacks_.push_back(SaslBuildCallback(SASL_CB_GETOPT, + reinterpret_cast<int (*)()>(&ServerNegotiationGetoptCb), this)); + callbacks_.push_back(SaslBuildCallback(SASL_CB_SERVER_USERDB_CHECKPASS, + reinterpret_cast<int (*)()>(&ServerNegotiationPlainAuthCb), this)); + callbacks_.push_back(SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr)); +} + +Status ServerNegotiation::EnablePlain() { + return helper_.EnablePlain(); +} + +Status ServerNegotiation::EnableGSSAPI() { + return helper_.EnableGSSAPI(); +} + +SaslMechanism::Type ServerNegotiation::negotiated_mechanism() const { + return negotiated_mech_; +} + +void ServerNegotiation::set_server_fqdn(const string& domain_name) { + helper_.set_server_fqdn(domain_name); +} + +void ServerNegotiation::set_deadline(const MonoTime& deadline) { + deadline_ = deadline; +} + +Status ServerNegotiation::Negotiate() { + TRACE("Beginning negotiation"); + + // Wait until starting negotiation to check that the socket, tls_context, and + // token_verifier are not null, since they do not need to be set for + // PreflightCheckGSSAPI. + DCHECK(socket_); + DCHECK(tls_context_); + DCHECK(token_verifier_); + + // Ensure we can use blocking calls on the socket during negotiation. + RETURN_NOT_OK(EnsureBlockingMode(socket_.get())); + + faststring recv_buf; + + // Step 1: Read the connection header. + RETURN_NOT_OK(ValidateConnectionHeader(&recv_buf)); + + { // Step 2: Receive and respond to the NEGOTIATE step message. + NegotiatePB request; + RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf)); + RETURN_NOT_OK(HandleNegotiate(request)); + TRACE("Negotiated authn=$0", AuthenticationTypeToString(negotiated_authn_)); + } + + // Step 3: if both ends support TLS, do a TLS handshake. + if (encryption_ != RpcEncryption::DISABLED && + tls_context_->has_cert() && + ContainsKey(client_features_, TLS)) { + RETURN_NOT_OK(tls_context_->InitiateHandshake(security::TlsHandshakeType::SERVER, + &tls_handshake_)); + + if (negotiated_authn_ != AuthenticationType::CERTIFICATE) { + // The server does not need to verify the client's certificate unless it's + // being used for authentication. + tls_handshake_.set_verification_mode(security::TlsVerificationMode::VERIFY_NONE); + } + + while (true) { + NegotiatePB request; + RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf)); + Status s = HandleTlsHandshake(request); + if (s.ok()) break; + if (!s.IsIncomplete()) return s; + } + tls_negotiated_ = true; + } + + // Rejects any connection from public routable IPs if encryption + // is disabled. See KUDU-1875. + if (!tls_negotiated_) { + Sockaddr addr; + RETURN_NOT_OK(socket_->GetPeerAddress(&addr)); + + if (!IsTrustedConnection(addr)) { + // Receives client response before sending error + // message, even though the response is never used, + // to avoid risk condition that connection gets + // closed before client receives server's error + // message. + NegotiatePB request; + RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf)); + + Status s = Status::NotAuthorized("unencrypted connections from publicly routable " + "IPs are prohibited. See --trusted_subnets flag " + "for more information.", + addr.ToString()); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + } + + // Step 4: Authentication + switch (negotiated_authn_) { + case AuthenticationType::SASL: + RETURN_NOT_OK(AuthenticateBySasl(&recv_buf)); + break; + case AuthenticationType::TOKEN: + RETURN_NOT_OK(AuthenticateByToken(&recv_buf)); + break; + case AuthenticationType::CERTIFICATE: + RETURN_NOT_OK(AuthenticateByCertificate()); + break; + case AuthenticationType::INVALID: LOG(FATAL) << "unreachable"; + } + + // Step 5: Receive connection context. + RETURN_NOT_OK(RecvConnectionContext(&recv_buf)); + + TRACE("Negotiation successful"); + return Status::OK(); +} + +Status ServerNegotiation::PreflightCheckGSSAPI() { + // TODO(todd): the error messages that come from this function on el6 + // are relatively useless due to the following krb5 bug: + // http://krbdev.mit.edu/rt/Ticket/Display.html?id=6973 + // This may not be useful anymore given the keytab login that happens + // in security/init.cc. + + // Initialize a ServerNegotiation with a null socket, and enable + // only GSSAPI. + // + // We aren't going to actually send/receive any messages, but + // this makes it easier to reuse the initialization code. + ServerNegotiation server(nullptr, nullptr, nullptr, RpcEncryption::OPTIONAL); + Status s = server.EnableGSSAPI(); + if (!s.ok()) { + return Status::RuntimeError(s.message()); + } + + RETURN_NOT_OK(server.InitSaslServer()); + + // Start the SASL server as if we were accepting a connection. + const char* server_out = nullptr; // ignored + uint32_t server_out_len = 0; + s = WrapSaslCall(server.sasl_conn_.get(), [&]() { + return sasl_server_start( + server.sasl_conn_.get(), + kSaslMechGSSAPI, + "", 0, // Pass a 0-length token. + &server_out, &server_out_len); + }); + + // We expect 'Incomplete' status to indicate that the first step of negotiation + // was correct. + if (s.IsIncomplete()) return Status::OK(); + + string err_msg = s.message().ToString(); + if (err_msg == "Permission denied") { + // For bad keytab permissions, we get a rather vague message. So, + // we make it more specific for better usability. + err_msg = "error accessing keytab: " + err_msg; + } + return Status::RuntimeError(err_msg); +} + +Status ServerNegotiation::RecvNegotiatePB(NegotiatePB* msg, faststring* recv_buf) { + RequestHeader header; + Slice param_buf; + RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), recv_buf, &header, ¶m_buf, deadline_)); + Status s = helper_.CheckNegotiateCallId(header.call_id()); + if (!s.ok()) { + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_RPC_HEADER, s)); + return s; + } + + s = helper_.ParseNegotiatePB(param_buf, msg); + if (!s.ok()) { + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_DESERIALIZING_REQUEST, s)); + return s; + } + + TRACE("Received $0 NegotiatePB request", NegotiatePB::NegotiateStep_Name(msg->step())); + return Status::OK(); +} + +Status ServerNegotiation::SendNegotiatePB(const NegotiatePB& msg) { + ResponseHeader header; + header.set_call_id(kNegotiateCallId); + + DCHECK(socket_); + DCHECK(msg.IsInitialized()) << "message must be initialized"; + DCHECK(msg.has_step()) << "message must have a step"; + + TRACE("Sending $0 NegotiatePB response", NegotiatePB::NegotiateStep_Name(msg.step())); + return SendFramedMessageBlocking(socket(), header, msg, deadline_); +} + +Status ServerNegotiation::SendError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) { + DCHECK(!err.ok()); + + // Create header with negotiation-specific callId + ResponseHeader header; + header.set_call_id(kNegotiateCallId); + header.set_is_error(true); + + // Get RPC error code from Status object + ErrorStatusPB msg; + msg.set_code(code); + msg.set_message(err.ToString()); + + TRACE("Sending RPC error: $0: $1", ErrorStatusPB::RpcErrorCodePB_Name(code), err.ToString()); + RETURN_NOT_OK(SendFramedMessageBlocking(socket(), header, msg, deadline_)); + + return Status::OK(); +} + +Status ServerNegotiation::ValidateConnectionHeader(faststring* recv_buf) { + TRACE("Waiting for connection header"); + size_t num_read; + const size_t conn_header_len = kMagicNumberLength + kHeaderFlagsLength; + recv_buf->resize(conn_header_len); + RETURN_NOT_OK(socket_->BlockingRecv(recv_buf->data(), conn_header_len, &num_read, deadline_)); + DCHECK_EQ(conn_header_len, num_read); + + RETURN_NOT_OK(serialization::ValidateConnHeader(*recv_buf)); + TRACE("Connection header received"); + return Status::OK(); +} + +// calls sasl_server_init() and sasl_server_new() +Status ServerNegotiation::InitSaslServer() { + RETURN_NOT_OK(SaslInit()); + + // TODO(unknown): Support security flags. + unsigned secflags = 0; + + sasl_conn_t* sasl_conn = nullptr; + RETURN_NOT_OK_PREPEND(WrapSaslCall(nullptr /* no conn */, [&]() { + return sasl_server_new( + // Registered name of the service using SASL. Required. + kSaslProtoName, + // The fully qualified domain name of this server. + helper_.server_fqdn(), + // Permits multiple user realms on server. NULL == use default. + nullptr, + // Local and remote IP address strings. We don't use any mechanisms + // which need these. + nullptr, + nullptr, + // Connection-specific callbacks. + &callbacks_[0], + // Security flags. + secflags, + &sasl_conn); + }), "Unable to create new SASL server"); + sasl_conn_.reset(sasl_conn); + return Status::OK(); +} + +Status ServerNegotiation::HandleNegotiate(const NegotiatePB& request) { + if (request.step() != NegotiatePB::NEGOTIATE) { + Status s = Status::NotAuthorized("expected NEGOTIATE step", + NegotiatePB::NegotiateStep_Name(request.step())); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + TRACE("Received NEGOTIATE request from client"); + + // Fill in the set of features supported by the client. + for (int flag : request.supported_features()) { + // We only add the features that our local build knows about. + RpcFeatureFlag feature_flag = RpcFeatureFlag_IsValid(flag) ? + static_cast<RpcFeatureFlag>(flag) : UNKNOWN; + if (feature_flag != UNKNOWN) { + client_features_.insert(feature_flag); + } + } + + if (encryption_ == RpcEncryption::REQUIRED && + !ContainsKey(client_features_, RpcFeatureFlag::TLS)) { + Status s = Status::NotAuthorized("client does not support required TLS encryption"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + // Find the set of mutually supported authentication types. + set<AuthenticationType> authn_types; + if (request.authn_types().empty()) { + // If the client doesn't send any support authentication types, we assume + // support for SASL. This preserves backwards compatibility with clients who + // don't support security features. + authn_types.insert(AuthenticationType::SASL); + } else { + for (const auto& type : request.authn_types()) { + switch (type.type_case()) { + case AuthenticationTypePB::kSasl: + authn_types.insert(AuthenticationType::SASL); + break; + case AuthenticationTypePB::kToken: + authn_types.insert(AuthenticationType::TOKEN); + break; + case AuthenticationTypePB::kCertificate: + // We only provide authenticated TLS if the certificates are generated + // by the internal CA. + if (!tls_context_->is_external_cert()) { + authn_types.insert(AuthenticationType::CERTIFICATE); + } + break; + case AuthenticationTypePB::TYPE_NOT_SET: { + Sockaddr addr; + RETURN_NOT_OK(socket_->GetPeerAddress(&addr)); + KLOG_EVERY_N_SECS(WARNING, 60) + << "client supports unknown authentication type, consider updating server, address: " + << addr.ToString(); + break; + } + } + } + + if (authn_types.empty()) { + Status s = Status::NotSupported("no mutually supported authentication types"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + } + + if (encryption_ != RpcEncryption::DISABLED && + ContainsKey(authn_types, AuthenticationType::CERTIFICATE) && + tls_context_->has_signed_cert()) { + // If the client supports it and we are locally configured with TLS and have + // a CA-signed cert, choose cert authn. + // TODO(KUDU-1924): consider adding the fingerprint of the CA cert which signed + // the client's cert to the authentication message. + negotiated_authn_ = AuthenticationType::CERTIFICATE; + } else if (ContainsKey(authn_types, AuthenticationType::TOKEN) && + token_verifier_->GetMaxKnownKeySequenceNumber() >= 0 && + encryption_ != RpcEncryption::DISABLED && + tls_context_->has_signed_cert()) { + // If the client supports it, we have a TSK to verify the client's token, + // and we have a signed-cert so the client can verify us, choose token authn. + // TODO(KUDU-1924): consider adding the TSK sequence number to the authentication + // message. + negotiated_authn_ = AuthenticationType::TOKEN; + } else { + // Otherwise we always can fallback to SASL. + DCHECK(ContainsKey(authn_types, AuthenticationType::SASL)); + negotiated_authn_ = AuthenticationType::SASL; + } + + // Fill in the NEGOTIATE step response for the client. + NegotiatePB response; + response.set_step(NegotiatePB::NEGOTIATE); + + // Tell the client which features we support. + server_features_ = kSupportedServerRpcFeatureFlags; + if (tls_context_->has_cert() && encryption_ != RpcEncryption::DISABLED) { + server_features_.insert(TLS); + // If the remote peer is local, then we allow using TLS for authentication + // without encryption or integrity. + if (socket_->IsLoopbackConnection() && !FLAGS_rpc_encrypt_loopback_connections) { + server_features_.insert(TLS_AUTHENTICATION_ONLY); + } + } + + for (RpcFeatureFlag feature : server_features_) { + response.add_supported_features(feature); + } + + switch (negotiated_authn_) { + case AuthenticationType::CERTIFICATE: + response.add_authn_types()->mutable_certificate(); + break; + case AuthenticationType::TOKEN: + response.add_authn_types()->mutable_token(); + break; + case AuthenticationType::SASL: { + response.add_authn_types()->mutable_sasl(); + const set<SaslMechanism::Type>& server_mechs = helper_.EnabledMechs(); + if (PREDICT_FALSE(server_mechs.empty())) { + // This will happen if no mechanisms are enabled before calling Init() + Status s = Status::NotAuthorized("SASL server mechanism list is empty!"); + LOG(ERROR) << s.ToString(); + TRACE("Sending FATAL_UNAUTHORIZED response to client"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + for (auto mechanism : server_mechs) { + response.add_sasl_mechanisms()->set_mechanism(SaslMechanism::name_of(mechanism)); + } + break; + } + case AuthenticationType::INVALID: LOG(FATAL) << "unreachable"; + } + + return SendNegotiatePB(response); +} + +Status ServerNegotiation::HandleTlsHandshake(const NegotiatePB& request) { + if (PREDICT_FALSE(request.step() != NegotiatePB::TLS_HANDSHAKE)) { + Status s = Status::NotAuthorized("expected TLS_HANDSHAKE step", + NegotiatePB::NegotiateStep_Name(request.step())); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + if (PREDICT_FALSE(!request.has_tls_handshake())) { + Status s = Status::NotAuthorized( + "No TLS handshake token in TLS_HANDSHAKE request from client"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + string token; + Status s = tls_handshake_.Continue(request.tls_handshake(), &token); + + if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) { + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + // Regardless of whether this is the final handshake roundtrip (in which case + // Continue would have returned OK), we still need to return a response. + RETURN_NOT_OK(SendTlsHandshake(std::move(token))); + RETURN_NOT_OK(s); + + // TLS handshake is finished. + if (ContainsKey(server_features_, TLS_AUTHENTICATION_ONLY) && + ContainsKey(client_features_, TLS_AUTHENTICATION_ONLY)) { + TRACE("Negotiated auth-only $0 with cipher suite $1", + tls_handshake_.GetProtocol(), tls_handshake_.GetCipherSuite()); + return tls_handshake_.FinishNoWrap(*socket_); + } + + TRACE("Negotiated $0 with cipher suite $1", + tls_handshake_.GetProtocol(), tls_handshake_.GetCipherSuite()); + return tls_handshake_.Finish(&socket_); +} + +Status ServerNegotiation::SendTlsHandshake(string tls_token) { + NegotiatePB msg; + msg.set_step(NegotiatePB::TLS_HANDSHAKE); + msg.mutable_tls_handshake()->swap(tls_token); + return SendNegotiatePB(msg); +} + +Status ServerNegotiation::AuthenticateBySasl(faststring* recv_buf) { + RETURN_NOT_OK(InitSaslServer()); + + NegotiatePB request; + RETURN_NOT_OK(RecvNegotiatePB(&request, recv_buf)); + Status s = HandleSaslInitiate(request); + + while (s.IsIncomplete()) { + RETURN_NOT_OK(RecvNegotiatePB(&request, recv_buf)); + s = HandleSaslResponse(request); + } + RETURN_NOT_OK(s); + + const char* c_username = nullptr; + int rc = sasl_getprop(sasl_conn_.get(), SASL_USERNAME, + reinterpret_cast<const void**>(&c_username)); + // We expect that SASL_USERNAME will always get set. + CHECK(rc == SASL_OK && c_username != nullptr) << "No username on authenticated connection"; + if (negotiated_mech_ == SaslMechanism::GSSAPI) { + // The SASL library doesn't include the user's realm in the username if it's the + // same realm as the default realm of the server. So, we pass it back through the + // Kerberos library to add back the realm if necessary. + string principal = c_username; + RETURN_NOT_OK_PREPEND(security::CanonicalizeKrb5Principal(&principal), + "could not canonicalize krb5 principal"); + + // Map the principal to the corresponding local username. For example, admins + // can set up mappings so that joe@REMOTEREALM becomes something like 'remote-joe' + // locally for the purposes of group mapping, ACLs, etc. + string local_name; + RETURN_NOT_OK_PREPEND(security::MapPrincipalToLocalName(principal, &local_name), + strings::Substitute("could not map krb5 principal '$0' to username", + principal)); + authenticated_user_.SetAuthenticatedByKerberos(std::move(local_name), std::move(principal)); + } else { + authenticated_user_.SetUnauthenticated(c_username); + } + return Status::OK(); +} + +Status ServerNegotiation::AuthenticateByToken(faststring* recv_buf) { + // Sanity check that TLS has been negotiated. Receiving the token on an + // unencrypted channel is a big no-no. + CHECK(tls_negotiated_); + + // Receive the token from the client. + NegotiatePB pb; + RETURN_NOT_OK(RecvNegotiatePB(&pb, recv_buf)); + + if (pb.step() != NegotiatePB::TOKEN_EXCHANGE) { + Status s = Status::NotAuthorized("expected TOKEN_EXCHANGE step", + NegotiatePB::NegotiateStep_Name(pb.step())); + } + if (!pb.has_authn_token()) { + Status s = Status::NotAuthorized("TOKEN_EXCHANGE message must include an authentication token"); + } + + // TODO(KUDU-1924): propagate the specific token verification failure back to the client, + // so it knows how to intelligently retry. + security::TokenPB token; + auto verification_result = token_verifier_->VerifyTokenSignature(pb.authn_token(), &token); + switch (verification_result) { + case security::VerificationResult::VALID: break; + + case security::VerificationResult::INVALID_TOKEN: + case security::VerificationResult::INVALID_SIGNATURE: + case security::VerificationResult::EXPIRED_TOKEN: + case security::VerificationResult::EXPIRED_SIGNING_KEY: { + // These errors indicate the client should get a new token and try again. + Status s = Status::NotAuthorized(VerificationResultToString(verification_result)); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s)); + return s; + } + + case security::VerificationResult::UNKNOWN_SIGNING_KEY: { + // The server doesn't recognize the signing key. This indicates that the + // server has not been updated with the most recent TSKs, so tell the + // client to try again later. + Status s = Status::NotAuthorized(VerificationResultToString(verification_result)); + RETURN_NOT_OK(SendError(ErrorStatusPB::ERROR_UNAVAILABLE, s)); + return s; + } + case security::VerificationResult::INCOMPATIBLE_FEATURE: { + Status s = Status::NotAuthorized(VerificationResultToString(verification_result)); + // These error types aren't recoverable by having the client get a new token. + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + } + + if (!token.has_authn()) { + Status s = Status::NotAuthorized("non-authentication token presented for authentication"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + if (!token.authn().has_username()) { + // This is a runtime error because there should be no way a client could + // get a signed authn token without a subject. + Status s = Status::RuntimeError("authentication token has no username"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s)); + return s; + } + + if (PREDICT_FALSE(FLAGS_rpc_inject_invalid_authn_token_ratio > 0)) { + security::VerificationResult res; + int sel = rand() % 4; + switch (sel) { + case 0: + res = security::VerificationResult::INVALID_TOKEN; + break; + case 1: + res = security::VerificationResult::INVALID_SIGNATURE; + break; + case 2: + res = security::VerificationResult::EXPIRED_TOKEN; + break; + case 3: + res = security::VerificationResult::EXPIRED_SIGNING_KEY; + break; + } + const Status s = kudu::fault_injection::MaybeReturnFailure( + FLAGS_rpc_inject_invalid_authn_token_ratio, + Status::NotAuthorized(VerificationResultToString(res))); + if (!s.ok()) { + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s)); + return s; + } + } + + authenticated_user_.SetAuthenticatedByToken(token.authn().username()); + + // Respond with success message. + pb.Clear(); + pb.set_step(NegotiatePB::TOKEN_EXCHANGE); + return SendNegotiatePB(pb); +} + +Status ServerNegotiation::AuthenticateByCertificate() { + // Sanity check that TLS has been negotiated. Cert-based authentication is + // only possible with TLS. + CHECK(tls_negotiated_); + + // Grab the subject from the client's cert. + security::Cert cert; + RETURN_NOT_OK(tls_handshake_.GetRemoteCert(&cert)); + + boost::optional<string> user_id = cert.UserId(); + boost::optional<string> principal = cert.KuduKerberosPrincipal(); + + if (!user_id) { + Status s = Status::NotAuthorized("did not find expected X509 userId extension in cert"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s)); + return s; + } + + authenticated_user_.SetAuthenticatedByClientCert(*user_id, std::move(principal)); + + return Status::OK(); +} + +Status ServerNegotiation::HandleSaslInitiate(const NegotiatePB& request) { + if (PREDICT_FALSE(request.step() != NegotiatePB::SASL_INITIATE)) { + Status s = Status::NotAuthorized("expected SASL_INITIATE step", + NegotiatePB::NegotiateStep_Name(request.step())); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + TRACE("Received SASL_INITIATE request from client"); + + if (request.sasl_mechanisms_size() != 1) { + Status s = Status::NotAuthorized( + "SASL_INITIATE request must include exactly one SASL mechanism, found", + std::to_string(request.sasl_mechanisms_size())); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + const string& mechanism = request.sasl_mechanisms(0).mechanism(); + TRACE("Client requested to use mechanism: $0", mechanism); + + negotiated_mech_ = SaslMechanism::value_of(mechanism); + + // Rejects any connection from public routable IPs if authentication mechanism + // is plain. See KUDU-1875. + if (negotiated_mech_ == SaslMechanism::PLAIN) { + Sockaddr addr; + RETURN_NOT_OK(socket_->GetPeerAddress(&addr)); + + if (!IsTrustedConnection(addr)) { + Status s = Status::NotAuthorized("unauthenticated connections from publicly " + "routable IPs are prohibited. See " + "--trusted_subnets flag for more information.", + addr.ToString()); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + } + + // If the negotiated mechanism is GSSAPI (Kerberos), configure SASL to use + // integrity protection so that the channel bindings and nonce can be + // verified. + if (negotiated_mech_ == SaslMechanism::GSSAPI) { + RETURN_NOT_OK(EnableIntegrityProtection(sasl_conn_.get())); + } + + const char* server_out = nullptr; + uint32_t server_out_len = 0; + TRACE("Calling sasl_server_start()"); + + Status s = WrapSaslCall(sasl_conn_.get(), [&]() { + return sasl_server_start( + sasl_conn_.get(), // The SASL connection context created by init() + mechanism.c_str(), // The mechanism requested by the client. + request.token().c_str(), // Optional string the client gave us. + request.token().length(), // Client string len. + &server_out, // The output of the SASL library, might not be NULL terminated + &server_out_len); // Output len. + }); + + if (PREDICT_FALSE(!s.ok() && !s.IsIncomplete())) { + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + // We have a valid mechanism match + if (s.ok()) { + DCHECK(server_out_len == 0); + RETURN_NOT_OK(SendSaslSuccess()); + } else { // s.IsIncomplete() (equivalent to SASL_CONTINUE) + RETURN_NOT_OK(SendSaslChallenge(server_out, server_out_len)); + } + return s; +} + +Status ServerNegotiation::HandleSaslResponse(const NegotiatePB& request) { + if (PREDICT_FALSE(request.step() != NegotiatePB::SASL_RESPONSE)) { + Status s = Status::NotAuthorized("expected SASL_RESPONSE step", + NegotiatePB::NegotiateStep_Name(request.step())); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + TRACE("Received SASL_RESPONSE request from client"); + + if (!request.has_token()) { + Status s = Status::NotAuthorized("no token in SASL_RESPONSE from client"); + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; + } + + const char* server_out = nullptr; + uint32_t server_out_len = 0; + TRACE("Calling sasl_server_step()"); + Status s = WrapSaslCall(sasl_conn_.get(), [&]() { + return sasl_server_step( + sasl_conn_.get(), // The SASL connection context created by init() + request.token().c_str(), // Optional string the client gave us + request.token().length(), // Client string len + &server_out, // The output of the SASL library, might not be NULL terminated + &server_out_len); // Output len + }); + + if (s.ok()) { + DCHECK(server_out_len == 0); + return SendSaslSuccess(); + } + if (s.IsIncomplete()) { + return SendSaslChallenge(server_out, server_out_len); + } + RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s)); + return s; +} + +Status ServerNegotiation::SendSaslChallenge(const char* challenge, unsigned clen) { + NegotiatePB response; + response.set_step(NegotiatePB::SASL_CHALLENGE); + response.mutable_token()->assign(challenge, clen); + RETURN_NOT_OK(SendNegotiatePB(response)); + return Status::Incomplete(""); +} + +Status ServerNegotiation::SendSaslSuccess() { + NegotiatePB response; + response.set_step(NegotiatePB::SASL_SUCCESS); + + if (negotiated_mech_ == SaslMechanism::GSSAPI) { + // Send a nonce to the client. + nonce_ = string(); + RETURN_NOT_OK(security::GenerateNonce(nonce_.get_ptr())); + response.set_nonce(*nonce_); + + if (tls_negotiated_) { + // Send the channel bindings to the client. + security::Cert cert; + RETURN_NOT_OK(tls_handshake_.GetLocalCert(&cert)); + + string plaintext_channel_bindings; + RETURN_NOT_OK(cert.GetServerEndPointChannelBindings(&plaintext_channel_bindings)); + RETURN_NOT_OK(SaslEncode(sasl_conn_.get(), + plaintext_channel_bindings, + response.mutable_channel_bindings())); + } + } + + RETURN_NOT_OK(SendNegotiatePB(response)); + return Status::OK(); +} + +Status ServerNegotiation::RecvConnectionContext(faststring* recv_buf) { + TRACE("Waiting for connection context"); + RequestHeader header; + Slice param_buf; + RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), recv_buf, &header, ¶m_buf, deadline_)); + DCHECK(header.IsInitialized()); + + if (header.call_id() != kConnectionContextCallId) { + return Status::NotAuthorized("expected ConnectionContext callid, received", + std::to_string(header.call_id())); + } + + ConnectionContextPB conn_context; + if (!conn_context.ParseFromArray(param_buf.data(), param_buf.size())) { + return Status::NotAuthorized("invalid ConnectionContextPB message, missing fields", + conn_context.InitializationErrorString()); + } + + if (nonce_) { + Status s; + // Validate that the client returned the correct SASL protected nonce. + if (!conn_context.has_encoded_nonce()) { + return Status::NotAuthorized("ConnectionContextPB wrapped nonce missing"); + } + + string decoded_nonce; + s = SaslDecode(sasl_conn_.get(), conn_context.encoded_nonce(), &decoded_nonce); + if (!s.ok()) { + return Status::NotAuthorized("failed to decode nonce", s.message()); + } + + if (*nonce_ != decoded_nonce) { + Sockaddr addr; + RETURN_NOT_OK(socket_->GetPeerAddress(&addr)); + LOG(WARNING) << "Received an invalid connection nonce from client " + << addr.ToString() + << ", this could indicate a replay attack"; + return Status::NotAuthorized("nonce mismatch"); + } + } + + return Status::OK(); +} + +int ServerNegotiation::GetOptionCb(const char* plugin_name, + const char* option, + const char** result, + unsigned* len) { + return helper_.GetOptionCb(plugin_name, option, result, len); +} + +int ServerNegotiation::PlainAuthCb(sasl_conn_t* /*conn*/, + const char* /*user*/, + const char* /*pass*/, + unsigned /*passlen*/, + struct propctx* /*propctx*/) { + TRACE("Received PLAIN auth."); + if (PREDICT_FALSE(!helper_.IsPlainEnabled())) { + LOG(DFATAL) << "Password authentication callback called while PLAIN auth disabled"; + return SASL_BADPARAM; + } + // We always allow PLAIN authentication to succeed. + return SASL_OK; +} + +bool ServerNegotiation::IsTrustedConnection(const Sockaddr& addr) { + static std::once_flag once; + std::call_once(once, [] { + g_trusted_subnets = new vector<Network>(); + CHECK_OK(Network::ParseCIDRStrings(FLAGS_trusted_subnets, g_trusted_subnets)); + + // If --trusted_subnets is not set explicitly, local subnets of all local network + // interfaces as well as the default private subnets will be used. + if (google::GetCommandLineFlagInfoOrDie("trusted_subnets").is_default) { + std::vector<Network> local_networks; + WARN_NOT_OK(GetLocalNetworks(&local_networks), + "Unable to get local networks."); + + g_trusted_subnets->insert(g_trusted_subnets->end(), + local_networks.begin(), + local_networks.end()); + } + }); + + return std::any_of(g_trusted_subnets->begin(), g_trusted_subnets->end(), + [&](const Network& t) { return t.WithinNetwork(addr); }); +} + +} // namespace rpc +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/c7db60aa/be/src/kudu/rpc/server_negotiation.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/rpc/server_negotiation.h b/be/src/kudu/rpc/server_negotiation.h new file mode 100644 index 0000000..e9e945a --- /dev/null +++ b/be/src/kudu/rpc/server_negotiation.h @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <set> +#include <string> +#include <vector> + +#include <sasl/sasl.h> + +#include "kudu/rpc/negotiation.h" +#include "kudu/rpc/remote_user.h" +#include "kudu/rpc/rpc_header.pb.h" +#include "kudu/rpc/sasl_common.h" +#include "kudu/rpc/sasl_helper.h" +#include "kudu/security/tls_handshake.h" +#include "kudu/util/monotime.h" +#include "kudu/util/net/socket.h" +#include "kudu/util/status.h" + +namespace kudu { + +class Slice; + +namespace security { +class TlsContext; +class TokenVerifier; +} + +namespace rpc { + +// Class for doing KRPC negotiation with a remote client over a bidirectional socket. +// Operations on this class are NOT thread-safe. +class ServerNegotiation { + public: + // Creates a new server negotiation instance, taking ownership of the + // provided socket. After completing the negotiation process by setting the + // desired options and calling Negotiate((), the socket can be retrieved with + // release_socket(). + // + // The provided TlsContext must outlive this negotiation instance. + ServerNegotiation(std::unique_ptr<Socket> socket, + const security::TlsContext* tls_context, + const security::TokenVerifier* token_verifier, + RpcEncryption encryption); + + // Enable PLAIN authentication. + // Despite PLAIN authentication taking a username and password, we disregard + // the password and use this as a "unauthenticated" mode. + // Must be called before Negotiate(). + Status EnablePlain(); + + // Enable GSSAPI (Kerberos) authentication. + // Must be called before Negotiate(). + Status EnableGSSAPI(); + + // Returns mechanism negotiated by this connection. + // Must be called after Negotiate(). + SaslMechanism::Type negotiated_mechanism() const; + + // Returns the negotiated authentication type for the connection. + // Must be called after Negotiate(). + AuthenticationType negotiated_authn() const { + DCHECK_NE(negotiated_authn_, AuthenticationType::INVALID); + return negotiated_authn_; + } + + // Returns true if TLS was negotiated. + // Must be called after Negotiate(). + bool tls_negotiated() const { + return tls_negotiated_; + } + + // Returns the set of RPC system features supported by the remote client. + // Must be called after Negotiate(). + std::set<RpcFeatureFlag> client_features() const { + return client_features_; + } + + // Returns the set of RPC system features supported by the remote client. + // Must be called after Negotiate(). + // Subsequent calls to this method or client_features() will return an empty set. + std::set<RpcFeatureFlag> take_client_features() { + return std::move(client_features_); + } + + // Name of the user that was authenticated. + // Must be called after a successful Negotiate(). + // + // Subsequent calls will return bogus data. + RemoteUser take_authenticated_user() { + return std::move(authenticated_user_); + } + + // Specify the fully-qualified domain name of the remote server. + // Must be called before Negotiate(). Required for some mechanisms. + void set_server_fqdn(const std::string& domain_name); + + // Set deadline for connection negotiation. + void set_deadline(const MonoTime& deadline); + + Socket* socket() const { return socket_.get(); } + + // Returns the socket owned by this server negotiation. The caller will own + // the socket after this call, and the negotiation instance should no longer + // be used. Must be called after Negotiate(). + std::unique_ptr<Socket> release_socket() { return std::move(socket_); } + + // Negotiate with the remote client. Should only be called once per + // ServerNegotiation and socket instance, after all options have been set. + // + // Returns OK on success, otherwise may return NotAuthorized, NotSupported, or + // another non-OK status. + Status Negotiate() WARN_UNUSED_RESULT; + + // SASL callback for plugin options, supported mechanisms, etc. + // Returns SASL_FAIL if the option is not handled, which does not fail the handshake. + int GetOptionCb(const char* plugin_name, const char* option, + const char** result, unsigned* len); + + // SASL callback for PLAIN authentication via SASL_CB_SERVER_USERDB_CHECKPASS. + int PlainAuthCb(sasl_conn_t* conn, const char* user, const char* pass, + unsigned passlen, struct propctx* propctx); + + // Perform a "pre-flight check" that everything required to act as a Kerberos + // server is properly set up. + static Status PreflightCheckGSSAPI() WARN_UNUSED_RESULT; + + private: + + // Parse a negotiate request from the client, deserializing it into 'msg'. + // If the request is malformed, sends an error message to the client. + Status RecvNegotiatePB(NegotiatePB* msg, faststring* recv_buf) WARN_UNUSED_RESULT; + + // Encode and send the specified negotiate response message to the server. + Status SendNegotiatePB(const NegotiatePB& msg) WARN_UNUSED_RESULT; + + // Encode and send the specified RPC error message to the client. + // Calls Status.ToString() for the embedded error message. + Status SendError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) WARN_UNUSED_RESULT; + + // Parse and validate connection header. + Status ValidateConnectionHeader(faststring* recv_buf) WARN_UNUSED_RESULT; + + // Initialize the SASL server negotiation instance. + Status InitSaslServer() WARN_UNUSED_RESULT; + + // Handle case when client sends NEGOTIATE request. Builds the set of + // client-supported RPC features, determines a mutually supported + // authentication type to use for the connection, and sends a NEGOTIATE + // response. + Status HandleNegotiate(const NegotiatePB& request) WARN_UNUSED_RESULT; + + // Handle a TLS_HANDSHAKE request message from the server. + Status HandleTlsHandshake(const NegotiatePB& request) WARN_UNUSED_RESULT; + + // Send a TLS_HANDSHAKE response message to the server with the provided token. + Status SendTlsHandshake(std::string tls_token) WARN_UNUSED_RESULT; + + // Authenticate the client using SASL. Populates the 'authenticated_user_' + // field with the SASL principal. + // 'recv_buf' allows a receive buffer to be reused. + Status AuthenticateBySasl(faststring* recv_buf) WARN_UNUSED_RESULT; + + // Authenticate the client using a token. Populates the + // 'authenticated_user_' field with the token's principal. + // 'recv_buf' allows a receive buffer to be reused. + Status AuthenticateByToken(faststring* recv_buf) WARN_UNUSED_RESULT; + + // Authenticate the client using the client's TLS certificate. Populates the + // 'authenticated_user_' field with the certificate's subject. + Status AuthenticateByCertificate() WARN_UNUSED_RESULT; + + // Handle case when client sends SASL_INITIATE request. + // Returns Status::OK if the SASL negotiation is complete, or + // Status::Incomplete if a SASL_RESPONSE step is expected. + Status HandleSaslInitiate(const NegotiatePB& request) WARN_UNUSED_RESULT; + + // Handle case when client sends SASL_RESPONSE request. + Status HandleSaslResponse(const NegotiatePB& request) WARN_UNUSED_RESULT; + + // Send a SASL_CHALLENGE response to the client with a challenge token. + Status SendSaslChallenge(const char* challenge, unsigned clen) WARN_UNUSED_RESULT; + + // Send a SASL_SUCCESS response to the client. + Status SendSaslSuccess() WARN_UNUSED_RESULT; + + // Receive and validate the ConnectionContextPB. + Status RecvConnectionContext(faststring* recv_buf) WARN_UNUSED_RESULT; + + // Returns true if connection is from trusted subnets or local networks. + bool IsTrustedConnection(const Sockaddr& addr); + + // The socket to the remote client. + std::unique_ptr<Socket> socket_; + + // SASL state. + std::vector<sasl_callback_t> callbacks_; + std::unique_ptr<sasl_conn_t, SaslDeleter> sasl_conn_; + SaslHelper helper_; + boost::optional<std::string> nonce_; + + // TLS state. + const security::TlsContext* tls_context_; + security::TlsHandshake tls_handshake_; + const RpcEncryption encryption_; + bool tls_negotiated_; + + // TSK state. + const security::TokenVerifier* token_verifier_; + + // The set of features supported by the client and server. Filled in during negotiation. + std::set<RpcFeatureFlag> client_features_; + std::set<RpcFeatureFlag> server_features_; + + // The successfully-authenticated user, if applicable. Filled in during + // negotiation. + RemoteUser authenticated_user_; + + // The authentication type. Filled in during negotiation. + AuthenticationType negotiated_authn_; + + // The SASL mechanism. Filled in during negotiation if the negotiated + // authentication type is SASL. + SaslMechanism::Type negotiated_mech_; + + // Negotiation timeout deadline. + MonoTime deadline_; +}; + +} // namespace rpc +} // namespace kudu
