KUDU-2191 (5/n): Add Kerberos SASL support to the HMS client The bulk of this commit is adding a new Thrift transport type, SaslClientTransport, which facilitates SASL GSSAPI negotiation, as well as integrity/privacy channel protection. The new transport is based on Impala's version with some significant changes:
- Impala has a client and server SASL transport, necessitating a common superclass (SaslTransport). Since we only need a client transport, I collapsed all of the logic into a single class, which I think makes the code easier to follow. - The transport uses Kudu helper types where possible, e.g., faststring buffers, and our existing SASL utility infrastructure. - Integrity and privacy channel protection are implemented. There are no standlone unit-tests for the transport, since that would require implementing the server-specific counterpart. Instead, the class is tested indirectly through using the HMS client to communicate with a Kerberos-enabled HMS instance. Change-Id: I8f217ae05fd36c8ee88fe20eeccd73d49233a345 Reviewed-on: http://gerrit.cloudera.org:8080/8692 Tested-by: Kudu Jenkins Reviewed-by: Todd Lipcon <t...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/kudu/repo Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/57fe0c3d Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/57fe0c3d Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/57fe0c3d Branch: refs/heads/master Commit: 57fe0c3db086c9fc61fbcef9b2d879422b387a7e Parents: 20ba3c7 Author: Dan Burkert <danburk...@apache.org> Authored: Tue Nov 14 18:14:06 2017 -0800 Committer: Dan Burkert <danburk...@apache.org> Committed: Fri Feb 23 21:09:17 2018 +0000 ---------------------------------------------------------------------- src/kudu/hms/CMakeLists.txt | 7 +- src/kudu/hms/hms_client-test.cc | 119 +++++- src/kudu/hms/hms_client.cc | 54 ++- src/kudu/hms/hms_client.h | 15 +- src/kudu/hms/mini_hms.cc | 91 ++++- src/kudu/hms/mini_hms.h | 16 + src/kudu/hms/sasl_client_transport.cc | 402 +++++++++++++++++++ src/kudu/hms/sasl_client_transport.h | 176 ++++++++ .../mini-cluster/external_mini_cluster-test.cc | 25 +- src/kudu/mini-cluster/external_mini_cluster.cc | 11 + src/kudu/rpc/client_negotiation.cc | 8 +- src/kudu/rpc/sasl_common.cc | 89 ++-- src/kudu/rpc/sasl_common.h | 55 ++- src/kudu/rpc/server_negotiation.cc | 9 +- 14 files changed, 994 insertions(+), 83 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/src/kudu/hms/CMakeLists.txt b/src/kudu/hms/CMakeLists.txt index e449047..f50eead 100644 --- a/src/kudu/hms/CMakeLists.txt +++ b/src/kudu/hms/CMakeLists.txt @@ -24,10 +24,13 @@ target_link_libraries(hms_thrift thrift) add_dependencies(hms_thrift ${HMS_THRIFT_TGTS}) set(HMS_SRCS - hms_client.cc) + hms_client.cc + sasl_client_transport.cc) set(HMS_DEPS + gflags glog hms_thrift + krpc kudu_util) add_library(kudu_hms ${HMS_SRCS}) @@ -62,6 +65,7 @@ set(MINI_HMS_SRCS add_library(mini_hms ${MINI_HMS_SRCS}) target_link_libraries(mini_hms gutil + krpc kudu_test_util kudu_util) add_dependencies(mini_hms hms-plugin) @@ -71,6 +75,7 @@ if (NOT NO_TESTS) set(KUDU_TEST_LINK_LIBS kudu_hms mini_hms + mini_kdc ${KUDU_MIN_TEST_LIBS}) # This test has to run serially since otherwise starting the HMS can take a very http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/hms_client-test.cc ---------------------------------------------------------------------- diff --git a/src/kudu/hms/hms_client-test.cc b/src/kudu/hms/hms_client-test.cc index 9ec0f4c..87601eb 100644 --- a/src/kudu/hms/hms_client-test.cc +++ b/src/kudu/hms/hms_client-test.cc @@ -22,12 +22,15 @@ #include <utility> #include <vector> +#include <boost/optional/optional.hpp> #include <glog/stl_logging.h> // IWYU pragma: keep #include <gtest/gtest.h> #include "kudu/hms/hive_metastore_constants.h" #include "kudu/hms/hive_metastore_types.h" #include "kudu/hms/mini_hms.h" +#include "kudu/rpc/sasl_common.h" +#include "kudu/security/test/mini_kdc.h" #include "kudu/util/monotime.h" #include "kudu/util/net/net_util.h" #include "kudu/util/net/sockaddr.h" @@ -36,6 +39,8 @@ #include "kudu/util/test_macros.h" #include "kudu/util/test_util.h" +using boost::optional; +using kudu::rpc::SaslProtection; using std::make_pair; using std::string; using std::vector; @@ -43,7 +48,8 @@ using std::vector; namespace kudu { namespace hms { -class HmsClientTest : public KuduTest { +class HmsClientTest : public KuduTest, + public ::testing::WithParamInterface<optional<SaslProtection::Type>> { public: Status CreateTable(HmsClient* client, @@ -75,11 +81,47 @@ class HmsClientTest : public KuduTest { } }; -TEST_F(HmsClientTest, TestHmsOperations) { +INSTANTIATE_TEST_CASE_P(ProtectionTypes, + HmsClientTest, + ::testing::Values(boost::none + , SaslProtection::kIntegrity +// On macos, krb5 has issues repeatedly spinning up new KDCs ('unable to reach +// any KDC in realm KRBTEST.COM, tried 1 KDC'). Integrity protection gives us +// good coverage, so we disable the other variants. +#ifndef __APPLE__ + , SaslProtection::kAuthentication + , SaslProtection::kPrivacy +#endif + )); + +TEST_P(HmsClientTest, TestHmsOperations) { + optional<SaslProtection::Type> protection = GetParam(); + MiniKdc kdc; MiniHms hms; + HmsClientOptions hms_client_opts; + + if (protection) { + ASSERT_OK(kdc.Start()); + + string spn = "hive/127.0.0.1"; + string ktpath; + ASSERT_OK(kdc.CreateServiceKeytab(spn, &ktpath)); + + ASSERT_OK(rpc::SaslInit()); + hms.EnableKerberos(kdc.GetEnvVars()["KRB5_CONFIG"], + spn, + ktpath, + *protection); + + ASSERT_OK(kdc.CreateUserPrincipal("alice")); + ASSERT_OK(kdc.Kinit("alice")); + ASSERT_OK(kdc.SetKrb5Environment()); + hms_client_opts.enable_kerberos = true; + } + ASSERT_OK(hms.Start()); - HmsClient client(hms.address(), HmsClientOptions()); + HmsClient client(hms.address(), hms_client_opts); ASSERT_OK(client.Start()); // Create a database. @@ -200,6 +242,77 @@ TEST_F(HmsClientTest, TestHmsOperations) { ASSERT_OK(client.Stop()); } +TEST_P(HmsClientTest, TestLargeObjects) { + optional<SaslProtection::Type> protection = GetParam(); + + MiniKdc kdc; + MiniHms hms; + HmsClientOptions hms_client_opts; + + if (protection) { + ASSERT_OK(kdc.Start()); + + string spn = "hive/127.0.0.1"; + string ktpath; + ASSERT_OK(kdc.CreateServiceKeytab(spn, &ktpath)); + + ASSERT_OK(rpc::SaslInit()); + hms.EnableKerberos(kdc.GetEnvVars()["KRB5_CONFIG"], + spn, + ktpath, + *protection); + + ASSERT_OK(kdc.CreateUserPrincipal("alice")); + ASSERT_OK(kdc.Kinit("alice")); + ASSERT_OK(kdc.SetKrb5Environment()); + hms_client_opts.enable_kerberos = true; + } + + ASSERT_OK(hms.Start()); + + HmsClient client(hms.address(), hms_client_opts); + ASSERT_OK(client.Start()); + + string database_name = "default"; + string table_name = "big_table"; + + hive::Table table; + table.dbName = database_name; + table.tableName = table_name; + table.tableType = "MANAGED_TABLE"; + hive::FieldSchema partition_key; + partition_key.name = "c1"; + partition_key.type = "int"; + table.partitionKeys.emplace_back(std::move(partition_key)); + + ASSERT_OK(client.CreateTable(table)); + + // Add a bunch of partitions to the table. This ensures we can send and + // receive really large thrift objects. We have to add the partitions in small + // batches, otherwise Derby chokes. + const int kBatches = 25; + const int kPartitionsPerBatch = 40; + + for (int batch_idx = 0; batch_idx < kBatches; batch_idx++) { + vector<hive::Partition> partitions; + for (int partition_idx = 0; partition_idx < kPartitionsPerBatch; partition_idx++) { + hive::Partition partition; + partition.dbName = database_name; + partition.tableName = table_name; + partition.values = { std::to_string(batch_idx * kPartitionsPerBatch + partition_idx) }; + partitions.emplace_back(std::move(partition)); + } + + ASSERT_OK(client.AddPartitions(database_name, table_name, std::move(partitions))); + } + + ASSERT_OK(client.GetTable(database_name, table_name, &table)); + + vector<hive::Partition> partitions; + ASSERT_OK(client.GetPartitions(database_name, table_name, &partitions)); + ASSERT_EQ(kBatches * kPartitionsPerBatch, partitions.size()); +} + TEST_F(HmsClientTest, TestHmsFaultHandling) { MiniHms hms; ASSERT_OK(hms.Start()); http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/hms_client.cc ---------------------------------------------------------------------- diff --git a/src/kudu/hms/hms_client.cc b/src/kudu/hms/hms_client.cc index c2c2a67..6e2f4c6 100644 --- a/src/kudu/hms/hms_client.cc +++ b/src/kudu/hms/hms_client.cc @@ -37,16 +37,29 @@ #include "kudu/gutil/strings/substitute.h" #include "kudu/hms/ThriftHiveMetastore.h" #include "kudu/hms/hive_metastore_constants.h" +#include "kudu/hms/sasl_client_transport.h" +#include "kudu/util/flag_tags.h" #include "kudu/util/net/net_util.h" #include "kudu/util/status.h" #include "kudu/util/stopwatch.h" +// Default to 100 MiB to match Thrift TSaslTransport.receiveSaslMessage and the +// HMS metastore.server.max.message.size config. +DEFINE_int32(hms_client_max_buf_size, 100 * 1024 * 1024, + "Maximum size of Hive MetaStore objects that can be received by the " + "HMS client in bytes."); +TAG_FLAG(hms_client_max_buf_size, experimental); +// Note: despite being marked as a runtime flag, the new buf size value will +// only take effect for new HMS clients. +TAG_FLAG(hms_client_max_buf_size, runtime); + using apache::thrift::TException; using apache::thrift::protocol::TBinaryProtocol; using apache::thrift::protocol::TJSONProtocol; using apache::thrift::transport::TBufferedTransport; using apache::thrift::transport::TMemoryBuffer; using apache::thrift::transport::TSocket; +using apache::thrift::transport::TTransport; using apache::thrift::transport::TTransportException; using std::make_shared; using std::shared_ptr; @@ -76,6 +89,8 @@ namespace hms { return Status::IllegalState((msg), e.what()); \ } catch (const hive::MetaException& e) { \ return Status::RemoteError((msg), e.what()); \ + } catch (const SaslException& e) { \ + return e.status().CloneAndPrepend((msg)); \ } catch (const TTransportException& e) { \ switch (e.getType()) { \ case TTransportException::TIMED_OUT: return Status::TimedOut((msg), e.what()); \ @@ -85,6 +100,8 @@ namespace hms { } \ } catch (const TException& e) { \ return Status::IOError((msg), e.what()); \ + } catch (const std::exception& e) { \ + return Status::RuntimeError((msg), e.what()); \ } const char* const HmsClient::kKuduTableIdKey = "kudu.table_id"; @@ -122,7 +139,16 @@ HmsClient::HmsClient(const HostPort& hms_address, const HmsClientOptions& option socket->setSendTimeout(options.send_timeout.ToMilliseconds()); socket->setRecvTimeout(options.recv_timeout.ToMilliseconds()); socket->setConnTimeout(options.conn_timeout.ToMilliseconds()); - auto transport = make_shared<TBufferedTransport>(std::move(socket)); + shared_ptr<TTransport> transport; + + if (options.enable_kerberos) { + transport = make_shared<SaslClientTransport>(hms_address.host(), + std::move(socket), + FLAGS_hms_client_max_buf_size); + } else { + transport = make_shared<TBufferedTransport>(std::move(socket)); + } + auto protocol = make_shared<TBinaryProtocol>(std::move(transport)); client_ = hive::ThriftHiveMetastoreClient(std::move(protocol)); } @@ -267,6 +293,32 @@ Status HmsClient::GetNotificationEvents(int64_t last_event_id, return Status::OK(); } +Status HmsClient::AddPartitions(const string& database_name, + const string& table_name, + vector<hive::Partition> partitions) { + SCOPED_LOG_SLOW_EXECUTION(WARNING, kSlowExecutionWarningThresholdMs, "add HMS table partitions"); + hive::AddPartitionsRequest request; + hive::AddPartitionsResult response; + + request.dbName = database_name; + request.tblName = table_name; + request.parts = std::move(partitions); + + HMS_RET_NOT_OK(client_.add_partitions_req(response, request), + "failed to add Hive MetaStore table partitions"); + return Status::OK(); +} + +Status HmsClient::GetPartitions(const string& database_name, + const string& table_name, + vector<hive::Partition>* partitions) { + SCOPED_LOG_SLOW_EXECUTION(WARNING, kSlowExecutionWarningThresholdMs, "get HMS table partitions"); + HMS_RET_NOT_OK(client_.get_partitions(*partitions, database_name, table_name, -1), + "failed to get Hive Metastore table partitions"); + return Status::OK(); +} + + Status HmsClient::DeserializeJsonTable(Slice json, hive::Table* table) { shared_ptr<TMemoryBuffer> membuffer(new TMemoryBuffer(json.size())); membuffer->write(json.data(), json.size()); http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/hms_client.h ---------------------------------------------------------------------- diff --git a/src/kudu/hms/hms_client.h b/src/kudu/hms/hms_client.h index 0ce17f9..a71fcf6 100644 --- a/src/kudu/hms/hms_client.h +++ b/src/kudu/hms/hms_client.h @@ -52,6 +52,9 @@ struct HmsClientOptions { // Thrift socket connect timeout. MonoDelta conn_timeout = MonoDelta::FromSeconds(60); + + // Whether to use SASL Kerberos authentication when connecting to the HMS. + bool enable_kerberos = false; }; // A client for the Hive MetaStore. @@ -76,8 +79,6 @@ struct HmsClientOptions { // handling connection retries, because the higher-level construct which is // handling HA deployments will naturally want to retry across HMS instances as // opposed to retrying repeatedly on a single instance. -// -// TODO(dan): this client does not yet handle Kerberized HMS deployments. class HmsClient { public: @@ -160,6 +161,16 @@ class HmsClient { int32_t max_events, std::vector<hive::NotificationEvent>* events) WARN_UNUSED_RESULT; + // Adds partitions to an HMS table. + Status AddPartitions(const std::string& database_name, + const std::string& table_name, + std::vector<hive::Partition> partitions) WARN_UNUSED_RESULT; + + // Retrieves the partitions of an HMS table. + Status GetPartitions(const std::string& database_name, + const std::string& table_name, + std::vector<hive::Partition>* partitions) WARN_UNUSED_RESULT; + // Deserializes a JSON encoded table. // // Notification event log messages often include table objects serialized as http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/mini_hms.cc ---------------------------------------------------------------------- diff --git a/src/kudu/hms/mini_hms.cc b/src/kudu/hms/mini_hms.cc index c7de53d..078341d 100644 --- a/src/kudu/hms/mini_hms.cc +++ b/src/kudu/hms/mini_hms.cc @@ -38,6 +38,7 @@ #include "kudu/util/subprocess.h" #include "kudu/util/test_util.h" +using kudu::rpc::SaslProtection; using std::map; using std::string; using std::unique_ptr; @@ -52,8 +53,21 @@ MiniHms::~MiniHms() { WARN_NOT_OK(Stop(), "Failed to stop MiniHms"); } -namespace { +void MiniHms::EnableKerberos(string krb5_conf, + string service_principal, + string keytab_file, + SaslProtection::Type protection) { + CHECK(!hms_process_); + CHECK(!krb5_conf.empty()); + CHECK(!service_principal.empty()); + CHECK(!keytab_file.empty()); + krb5_conf_ = std::move(krb5_conf); + service_principal_ = std::move(service_principal); + keytab_file_ = std::move(keytab_file); + protection_ = protection; +} +namespace { Status FindHomeDir(const char* name, const string& bin_dir, string* home_dir) { string name_upper; ToUpperCase(name, &name_upper); @@ -67,7 +81,6 @@ Status FindHomeDir(const char* name, const string& bin_dir, string* home_dir) { } return Status::OK(); } - } // anonymous namespace Status MiniHms::Start() { @@ -92,6 +105,7 @@ Status MiniHms::Start() { auto tmp_dir = GetTestDataDirectory(); RETURN_NOT_OK(CreateHiveSite(tmp_dir)); + RETURN_NOT_OK(CreateCoreSite(tmp_dir)); // Comma-separated list of additional jars to add to the HMS classpath. string aux_jars = Substitute("$0/hms-plugin.jar", bin_dir); @@ -101,7 +115,11 @@ Status MiniHms::Start() { { "HIVE_AUX_JARS_PATH", aux_jars }, { "HIVE_CONF_DIR", tmp_dir }, { "JAVA_TOOL_OPTIONS", "-Dhive.log.level=WARN -Dhive.root.logger=console" }, + { "HADOOP_CONF_DIR", tmp_dir }, }; + if (!krb5_conf_.empty()) { + env_vars["JAVA_TOOL_OPTIONS"] += Substitute(" -Djava.security.krb5.conf=$0", krb5_conf_); + } // Start the HMS. hms_process_.reset(new Subprocess({ @@ -150,10 +168,19 @@ Status MiniHms::Resume() { } Status MiniHms::CreateHiveSite(const string& tmp_dir) const { - // 'datanucleus.schema.autoCreateAll' and 'hive.metastore.schema.verification' - // allow Hive to startup and run without first running the schemaTool. - // 'hive.metastore.event.db.listener.timetolive' configures how long the - // Metastore will store notification log events before GCing them. + + // - datanucleus.schema.autoCreateAll + // - hive.metastore.schema.verification + // Allow Hive to startup and run without first running the schemaTool. + // + // - hive.metastore.event.db.listener.timetolive + // Configures how long the Metastore will store notification log events + // before GCing them. + // + // - hive.metastore.sasl.enabled + // - hive.metastore.kerberos.keytab.file + // - hive.metastore.kerberos.principal + // Configures the HMS to use Kerberos for its Thrift RPC interface. static const string kFileTemplate = R"( <configuration> <property> @@ -188,17 +215,67 @@ Status MiniHms::CreateHiveSite(const string& tmp_dir) const { <name>hive.metastore.event.db.listener.timetolive</name> <value>$0s</value> </property> + + <property> + <name>hive.metastore.sasl.enabled</name> + <value>$2</value> + </property> + + <property> + <name>hive.metastore.kerberos.keytab.file</name> + <value>$3</value> + </property> + + <property> + <name>hive.metastore.kerberos.principal</name> + <value>$4</value> + </property> + + <property> + <name>hadoop.rpc.protection</name> + <value>$5</value> + </property> </configuration> )"; string file_contents = strings::Substitute(kFileTemplate, notification_log_ttl_.ToSeconds(), - tmp_dir); + tmp_dir, + !keytab_file_.empty(), + keytab_file_, + service_principal_, + SaslProtection::name_of(protection_)); return WriteStringToFile(Env::Default(), file_contents, JoinPathSegments(tmp_dir, "hive-site.xml")); } +Status MiniHms::CreateCoreSite(const string& tmp_dir) const { + + // - hadoop.security.authentication + // The HMS uses Hadoop's UGI contraption which will refuse to login a user + // with Kerberos unless this special property is set. The property must + // not be in hive-site.xml because a new Configuration object is created + // to search for the property, and it only checks places Hadoop knows + // about. + + static const string kFileTemplate = R"( +<configuration> + <property> + <name>hadoop.security.authentication</name> + <value>$0</value> + </property> +</configuration> + )"; + + string file_contents = strings::Substitute(kFileTemplate, + keytab_file_.empty() ? "simple" : "kerberos"); + + return WriteStringToFile(Env::Default(), + file_contents, + JoinPathSegments(tmp_dir, "core-site.xml")); +} + } // namespace hms } // namespace kudu http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/mini_hms.h ---------------------------------------------------------------------- diff --git a/src/kudu/hms/mini_hms.h b/src/kudu/hms/mini_hms.h index 4bef966..52f480b 100644 --- a/src/kudu/hms/mini_hms.h +++ b/src/kudu/hms/mini_hms.h @@ -24,6 +24,7 @@ #include <glog/logging.h> #include "kudu/gutil/port.h" +#include "kudu/rpc/sasl_common.h" #include "kudu/util/monotime.h" #include "kudu/util/net/net_util.h" #include "kudu/util/status.h" @@ -45,6 +46,12 @@ class MiniHms { notification_log_ttl_ = ttl; } + // Configures the mini HMS to use Kerberos. + void EnableKerberos(std::string krb5_conf, + std::string service_principal, + std::string keytab_file, + rpc::SaslProtection::Type protection); + // Starts the mini Hive metastore. // // If the MiniHms has already been started and stopped, it will be restarted @@ -71,12 +78,21 @@ class MiniHms { // Creates a hive-site.xml for the mini HMS. Status CreateHiveSite(const std::string& tmp_dir) const WARN_UNUSED_RESULT; + // Creates a core-site.xml for the mini HMS. + Status CreateCoreSite(const std::string& tmp_dir) const WARN_UNUSED_RESULT; + // Waits for the metastore process to bind to a port. Status WaitForHmsPorts() WARN_UNUSED_RESULT; std::unique_ptr<Subprocess> hms_process_; MonoDelta notification_log_ttl_ = MonoDelta::FromSeconds(86400); uint16_t port_ = 0; + + // Kerberos configuration + std::string krb5_conf_; + std::string service_principal_; + std::string keytab_file_; + rpc::SaslProtection::Type protection_ = rpc::SaslProtection::kAuthentication; }; } // namespace hms http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/sasl_client_transport.cc ---------------------------------------------------------------------- diff --git a/src/kudu/hms/sasl_client_transport.cc b/src/kudu/hms/sasl_client_transport.cc new file mode 100644 index 0000000..95d1222 --- /dev/null +++ b/src/kudu/hms/sasl_client_transport.cc @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/hms/sasl_client_transport.h" + +#include <algorithm> +#include <cstring> +#include <limits> +#include <memory> +#include <ostream> +#include <string> + +#include <glog/logging.h> +#include <thrift/transport/TTransport.h> + +#include "kudu/gutil/endian.h" +#include "kudu/gutil/port.h" +#include "kudu/gutil/strings/human_readable.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/rpc/sasl_common.h" +#include "kudu/rpc/sasl_helper.h" +#include "kudu/util/faststring.h" +#include "kudu/util/logging.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" + +using apache::thrift::transport::TTransportException; +using std::shared_ptr; +using std::string; +using strings::Substitute; + +namespace kudu { + +using rpc::SaslMechanism; +using rpc::WrapSaslCall; + +namespace hms { + +namespace { + +// SASL negotiation frames are sent with an 8-bit status and a 32-bit length. +const uint32_t kSaslHeaderSize = sizeof(uint8_t) + sizeof(uint32_t); + +// Frame headers consist of a 32-bit length. +const uint32_t kFrameHeaderSize = sizeof(uint32_t); + +// SASL SASL_CB_GETOPT callback function. +int GetoptCb(SaslClientTransport* client_transport, + const char* plugin_name, + const char* option, + const char** result, + unsigned* len) { + return client_transport->GetOptionCb(plugin_name, option, result, len); +} + +// SASL SASL_CB_CANON_USER callback function. +int CanonUserCb(sasl_conn_t* /*conn*/, + void* /*context*/, + const char* in, unsigned inlen, + unsigned /*flags*/, + const char* /*user_realm*/, + char* out, unsigned out_max, unsigned* out_len) { + CHECK_LE(inlen, out_max); + memcpy(out, in, inlen); + *out_len = inlen; + return SASL_OK; +} + +// SASL SASL_CB_USER callback function. +int UserCb(void* /*context*/, int id, const char** result, unsigned* len) { + CHECK_EQ(SASL_CB_USER, id); + + // Setting the username to the empty string causes the remote end to use the + // clients Kerberos principal, which is correct. + *result = ""; + if (len != nullptr) *len = 0; + return SASL_OK; +} +} // anonymous namespace + +SaslClientTransport::SaslClientTransport(const string& server_fqdn, + shared_ptr<TTransport> transport, + size_t max_recv_buf_size) + : transport_(std::move(transport)), + sasl_helper_(rpc::SaslHelper::CLIENT), + sasl_callbacks_({ + rpc::SaslBuildCallback(SASL_CB_GETOPT, reinterpret_cast<int (*)()>(&GetoptCb), this), + rpc::SaslBuildCallback(SASL_CB_CANON_USER, + reinterpret_cast<int (*)()>(&CanonUserCb), + this), + rpc::SaslBuildCallback(SASL_CB_USER, reinterpret_cast<int (*)()>(&UserCb), nullptr), + rpc::SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr) + }), + needs_wrap_(false), + max_recv_buf_size_(max_recv_buf_size), + // Set a reasonable max send buffer size for negotiation. Once negotiation + // is complete the negotiated value will be used. + max_send_buf_size_(64 * 1024) { + sasl_helper_.set_server_fqdn(server_fqdn); + sasl_helper_.EnableGSSAPI(); + ResetWriteBuf(); +} + +bool SaslClientTransport::isOpen() { + return transport_->isOpen(); +} + +bool SaslClientTransport::peek() { + return !read_slice_.empty() || transport_->peek(); +} + +void SaslClientTransport::open() { + transport_->open(); + DCHECK(transport_->isOpen()); + try { + Negotiate(); + } catch (...) { + transport_->close(); + throw; + } +} + +void SaslClientTransport::close() { + transport_->close(); + sasl_conn_.reset(); +} + +void SaslClientTransport::ReadFrame() { + DCHECK_EQ(0, read_buf_.size()); + DCHECK(read_slice_.empty()); + + uint8_t payload_len_buf[kFrameHeaderSize]; + transport_->readAll(payload_len_buf, kFrameHeaderSize); + size_t payload_len = NetworkByteOrder::Load32(payload_len_buf); + + if (payload_len > 1024 * 1024) { + KLOG_EVERY_N_SECS(WARNING, 60) << "Received large Thrift SASL frame: " + << HumanReadableNumBytes::ToString(payload_len); + if (payload_len > max_recv_buf_size_) { + throw TTransportException(Substitute("Thrift SASL frame is too long: $0/$1", + HumanReadableNumBytes::ToString(payload_len), + HumanReadableNumBytes::ToString(max_recv_buf_size_))); + } + } + + read_buf_.reserve(kFrameHeaderSize + payload_len); + read_buf_.append(payload_len_buf, kFrameHeaderSize); + read_buf_.resize(kFrameHeaderSize + payload_len); + transport_->readAll(&read_buf_.data()[kFrameHeaderSize], payload_len); + + if (needs_wrap_) { + // Point read_slice_ directly at the SASL library's internal buffer. This + // avoids having to copy the decoded data back into read_buf_. + Status s = rpc::SaslDecode(sasl_conn_.get(), read_buf_, &read_slice_); + if (!s.ok()) { + throw SaslException(s); + } + ResetReadBuf(); + } else { + read_slice_ = read_buf_; + read_slice_.remove_prefix(kFrameHeaderSize); + } +} + +uint32_t SaslClientTransport::read(uint8_t* buf, uint32_t len) { + // If there is nothing left to read in the buffer, then fill it. + if (read_slice_.empty()) { + ReadFrame(); + } + + uint32_t n = std::min(read_slice_.size(), static_cast<size_t>(len)); + memcpy(buf, read_slice_.data(), n); + read_slice_.remove_prefix(n); + if (read_slice_.empty()) { + ResetReadBuf(); + } + return n; +} + +void SaslClientTransport::write(const uint8_t* buf, uint32_t len) { + // Check that we've already preallocated space in the buffer for the frame-header. + DCHECK(write_buf_.size() >= kFrameHeaderSize); + + // Check if the amount to write would overflow a frame. + while (write_buf_.size() + len > max_send_buf_size_) { + uint32_t n = max_send_buf_size_ - write_buf_.size(); + write_buf_.append(buf, n); + flush(); + buf += n; + len -= n; + } + + write_buf_.append(buf, len); +} + +void SaslClientTransport::flush() { + if (needs_wrap_) { + Slice plaintext(write_buf_); + plaintext.remove_prefix(kFrameHeaderSize); + Slice ciphertext; + Status s = rpc::SaslEncode(sasl_conn_.get(), plaintext, &ciphertext); + if (!s.ok()) { + throw SaslException(s); + } + + // Note: when the SASL C library encodes the plaintext, it prefixes the + // ciphertext with the length. Since this happens to match the SASL/Thrift + // frame format, we can send the ciphertext unmodified to the remote server. + transport_->write(ciphertext.data(), ciphertext.size()); + } else { + size_t payload_len = write_buf_.size() - kFrameHeaderSize; + NetworkByteOrder::Store32(write_buf_.data(), payload_len); + transport_->write(write_buf_.data(), write_buf_.size()); + } + + transport_->flush(); + ResetWriteBuf(); +} + +void SaslClientTransport::Negotiate() { + SetupSaslContext(); + + faststring recv_buf; + SendSaslStart(); + + for (;;) { + NegotiationStatus status = ReceiveSaslMessage(&recv_buf); + + if (status == TSASL_COMPLETE) { + throw SaslException( + Status::IllegalState("Received SASL COMPLETE status, but handshake is not finished")); + } + CHECK_EQ(status, TSASL_OK); + + const char* out; + unsigned out_len; + Status s = WrapSaslCall(sasl_conn_.get(), [&] { + return sasl_client_step(sasl_conn_.get(), + reinterpret_cast<const char*>(recv_buf.data()), + recv_buf.size(), + nullptr, + &out, + &out_len); + }); + + if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) { + throw SaslException(std::move(s)); + } + + SendSaslMessage(status, Slice(out, out_len)); + transport_->flush(); + + if (s.ok()) { + break; + } + } + + NegotiationStatus status = ReceiveSaslMessage(&recv_buf); + if (status != TSASL_COMPLETE) { + throw SaslException( + Status::IllegalState("Received SASL OK status, but expected SASL COMPLETE")); + } + DCHECK_EQ(0, recv_buf.size()); + + needs_wrap_ = rpc::NeedsWrap(sasl_conn_.get()); + max_send_buf_size_ = rpc::GetMaxSendBufferSize(sasl_conn_.get()); + VLOG(2) << "Thrift SASL GSSAPI negotiation complete" + << "; needs wrap: " << (needs_wrap_ ? "true" : "false") + << ", max send frame length: " + << HumanReadableNumBytes::ToStringWithoutRounding(max_send_buf_size_) + << ", max receive frame length: " + << HumanReadableNumBytes::ToStringWithoutRounding(max_recv_buf_size_); +} + +void SaslClientTransport::SendSaslMessage(NegotiationStatus status, Slice payload) { + uint8_t header[kSaslHeaderSize]; + header[0] = status; + DCHECK_LT(payload.size(), std::numeric_limits<int32_t>::max()); + NetworkByteOrder::Store32(&header[1], payload.size()); + transport_->write(header, kSaslHeaderSize); + if (!payload.empty()) { + transport_->write(payload.data(), payload.size()); + } +} + +NegotiationStatus SaslClientTransport::ReceiveSaslMessage(faststring* payload) { + // Read the fixed-length message header. + uint8_t header[kSaslHeaderSize]; + transport_->readAll(header, kSaslHeaderSize); + size_t len = NetworkByteOrder::Load32(&header[1]); + + // Handle status errors. + switch (header[0]) { + case TSASL_OK: + case TSASL_COMPLETE: break; + case TSASL_BAD: + case TSASL_ERROR: + throw SaslException(Status::RuntimeError("SASL peer indicated failure")); + // The Thrift client should never receive TSASL_START. + case TSASL_START: + default: + throw SaslException(Status::RuntimeError("Unexpected SASL status", + std::to_string(header[0]))); + } + + // Read the message payload. + if (len > max_recv_buf_size_) { + throw SaslException(Status::RuntimeError(Substitute( + "SASL negotiation message payload exceeds maximum length: $0/$1", + HumanReadableNumBytes::ToString(len), + HumanReadableNumBytes::ToString(max_recv_buf_size_)))); + } + payload->resize(len); + transport_->readAll(payload->data(), len); + + return static_cast<NegotiationStatus>(header[0]); +} + +void SaslClientTransport::SendSaslStart() { + const char* init_msg = nullptr; + unsigned init_msg_len = 0; + const char* negotiated_mech = nullptr; + + Status s = WrapSaslCall(sasl_conn_.get(), [&] { + return sasl_client_start( + sasl_conn_.get(), // The SASL connection context created by sasl_client_new() + SaslMechanism::name_of(SaslMechanism::GSSAPI), // The mechanism to use. + nullptr, // Disables INTERACT return if NULL. + &init_msg, // Filled in on success. + &init_msg_len, // Filled in on success. + &negotiated_mech); // Filled in on success. + }); + + if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) { + throw SaslException(std::move(s)); + } + + // Check that the SASL library is using the mechanism that we picked. + DCHECK_EQ(SaslMechanism::value_of(negotiated_mech), SaslMechanism::GSSAPI); + s = rpc::EnableProtection(sasl_conn_.get(), + rpc::SaslProtection::kAuthentication, + max_recv_buf_size_); + if (!s.ok()) { + throw SaslException(s); + } + + // These two calls comprise a single message in the thrift-sasl protocol. + SendSaslMessage(TSASL_START, Slice(negotiated_mech)); + SendSaslMessage(TSASL_OK, Slice(init_msg, init_msg_len)); + transport_->flush(); +} + +int SaslClientTransport::GetOptionCb(const char* plugin_name, const char* option, + const char** result, unsigned* len) { + return sasl_helper_.GetOptionCb(plugin_name, option, result, len); +} + +void SaslClientTransport::SetupSaslContext() { + sasl_conn_t* sasl_conn = nullptr; + Status s = WrapSaslCall(nullptr /* no conn */, [&] { + return sasl_client_new( + // TODO(dan): make the service name configurable. + "hive", // Registered name of the service using SASL. Required. + sasl_helper_.server_fqdn(), // The fully qualified domain name of the remote server. + nullptr, // Local and remote IP address strings. (we don't use + nullptr, // any mechanisms which require this info.) + sasl_callbacks_.data(), // Connection-specific callbacks. + 0, // flags + &sasl_conn); + }); + if (!s.ok()) { + throw SaslException(s); + } + sasl_conn_.reset(sasl_conn); +} + +void SaslClientTransport::ResetReadBuf() { + read_buf_.clear(); + read_buf_.shrink_to_fit(); +} + +void SaslClientTransport::ResetWriteBuf() { + write_buf_.resize(kFrameHeaderSize); + write_buf_.shrink_to_fit(); +} + +} // namespace hms +} // namespace kudu http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/hms/sasl_client_transport.h ---------------------------------------------------------------------- diff --git a/src/kudu/hms/sasl_client_transport.h b/src/kudu/hms/sasl_client_transport.h new file mode 100644 index 0000000..a2bc7f3 --- /dev/null +++ b/src/kudu/hms/sasl_client_transport.h @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include <sasl/sasl.h> +#include <thrift/transport/TTransportException.h> +#include <thrift/transport/TVirtualTransport.h> + +#include "kudu/rpc/sasl_helper.h" +#include "kudu/util/faststring.h" +#include "kudu/util/slice.h" +#include "kudu/util/status.h" + +namespace apache { +namespace thrift { +namespace transport { +class TTransport; +} // namespace transport +} // namespace thrift +} // namespace apache + +namespace kudu { +namespace rpc { +struct SaslDeleter; +} // namespace rpc +namespace hms { + +// An exception representing a SASL or Kerberos failure. +class SaslException : public apache::thrift::transport::TTransportException { + public: + explicit SaslException(Status status) + : TTransportException(status.ToString()), + status_(std::move(status)) { + } + + const Status& status() const { + return status_; + } + + private: + Status status_; +}; + +// An enum describing the possible states of the SASL negotiation protocol. +enum NegotiationStatus { + TSASL_INVALID = -1, + TSASL_START = 1, + TSASL_OK = 2, + TSASL_BAD = 3, + TSASL_ERROR = 4, + TSASL_COMPLETE = 5 +}; + +// A Thrift transport which uses SASL GSSAPI to authenticate as a client to a +// remote server. +// +// SaslClientTransport internally holds buffers, so it does not need the +// underlying transport to be buffered. +class SaslClientTransport + : public apache::thrift::transport::TVirtualTransport<SaslClientTransport> { + public: + SaslClientTransport(const std::string& server_fqdn, + std::shared_ptr<TTransport> transport, + size_t max_recv_buf_size); + + ~SaslClientTransport() override = default; + + bool isOpen() override; + + bool peek() override; + + void open() override; + + void close() override; + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void flush() override; + + int GetOptionCb(const char* plugin_name, const char* option, + const char** result, unsigned* len); + + private: + + // Runs SASL negotiation with the remote server. + void Negotiate(); + + // Sends a SASL negotiation message to the underlying transport. + // + // Send a SASL negotiation message using the Thrift framing protocol: + // + // - 1 byte of status + // - 4 bytes of remaining length + // - var-len payload + void SendSaslMessage(NegotiationStatus status, Slice payload); + + // Receives a SASL negotiation message from the underlying transport. + // + // The returned negotiation status will be of type OK or COMPLETE, all + // other statuses result in an exception. + NegotiationStatus ReceiveSaslMessage(faststring* payload); + + // Initializes SASL state. + void SetupSaslContext(); + + // Sends the initial SASL connection message. + void SendSaslStart(); + + // Reads a frame from the underlying transport, storing the payload into + // read_slice_. If the connection is using SASL auth-conf or auth-int + // protection the data is automatically decoded. + void ReadFrame(); + + // Resets the read buffer to empty, and deallocates its internal buffer. + void ResetReadBuf(); + + // Resets the write buffer to the size of a frame header, and deallocates its + // internal buffer. + void ResetWriteBuf(); + + // The underlying transport. Typically a TCP socket. + std::shared_ptr<TTransport> transport_; + + // SASL state. + rpc::SaslHelper sasl_helper_; + std::unique_ptr<sasl_conn_t, rpc::SaslDeleter> sasl_conn_; + std::vector<sasl_callback_t> sasl_callbacks_; + + // Whether the connection is using auth-int or auth-conf protection. + bool needs_wrap_; + + // The negotiated SASL maximum buffer sizes. These correspond to the maximum + // sized frames that can be received or sent. + // + // Note: the Java implementation of the Thrift SASL transport does not respect + // the negotiated maximum buffer size (THRIFT-4483) and never splits a message + // into multiple frames, so we end up having to set the recv buf size to match + // the largest serialized Thrift message we want to be able to receive. + size_t max_recv_buf_size_; + size_t max_send_buf_size_; + + // The read buffer and slice. The slice points to the remaining frame data + // which hasn't been read yet. + faststring read_buf_; + Slice read_slice_; + + // The write buffer. + faststring write_buf_; +}; + +} // namespace hms +} // namespace kudu http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/mini-cluster/external_mini_cluster-test.cc ---------------------------------------------------------------------- diff --git a/src/kudu/mini-cluster/external_mini_cluster-test.cc b/src/kudu/mini-cluster/external_mini_cluster-test.cc index c6787f4..8cd357c 100644 --- a/src/kudu/mini-cluster/external_mini_cluster-test.cc +++ b/src/kudu/mini-cluster/external_mini_cluster-test.cc @@ -76,12 +76,12 @@ class ExternalMiniClusterTest : public KuduTest, public testing::WithParamInterface<pair<Kerberos, HiveMetastore>> { }; -// TODO(dan): Add ENABLED/ENABLED when the mini HMS supports Kerberos. INSTANTIATE_TEST_CASE_P(KerberosOnAndOff, ExternalMiniClusterTest, testing::Values(make_pair(Kerberos::DISABLED, HiveMetastore::DISABLED), make_pair(Kerberos::ENABLED, HiveMetastore::DISABLED), - make_pair(Kerberos::DISABLED, HiveMetastore::ENABLED))); + make_pair(Kerberos::DISABLED, HiveMetastore::ENABLED), + make_pair(Kerberos::ENABLED, HiveMetastore::ENABLED))); void SmokeTestKerberizedCluster(ExternalMiniClusterOptions opts) { ASSERT_TRUE(opts.enable_kerberos); @@ -190,6 +190,17 @@ TEST_P(ExternalMiniClusterTest, TestBasicOperation) { ASSERT_EQ(ts_rpc.ToString(), ts->bound_rpc_hostport().ToString()); ASSERT_EQ(ts_http.ToString(), ts->bound_http_hostport().ToString()); + // Verify that the HMS is reachable. + if (opts.enable_hive_metastore) { + hms::HmsClientOptions hms_client_opts; + hms_client_opts.enable_kerberos = opts.enable_kerberos; + hms::HmsClient hms_client(cluster.hms()->address(), hms_client_opts); + ASSERT_OK(hms_client.Start()); + vector<string> tables; + ASSERT_OK(hms_client.GetAllTables("default", &tables)); + ASSERT_TRUE(tables.empty()) << "tables: " << tables; + } + // Verify that, in a Kerberized cluster, if we drop our Kerberos environment, // we can't make RPCs to a server. if (opts.enable_kerberos) { @@ -202,16 +213,6 @@ TEST_P(ExternalMiniClusterTest, TestBasicOperation) { "but client does not have Kerberos credentials available"); } - // Verify that the HMS is reachable. - if (opts.enable_hive_metastore) { - hms::HmsClient hms_client(cluster.hms()->address(), hms::HmsClientOptions()); - ASSERT_OK(hms_client.Start()); - vector<string> tables; - ASSERT_OK(hms_client.GetAllTables("default", &tables)); - ASSERT_TRUE(tables.empty()) << "tables: " << tables; - } - - // Test that if we inject a fault into a tablet server's boot process // ExternalTabletServer::Restart() still returns OK, even if the tablet server crashed. ts->Shutdown(); ts->mutable_flags()->push_back("--fault_before_start=1.0"); http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/mini-cluster/external_mini_cluster.cc ---------------------------------------------------------------------- diff --git a/src/kudu/mini-cluster/external_mini_cluster.cc b/src/kudu/mini-cluster/external_mini_cluster.cc index 55a75fa..b90c4ef 100644 --- a/src/kudu/mini-cluster/external_mini_cluster.cc +++ b/src/kudu/mini-cluster/external_mini_cluster.cc @@ -44,6 +44,7 @@ #include "kudu/master/master.proxy.h" #include "kudu/rpc/messenger.h" #include "kudu/rpc/rpc_controller.h" +#include "kudu/rpc/sasl_common.h" #include "kudu/security/test/mini_kdc.h" #include "kudu/server/server_base.pb.h" #include "kudu/server/server_base.proxy.h" @@ -185,6 +186,16 @@ Status ExternalMiniCluster::Start() { if (opts_.enable_hive_metastore) { hms_.reset(new hms::MiniHms()); + + if (opts_.enable_kerberos) { + string spn = "hive/127.0.0.1"; + string ktpath; + RETURN_NOT_OK_PREPEND(kdc_->CreateServiceKeytab(spn, &ktpath), + "could not create keytab"); + hms_->EnableKerberos(kdc_->GetEnvVars()["KRB5_CONFIG"], spn, ktpath, + rpc::SaslProtection::kAuthentication); + } + RETURN_NOT_OK_PREPEND(hms_->Start(), "Failed to start the Hive Metastore"); } http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/rpc/client_negotiation.cc ---------------------------------------------------------------------- diff --git a/src/kudu/rpc/client_negotiation.cc b/src/kudu/rpc/client_negotiation.cc index 71dde92..02175f6 100644 --- a/src/kudu/rpc/client_negotiation.cc +++ b/src/kudu/rpc/client_negotiation.cc @@ -596,7 +596,7 @@ Status ClientNegotiation::SendSaslInitiate() { // integrity protection so that the channel bindings and nonce can be // verified. if (negotiated_mech_ == SaslMechanism::GSSAPI) { - RETURN_NOT_OK(EnableIntegrityProtection(sasl_conn_.get())); + RETURN_NOT_OK(EnableProtection(sasl_conn_.get(), SaslProtection::kIntegrity)); } NegotiatePB msg; @@ -662,7 +662,7 @@ Status ClientNegotiation::HandleSaslSuccess(const NegotiatePB& response) { RETURN_NOT_OK_PREPEND(cert.GetServerEndPointChannelBindings(&expected_channel_bindings), "failed to generate channel bindings"); - string received_channel_bindings; + Slice received_channel_bindings; RETURN_NOT_OK_PREPEND(SaslDecode(sasl_conn_.get(), response.channel_bindings(), &received_channel_bindings), @@ -704,7 +704,9 @@ Status ClientNegotiation::SendConnectionContext() { if (nonce_) { // Reply with the SASL-protected nonce. We only set the nonce when using SASL GSSAPI. - RETURN_NOT_OK(SaslEncode(sasl_conn_.get(), *nonce_, conn_context.mutable_encoded_nonce())); + Slice ciphertext; + RETURN_NOT_OK(SaslEncode(sasl_conn_.get(), *nonce_, &ciphertext)); + *conn_context.mutable_encoded_nonce() = ciphertext.ToString(); } return SendFramedMessageBlocking(socket(), header, conn_context, deadline_); http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/rpc/sasl_common.cc ---------------------------------------------------------------------- diff --git a/src/kudu/rpc/sasl_common.cc b/src/kudu/rpc/sasl_common.cc index a377d16..645e854 100644 --- a/src/kudu/rpc/sasl_common.cc +++ b/src/kudu/rpc/sasl_common.cc @@ -17,7 +17,6 @@ #include "kudu/rpc/sasl_common.h" -#include <algorithm> #include <cstdio> #include <cstring> #include <limits> @@ -46,7 +45,7 @@ namespace rpc { const char* const kSaslMechPlain = "PLAIN"; const char* const kSaslMechGSSAPI = "GSSAPI"; -extern const size_t kSaslMaxOutBufLen = 1024; +extern const size_t kSaslMaxBufSize = 1024; // See WrapSaslCall(). static __thread string* g_auth_failure_capture = nullptr; @@ -351,45 +350,44 @@ Status WrapSaslCall(sasl_conn_t* conn, const std::function<int()>& call) { } } -Status SaslEncode(sasl_conn_t* conn, const std::string& plaintext, std::string* encoded) { - size_t offset = 0; - - // The SASL library can only encode up to a maximum amount at a time, so we - // have to call encode multiple times if our input is larger than this max. - while (offset < plaintext.size()) { - const char* out; - unsigned out_len; - size_t len = std::min(kSaslMaxOutBufLen, plaintext.size() - offset); - - RETURN_NOT_OK(WrapSaslCall(conn, [&]() { - return sasl_encode(conn, plaintext.data() + offset, len, &out, &out_len); - })); +bool NeedsWrap(sasl_conn_t* sasl_conn) { + const unsigned* ssf; + int rc = sasl_getprop(sasl_conn, SASL_SSF, reinterpret_cast<const void**>(&ssf)); + CHECK_EQ(rc, SASL_OK) << "Failed to get SSF property on authenticated SASL connection"; + return *ssf != 0; +} - encoded->append(out, out_len); - offset += len; - } +uint32_t GetMaxSendBufferSize(sasl_conn_t* sasl_conn) { + const unsigned* max_buf_size; + int rc = sasl_getprop(sasl_conn, SASL_MAXOUTBUF, reinterpret_cast<const void**>(&max_buf_size)); + CHECK_EQ(rc, SASL_OK) + << "Failed to get max output buffer property on authenticated SASL connection"; + return *max_buf_size; +} +Status SaslEncode(sasl_conn_t* conn, Slice plaintext, Slice* ciphertext) { + const char* out; + unsigned out_len; + RETURN_NOT_OK_PREPEND(WrapSaslCall(conn, [&] { + return sasl_encode(conn, + reinterpret_cast<const char*>(plaintext.data()), + plaintext.size(), + &out, &out_len); + }), "SASL encode failed"); + *ciphertext = Slice(out, out_len); return Status::OK(); } -Status SaslDecode(sasl_conn_t* conn, const string& encoded, string* plaintext) { - size_t offset = 0; - - // The SASL library can only decode up to a maximum amount at a time, so we - // have to call decode multiple times if our input is larger than this max. - while (offset < encoded.size()) { - const char* out; - unsigned out_len; - size_t len = std::min(kSaslMaxOutBufLen, encoded.size() - offset); - - RETURN_NOT_OK(WrapSaslCall(conn, [&]() { - return sasl_decode(conn, encoded.data() + offset, len, &out, &out_len); - })); - - plaintext->append(out, out_len); - offset += len; - } - +Status SaslDecode(sasl_conn_t* conn, Slice ciphertext, Slice* plaintext) { + const char* out; + unsigned out_len; + RETURN_NOT_OK_PREPEND(WrapSaslCall(conn, [&] { + return sasl_decode(conn, + reinterpret_cast<const char*>(ciphertext.data()), + ciphertext.size(), + &out, &out_len); + }), "SASL decode failed"); + *plaintext = Slice(out, out_len); return Status::OK(); } @@ -425,14 +423,16 @@ sasl_callback_t SaslBuildCallback(int id, int (*proc)(void), void* context) { return callback; } -Status EnableIntegrityProtection(sasl_conn_t* sasl_conn) { +Status EnableProtection(sasl_conn_t* sasl_conn, + SaslProtection::Type minimum_protection, + size_t max_recv_buf_size) { sasl_security_properties_t sec_props; memset(&sec_props, 0, sizeof(sec_props)); - sec_props.min_ssf = 1; + sec_props.min_ssf = minimum_protection; sec_props.max_ssf = std::numeric_limits<sasl_ssf_t>::max(); - sec_props.maxbufsize = kSaslMaxOutBufLen; + sec_props.maxbufsize = max_recv_buf_size; - RETURN_NOT_OK_PREPEND(WrapSaslCall(sasl_conn, [&] () { + RETURN_NOT_OK_PREPEND(WrapSaslCall(sasl_conn, [&] { return sasl_setprop(sasl_conn, SASL_SEC_PROPS, &sec_props); }), "failed to set SASL security properties"); return Status::OK(); @@ -457,5 +457,14 @@ const char* SaslMechanism::name_of(SaslMechanism::Type val) { } } +const char* SaslProtection::name_of(SaslProtection::Type val) { + switch (val) { + case SaslProtection::kAuthentication: return "authentication"; + case SaslProtection::kIntegrity: return "integrity"; + case SaslProtection::kPrivacy: return "privacy"; + } + LOG(FATAL) << "unknown SASL protection type: " << val; +} + } // namespace rpc } // namespace kudu http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/rpc/sasl_common.h ---------------------------------------------------------------------- diff --git a/src/kudu/rpc/sasl_common.h b/src/kudu/rpc/sasl_common.h index 888e7cb..2454cfd 100644 --- a/src/kudu/rpc/sasl_common.h +++ b/src/kudu/rpc/sasl_common.h @@ -19,13 +19,15 @@ #define KUDU_RPC_SASL_COMMON_H #include <cstddef> +#include <cstdint> #include <functional> -#include <string> #include <set> +#include <string> #include <sasl/sasl.h> #include "kudu/gutil/port.h" +#include "kudu/util/slice.h" #include "kudu/util/status.h" namespace kudu { @@ -37,7 +39,7 @@ namespace rpc { // Constants extern const char* const kSaslMechPlain; extern const char* const kSaslMechGSSAPI; -extern const size_t kSaslMaxOutBufLen; +extern const size_t kSaslMaxBufSize; struct SaslMechanism { enum Type { @@ -49,6 +51,18 @@ struct SaslMechanism { static const char* name_of(Type val); }; +struct SaslProtection { + enum Type { + // SASL authentication without integrity or privacy. + kAuthentication = 0, + // Integrity protection, i.e. messages are HMAC'd. + kIntegrity = 1, + // Privacy protection, i.e. messages are encrypted. + kPrivacy = 2, + }; + static const char* name_of(Type val); +}; + // Initialize the SASL library. // appname: Name of the application for logging messages & sasl plugin configuration. // Note that this string must remain allocated for the lifetime of the program. @@ -93,19 +107,38 @@ std::set<SaslMechanism::Type> SaslListAvailableMechs(); // context: An object to pass to the callback as the context pointer, or NULL. sasl_callback_t SaslBuildCallback(int id, int (*proc)(void), void* context); -// Require integrity protection on the SASL connection. Should be called before -// initiating the SASL negotiation. -Status EnableIntegrityProtection(sasl_conn_t* sasl_conn) WARN_UNUSED_RESULT; +// Enable SASL integrity and privacy protection on the connection. Also allows +// setting the minimum required protection level, and the maximum receive buffer +// size. +Status EnableProtection(sasl_conn_t* sasl_conn, + SaslProtection::Type minimum_protection = SaslProtection::kAuthentication, + size_t max_recv_buf_size = kSaslMaxBufSize) WARN_UNUSED_RESULT; -// Encode the provided data and append it to 'encoded'. +// Returns true if the SASL connection has been negotiated with auth-int or +// auth-conf. 'sasl_conn' must already be negotiated. +bool NeedsWrap(sasl_conn_t* sasl_conn); + +// Retrieves the negotiated maximum send buffer size for auth-int or auth-conf +// protected channels. +uint32_t GetMaxSendBufferSize(sasl_conn_t* sasl_conn) WARN_UNUSED_RESULT; + +// Encode the provided data. +// +// The plaintext data must not be longer than the negotiated maximum buffer size. +// +// The output 'ciphertext' slice is only valid until the next use of the SASL connection. Status SaslEncode(sasl_conn_t* conn, - const std::string& plaintext, - std::string* encoded) WARN_UNUSED_RESULT; + Slice plaintext, + Slice* ciphertext) WARN_UNUSED_RESULT; -// Decode the provided SASL-encoded data and append it to 'plaintext'. +// Decode the provided SASL-encoded data. +// +// The decoded plaintext must not be longer than the negotiated maximum buffer size. +// +// The output 'plaintext' slice is only valid until the next use of the SASL connection. Status SaslDecode(sasl_conn_t* conn, - const std::string& encoded, - std::string* plaintext) WARN_UNUSED_RESULT; + Slice ciphertext, + Slice* plaintext) WARN_UNUSED_RESULT; // Deleter for sasl_conn_t instances, for use with gscoped_ptr after calling sasl_*_new() struct SaslDeleter { http://git-wip-us.apache.org/repos/asf/kudu/blob/57fe0c3d/src/kudu/rpc/server_negotiation.cc ---------------------------------------------------------------------- diff --git a/src/kudu/rpc/server_negotiation.cc b/src/kudu/rpc/server_negotiation.cc index b12b54b..a623853 100644 --- a/src/kudu/rpc/server_negotiation.cc +++ b/src/kudu/rpc/server_negotiation.cc @@ -789,7 +789,7 @@ Status ServerNegotiation::HandleSaslInitiate(const NegotiatePB& request) { // integrity protection so that the channel bindings and nonce can be // verified. if (negotiated_mech_ == SaslMechanism::GSSAPI) { - RETURN_NOT_OK(EnableIntegrityProtection(sasl_conn_.get())); + RETURN_NOT_OK(EnableProtection(sasl_conn_.get(), SaslProtection::kIntegrity)); } const char* server_out = nullptr; @@ -884,9 +884,12 @@ Status ServerNegotiation::SendSaslSuccess() { string plaintext_channel_bindings; RETURN_NOT_OK(cert.GetServerEndPointChannelBindings(&plaintext_channel_bindings)); + + Slice ciphertext; RETURN_NOT_OK(SaslEncode(sasl_conn_.get(), plaintext_channel_bindings, - response.mutable_channel_bindings())); + &ciphertext)); + *response.mutable_channel_bindings() = ciphertext.ToString(); } } @@ -919,7 +922,7 @@ Status ServerNegotiation::RecvConnectionContext(faststring* recv_buf) { return Status::NotAuthorized("ConnectionContextPB wrapped nonce missing"); } - string decoded_nonce; + Slice decoded_nonce; s = SaslDecode(sasl_conn_.get(), conn_context.encoded_nonce(), &decoded_nonce); if (!s.ok()) { return Status::NotAuthorized("failed to decode nonce", s.message());