http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/interval_tree-inl.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/interval_tree-inl.h b/be/src/kudu/util/interval_tree-inl.h new file mode 100644 index 0000000..7637317 --- /dev/null +++ b/be/src/kudu/util/interval_tree-inl.h @@ -0,0 +1,444 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifndef KUDU_UTIL_INTERVAL_TREE_INL_H +#define KUDU_UTIL_INTERVAL_TREE_INL_H + +#include <algorithm> +#include <vector> + +#include "kudu/util/interval_tree.h" + +namespace kudu { + +template<class Traits> +IntervalTree<Traits>::IntervalTree(const IntervalVector &intervals) + : root_(NULL) { + if (!intervals.empty()) { + root_ = CreateNode(intervals); + } +} + +template<class Traits> +IntervalTree<Traits>::~IntervalTree() { + delete root_; +} + +template<class Traits> +template<class QueryPointType> +void IntervalTree<Traits>::FindContainingPoint(const QueryPointType &query, + IntervalVector *results) const { + if (root_) { + root_->FindContainingPoint(query, results); + } +} + +template<class Traits> +template<class Callback, class QueryContainer> +void IntervalTree<Traits>::ForEachIntervalContainingPoints( + const QueryContainer& queries, + const Callback& cb) const { + if (root_) { + root_->ForEachIntervalContainingPoints(queries.begin(), queries.end(), cb); + } +} + + +template<class Traits> +void IntervalTree<Traits>::FindIntersectingInterval(const interval_type &query, + IntervalVector *results) const { + if (root_) { + root_->FindIntersectingInterval(query, results); + } +} + +template<class Traits> +static bool LessThan(const typename Traits::point_type &a, + const typename Traits::point_type &b) { + return Traits::compare(a, b) < 0; +} + +// Select a split point which attempts to evenly divide 'in' into three groups: +// (a) those that are fully left of the split point +// (b) those that overlap the split point. +// (c) those that are fully right of the split point +// These three groups are stored in the output parameters '*left', '*overlapping', +// and '*right', respectively. The selected split point is stored in *split_point. +// +// For example, the input interval set: +// +// |------1-------| |-----2-----| +// |--3--| |---4--| |----5----| +// | +// Resulting split: | Partition point +// | +// +// *left: intervals 1 and 3 +// *overlapping: interval 4 +// *right: intervals 2 and 5 +template<class Traits> +void IntervalTree<Traits>::Partition(const IntervalVector &in, + point_type *split_point, + IntervalVector *left, + IntervalVector *overlapping, + IntervalVector *right) { + CHECK(!in.empty()); + + // Pick a split point which is the median of all of the interval boundaries. + std::vector<point_type> endpoints; + endpoints.reserve(in.size() * 2); + for (const interval_type &interval : in) { + endpoints.push_back(Traits::get_left(interval)); + endpoints.push_back(Traits::get_right(interval)); + } + std::sort(endpoints.begin(), endpoints.end(), LessThan<Traits>); + *split_point = endpoints[endpoints.size() / 2]; + + // Partition into the groups based on the determined split point. + for (const interval_type &interval : in) { + if (Traits::compare(Traits::get_right(interval), *split_point) < 0) { + // | split point + // |------------| | + // interval + left->push_back(interval); + } else if (Traits::compare(Traits::get_left(interval), *split_point) > 0) { + // | split point + // | |------------| + // interval + right->push_back(interval); + } else { + // | split point + // | + // |------------| + // interval + overlapping->push_back(interval); + } + } +} + +template<class Traits> +typename IntervalTree<Traits>::node_type *IntervalTree<Traits>::CreateNode( + const IntervalVector &intervals) { + IntervalVector left, right, overlap; + point_type split_point; + + // First partition the input intervals and select a split point + Partition(intervals, &split_point, &left, &overlap, &right); + + // Recursively subdivide the intervals which are fully left or fully + // right of the split point into subtree nodes. + node_type *left_node = !left.empty() ? CreateNode(left) : NULL; + node_type *right_node = !right.empty() ? CreateNode(right) : NULL; + + return new node_type(split_point, left_node, overlap, right_node); +} + +namespace interval_tree_internal { + +// Node in the interval tree. +template<typename Traits> +class ITNode { + private: + // Import types. + typedef std::vector<typename Traits::interval_type> IntervalVector; + typedef typename Traits::interval_type interval_type; + typedef typename Traits::point_type point_type; + + public: + ITNode(point_type split_point, + ITNode<Traits> *left, + const IntervalVector &overlap, + ITNode<Traits> *right); + ~ITNode(); + + // See IntervalTree::FindContainingPoint(...) + template<class QueryPointType> + void FindContainingPoint(const QueryPointType &query, + IntervalVector *results) const; + + // See IntervalTree::ForEachIntervalContainingPoints(). + // We use iterators here since as recursion progresses down the tree, we + // process sub-sequences of the original set of query points. + template<class Callback, class ItType> + void ForEachIntervalContainingPoints(ItType begin_queries, + ItType end_queries, + const Callback& cb) const; + + // See IntervalTree::FindIntersectingInterval(...) + void FindIntersectingInterval(const interval_type &query, + IntervalVector *results) const; + + private: + // Comparators for sorting lists of intervals. + static bool SortByAscLeft(const interval_type &a, const interval_type &b); + static bool SortByDescRight(const interval_type &a, const interval_type &b); + + // Partition point of this node. + point_type split_point_; + + // Those nodes that overlap with split_point_, in ascending order by their left side. + IntervalVector overlapping_by_asc_left_; + + // Those nodes that overlap with split_point_, in descending order by their right side. + IntervalVector overlapping_by_desc_right_; + + // Tree node for intervals fully left of split_point_, or NULL. + ITNode *left_; + + // Tree node for intervals fully right of split_point_, or NULL. + ITNode *right_; + + DISALLOW_COPY_AND_ASSIGN(ITNode); +}; + +template<class Traits> +bool ITNode<Traits>::SortByAscLeft(const interval_type &a, const interval_type &b) { + return Traits::compare(Traits::get_left(a), Traits::get_left(b)) < 0; +} + +template<class Traits> +bool ITNode<Traits>::SortByDescRight(const interval_type &a, const interval_type &b) { + return Traits::compare(Traits::get_right(a), Traits::get_right(b)) > 0; +} + +template <class Traits> +ITNode<Traits>::ITNode(typename Traits::point_type split_point, + ITNode<Traits> *left, const IntervalVector &overlap, + ITNode<Traits> *right) + : split_point_(std::move(split_point)), left_(left), right_(right) { + // Store two copies of the set of intervals which overlap the split point: + // 1) Sorted by ascending left boundary + overlapping_by_asc_left_.assign(overlap.begin(), overlap.end()); + std::sort(overlapping_by_asc_left_.begin(), overlapping_by_asc_left_.end(), SortByAscLeft); + // 2) Sorted by descending right boundary + overlapping_by_desc_right_.assign(overlap.begin(), overlap.end()); + std::sort(overlapping_by_desc_right_.begin(), overlapping_by_desc_right_.end(), SortByDescRight); +} + +template<class Traits> +ITNode<Traits>::~ITNode() { + if (left_) delete left_; + if (right_) delete right_; +} + +template<class Traits> +template<class Callback, class ItType> +void ITNode<Traits>::ForEachIntervalContainingPoints(ItType begin_queries, + ItType end_queries, + const Callback& cb) const { + if (begin_queries == end_queries) return; + + typedef decltype(*begin_queries) QueryPointType; + const auto& partitioner = [&](const QueryPointType& query_point) { + return Traits::compare(query_point, split_point_) < 0; + }; + + // Partition the query points into those less than the split_point_ and those greater + // than or equal to the split_point_. Because the input queries are already sorted, we + // can use 'std::partition_point' instead of 'std::partition'. + // + // The resulting 'partition_point' is the first query point in the second group. + // + // Complexity: O(log(number of query points)) + DCHECK(std::is_partitioned(begin_queries, end_queries, partitioner)); + auto partition_point = std::partition_point(begin_queries, end_queries, partitioner); + + // Recurse left: any query points left of the split point may intersect + // with non-overlapping intervals fully-left of our split point. + if (left_ != NULL) { + left_->ForEachIntervalContainingPoints(begin_queries, partition_point, cb); + } + + // Handle the query points < split_point + // + // split_point_ + // | + // [------] \ + // [-------] | overlapping_by_asc_left_ + // [--------] / + // Q Q Q + // ^ ^ \___ not handled (right of split_point_) + // | | + // \___\___ these points will be handled here + // + + // Lower bound of query points still relevant. + auto rem_queries = begin_queries; + for (const interval_type &interval : overlapping_by_asc_left_) { + const auto& interval_left = Traits::get_left(interval); + // Find those query points which are right of the left side of the interval. + // 'first_match' here is the first query point >= interval_left. + // Complexity: O(log(num_queries)) + // + // TODO(todd): The non-batched implementation is O(log(num_intervals) * num_queries) + // whereas this loop ends up O(num_intervals * log(num_queries)). So, for + // small numbers of queries this is not the fastest way to structure these loops. + auto first_match = std::partition_point( + rem_queries, partition_point, + [&](const QueryPointType& query_point) { + return Traits::compare(query_point, interval_left) < 0; + }); + for (auto it = first_match; it != partition_point; ++it) { + cb(*it, interval); + } + // Since the intervals are sorted in ascending-left order, we can start + // the search for the next interval at the first match in this interval. + // (any query point which was left of the current interval will also be left + // of all future intervals). + rem_queries = std::move(first_match); + } + + // Handle the query points >= split_point + // + // split_point_ + // | + // [--------] \ + // [-------] | overlapping_by_desc_right_ + // [------] / + // Q Q Q + // | \______\___ these points will be handled here + // | + // \___ not handled (left of split_point_) + + // Upper bound of query points still relevant. + rem_queries = end_queries; + for (const interval_type &interval : overlapping_by_desc_right_) { + const auto& interval_right = Traits::get_right(interval); + // Find the first query point which is > the right side of the interval. + auto first_non_match = std::partition_point( + partition_point, rem_queries, + [&](const QueryPointType& query_point) { + return Traits::compare(query_point, interval_right) <= 0; + }); + for (auto it = partition_point; it != first_non_match; ++it) { + cb(*it, interval); + } + // Same logic as above: if a query point was fully right of 'interval', + // then it will be fully right of all following intervals because they are + // sorted by descending-right. + rem_queries = std::move(first_non_match); + } + + if (right_ != NULL) { + while (partition_point != end_queries && + Traits::compare(*partition_point, split_point_) == 0) { + ++partition_point; + } + right_->ForEachIntervalContainingPoints(partition_point, end_queries, cb); + } +} + +template<class Traits> +template<class QueryPointType> +void ITNode<Traits>::FindContainingPoint(const QueryPointType &query, + IntervalVector *results) const { + int cmp = Traits::compare(query, split_point_); + if (cmp < 0) { + // None of the intervals in right_ may intersect this. + if (left_ != NULL) { + left_->FindContainingPoint(query, results); + } + + // Any intervals which start before the query point and overlap the split point + // must therefore contain the query point. + auto p = std::partition_point( + overlapping_by_asc_left_.cbegin(), overlapping_by_asc_left_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_left(interval), query) <= 0; + }); + results->insert(results->end(), overlapping_by_asc_left_.cbegin(), p); + } else if (cmp > 0) { + // None of the intervals in left_ may intersect this. + if (right_ != NULL) { + right_->FindContainingPoint(query, results); + } + + // Any intervals which end after the query point and overlap the split point + // must therefore contain the query point. + auto p = std::partition_point( + overlapping_by_desc_right_.cbegin(), overlapping_by_desc_right_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_right(interval), query) >= 0; + }); + results->insert(results->end(), overlapping_by_desc_right_.cbegin(), p); + } else { + DCHECK_EQ(cmp, 0); + // The query is exactly our split point -- in this case we've already got + // the computed list of overlapping intervals. + results->insert(results->end(), overlapping_by_asc_left_.begin(), + overlapping_by_asc_left_.end()); + } +} + +template<class Traits> +void ITNode<Traits>::FindIntersectingInterval(const interval_type &query, + IntervalVector *results) const { + if (Traits::compare(Traits::get_right(query), split_point_) < 0) { + // The interval is fully left of the split point. So, it may not overlap + // with any in 'right_' + if (left_ != NULL) { + left_->FindIntersectingInterval(query, results); + } + + // Any intervals whose left edge is <= the query interval's right edge + // intersect the query interval. 'std::partition_point' returns the first + // such interval which does not meet that criterion, so we insert all + // up to that point. + auto first_greater = std::partition_point( + overlapping_by_asc_left_.cbegin(), overlapping_by_asc_left_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_left(interval), Traits::get_right(query)) <= 0; + }); + results->insert(results->end(), overlapping_by_asc_left_.cbegin(), first_greater); + } else if (Traits::compare(Traits::get_left(query), split_point_) > 0) { + // The interval is fully right of the split point. So, it may not overlap + // with any in 'left_'. + if (right_ != NULL) { + right_->FindIntersectingInterval(query, results); + } + + // Any intervals whose right edge is >= the query interval's left edge + // intersect the query interval. 'std::partition_point' returns the first + // such interval which does not meet that criterion, so we insert all + // up to that point. + auto first_lesser = std::partition_point( + overlapping_by_desc_right_.cbegin(), overlapping_by_desc_right_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_right(interval), Traits::get_left(query)) >= 0; + }); + results->insert(results->end(), overlapping_by_desc_right_.cbegin(), first_lesser); + } else { + // The query interval contains the split point. Therefore all other intervals + // which also contain the split point are intersecting. + results->insert(results->end(), overlapping_by_asc_left_.begin(), + overlapping_by_asc_left_.end()); + + // The query interval may _also_ intersect some in either child. + if (left_ != NULL) { + left_->FindIntersectingInterval(query, results); + } + if (right_ != NULL) { + right_->FindIntersectingInterval(query, results); + } + } +} + + +} // namespace interval_tree_internal + +} // namespace kudu + +#endif
http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/interval_tree-test.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/interval_tree-test.cc b/be/src/kudu/util/interval_tree-test.cc new file mode 100644 index 0000000..34a1d07 --- /dev/null +++ b/be/src/kudu/util/interval_tree-test.cc @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// All rights reserved. + +#include <stdlib.h> + +#include <algorithm> +#include <tuple> + +#include <boost/optional.hpp> +#include <glog/stl_logging.h> +#include <gtest/gtest.h> + +#include "kudu/gutil/stringprintf.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/util/interval_tree.h" +#include "kudu/util/interval_tree-inl.h" +#include "kudu/util/test_util.h" + +using std::vector; +using std::string; +using strings::Substitute; + +namespace kudu { + +// Test harness. +class TestIntervalTree : public KuduTest { +}; + +// Simple interval class for integer intervals. +struct IntInterval { + IntInterval(int left, int right, int id = -1) + : left(left), + right(right), + id(id) { + } + + bool Intersects(const IntInterval &other) const { + if (other.left > right) return false; + if (left > other.right) return false; + return true; + } + + string ToString() const { + return strings::Substitute("[$0, $1]($2) ", left, right, id); + } + + int left, right, id; +}; + +// A wrapper around an int which can be compared with IntTraits::compare() +// but also can keep a counter of how many times it has been compared. Used +// for TestBigO below. +struct CountingQueryPoint { + explicit CountingQueryPoint(int v) + : val(v), + count(new int(0)) { + } + + int val; + std::shared_ptr<int> count; +}; + +// Traits definition for intervals made up of ints on either end. +struct IntTraits { + typedef int point_type; + typedef IntInterval interval_type; + static point_type get_left(const IntInterval &x) { + return x.left; + } + static point_type get_right(const IntInterval &x) { + return x.right; + } + static int compare(int a, int b) { + if (a < b) return -1; + if (a > b) return 1; + return 0; + } + + static int compare(const CountingQueryPoint& q, int b) { + (*q.count)++; + return compare(q.val, b); + } + static int compare(int a, const CountingQueryPoint& b) { + return -compare(b, a); + } + +}; + +// Compare intervals in an arbitrary but consistent way - this is only +// used for verifying that the two algorithms come up with the same results. +// It's not necessary to define this to use an interval tree. +static bool CompareIntervals(const IntInterval &a, const IntInterval &b) { + return std::make_tuple(a.left, a.right, a.id) < + std::make_tuple(b.left, b.right, b.id); +} + +// Stringify a list of int intervals, for easy test error reporting. +static string Stringify(const vector<IntInterval> &intervals) { + string ret; + bool first = true; + for (const IntInterval &interval : intervals) { + if (!first) { + ret.append(","); + } + ret.append(interval.ToString()); + } + return ret; +} + +// Find any intervals in 'intervals' which contain 'query_point' by brute force. +static void FindContainingBruteForce(const vector<IntInterval> &intervals, + int query_point, + vector<IntInterval> *results) { + for (const IntInterval &i : intervals) { + if (query_point >= i.left && query_point <= i.right) { + results->push_back(i); + } + } +} + + +// Find any intervals in 'intervals' which intersect 'query_interval' by brute force. +static void FindIntersectingBruteForce(const vector<IntInterval> &intervals, + IntInterval query_interval, + vector<IntInterval> *results) { + for (const IntInterval &i : intervals) { + if (query_interval.Intersects(i)) { + results->push_back(i); + } + } +} + + +// Verify that IntervalTree::FindContainingPoint yields the same results as the naive +// brute-force O(n) algorithm. +static void VerifyFindContainingPoint(const vector<IntInterval> all_intervals, + const IntervalTree<IntTraits> &tree, + int query_point) { + vector<IntInterval> results; + tree.FindContainingPoint(query_point, &results); + std::sort(results.begin(), results.end(), CompareIntervals); + + vector<IntInterval> brute_force; + FindContainingBruteForce(all_intervals, query_point, &brute_force); + std::sort(brute_force.begin(), brute_force.end(), CompareIntervals); + + SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" (q=%d)", query_point)); + EXPECT_EQ(Stringify(brute_force), Stringify(results)); +} + +// Verify that IntervalTree::FindIntersectingInterval yields the same results as the naive +// brute-force O(n) algorithm. +static void VerifyFindIntersectingInterval(const vector<IntInterval> all_intervals, + const IntervalTree<IntTraits> &tree, + const IntInterval &query_interval) { + vector<IntInterval> results; + tree.FindIntersectingInterval(query_interval, &results); + std::sort(results.begin(), results.end(), CompareIntervals); + + vector<IntInterval> brute_force; + FindIntersectingBruteForce(all_intervals, query_interval, &brute_force); + std::sort(brute_force.begin(), brute_force.end(), CompareIntervals); + + SCOPED_TRACE(Stringify(all_intervals) + + StringPrintf(" (q=[%d,%d])", query_interval.left, query_interval.right)); + EXPECT_EQ(Stringify(brute_force), Stringify(results)); +} + +static vector<IntInterval> CreateRandomIntervals(int n = 100) { + vector<IntInterval> intervals; + for (int i = 0; i < n; i++) { + int l = rand() % 100; // NOLINT(runtime/threadsafe_fn) + int r = l + rand() % 20; // NOLINT(runtime/threadsafe_fn) + intervals.push_back(IntInterval(l, r, i)); + } + return intervals; +} + +TEST_F(TestIntervalTree, TestBasic) { + vector<IntInterval> intervals; + intervals.push_back(IntInterval(1, 2, 1)); + intervals.push_back(IntInterval(3, 4, 2)); + intervals.push_back(IntInterval(1, 4, 3)); + IntervalTree<IntTraits> t(intervals); + + for (int i = 0; i <= 5; i++) { + VerifyFindContainingPoint(intervals, t, i); + + for (int j = i; j <= 5; j++) { + VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j, 0)); + } + } +} + +TEST_F(TestIntervalTree, TestRandomized) { + SeedRandom(); + + // Generate 100 random intervals spanning 0-200 and build an interval tree from them. + vector<IntInterval> intervals = CreateRandomIntervals(); + IntervalTree<IntTraits> t(intervals); + + // Test that we get the correct result on every possible query. + for (int i = -1; i < 201; i++) { + VerifyFindContainingPoint(intervals, t, i); + } + + // Test that we get the correct result for random intervals + for (int i = 0; i < 100; i++) { + int l = rand() % 100; // NOLINT(runtime/threadsafe_fn) + int r = l + rand() % 100; // NOLINT(runtime/threadsafe_fn) + VerifyFindIntersectingInterval(intervals, t, IntInterval(l, r)); + } +} + +TEST_F(TestIntervalTree, TestEmpty) { + vector<IntInterval> empty; + IntervalTree<IntTraits> t(empty); + + VerifyFindContainingPoint(empty, t, 1); + VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2, 0)); +} + +TEST_F(TestIntervalTree, TestBigO) { +#ifndef NDEBUG + LOG(WARNING) << "big-O results are not valid if DCHECK is enabled"; + return; +#endif + SeedRandom(); + + LOG(INFO) << "num_int\tnum_q\tresults\tsimple\tbatch"; + for (int num_intervals = 1; num_intervals < 2000; num_intervals *= 2) { + vector<IntInterval> intervals = CreateRandomIntervals(num_intervals); + IntervalTree<IntTraits> t(intervals); + for (int num_queries = 1; num_queries < 2000; num_queries *= 2) { + vector<CountingQueryPoint> queries; + for (int i = 0; i < num_queries; i++) { + queries.emplace_back(rand() % 100); + } + std::sort(queries.begin(), queries.end(), + [](const CountingQueryPoint& a, + const CountingQueryPoint& b) { + return a.val < b.val; + }); + + // Test using batch algorithm. + int num_results_batch = 0; + t.ForEachIntervalContainingPoints( + queries, + [&](CountingQueryPoint query_point, const IntInterval& interval) { + num_results_batch++; + }); + int num_comparisons_batch = 0; + for (const auto& q : queries) { + num_comparisons_batch += *q.count; + *q.count = 0; + } + + // Test using one-by-one queries. + int num_results_simple = 0; + for (auto& q : queries) { + vector<IntInterval> intervals; + t.FindContainingPoint(q, &intervals); + num_results_simple += intervals.size(); + } + int num_comparisons_simple = 0; + for (const auto& q : queries) { + num_comparisons_simple += *q.count; + } + ASSERT_EQ(num_results_simple, num_results_batch); + + LOG(INFO) << num_intervals << "\t" << num_queries << "\t" << num_results_simple << "\t" + << num_comparisons_simple << "\t" << num_comparisons_batch; + } + } +} + +TEST_F(TestIntervalTree, TestMultiQuery) { + SeedRandom(); + const int kNumQueries = 1; + vector<IntInterval> intervals = CreateRandomIntervals(10); + IntervalTree<IntTraits> t(intervals); + + // Generate random queries. + vector<int> queries; + for (int i = 0; i < kNumQueries; i++) { + queries.push_back(rand() % 100); + } + std::sort(queries.begin(), queries.end()); + + vector<pair<string, int>> results_simple; + for (int q : queries) { + vector<IntInterval> intervals; + t.FindContainingPoint(q, &intervals); + for (const auto& interval : intervals) { + results_simple.emplace_back(interval.ToString(), q); + } + } + + vector<pair<string, int>> results_batch; + t.ForEachIntervalContainingPoints( + queries, + [&](int query_point, const IntInterval& interval) { + results_batch.emplace_back(interval.ToString(), query_point); + }); + + // Check the property that, when the batch query points are in sorted order, + // the results are grouped by interval, and within each interval, sorted by + // query point. Each interval may have at most two groups. + boost::optional<pair<string, int>> prev = boost::none; + std::map<string, int> intervals_seen; + for (int i = 0; i < results_batch.size(); i++) { + const auto& cur = results_batch[i]; + // If it's another query point hitting the same interval, + // make sure the query points are returned in order. + if (prev && prev->first == cur.first) { + EXPECT_GE(cur.second, prev->second) << prev->first; + } else { + // It's the start of a new interval's data. Make sure that we don't + // see the same interval twice. + EXPECT_LE(++intervals_seen[cur.first], 2) + << "Saw more than two groups for interval " << cur.first; + } + prev = cur; + } + + std::sort(results_simple.begin(), results_simple.end()); + std::sort(results_batch.begin(), results_batch.end()); + ASSERT_EQ(results_simple, results_batch); +} + +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/interval_tree.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/interval_tree.h b/be/src/kudu/util/interval_tree.h new file mode 100644 index 0000000..a677528 --- /dev/null +++ b/be/src/kudu/util/interval_tree.h @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Implements an Interval Tree. See http://en.wikipedia.org/wiki/Interval_tree +// or CLRS for a full description of the data structure. +// +// Callers of this class should also include interval_tree-inl.h for function +// definitions. +#ifndef KUDU_UTIL_INTERVAL_TREE_H +#define KUDU_UTIL_INTERVAL_TREE_H + +#include <glog/logging.h> + +#include <vector> + +#include "kudu/gutil/macros.h" + +namespace kudu { + +namespace interval_tree_internal { +template<class Traits> +class ITNode; +} + +// Implements an Interval Tree. +// +// An Interval Tree is a data structure which stores a set of intervals and supports +// efficient searches to determine which intervals in that set overlap a query +// point or interval. These operations are O(lg n + k) where 'n' is the number of +// intervals in the tree and 'k' is the number of results returned for a given query. +// +// This particular implementation is a static tree -- intervals may not be added or +// removed once the tree is instantiated. +// +// This class also assumes that all intervals are "closed" intervals -- the intervals +// are inclusive of their start and end points. +// +// The Traits class should have the following members: +// Traits::point_type +// a typedef for what a "point" in the range is +// +// Traits::interval_type +// a typedef for an interval +// +// static point_type get_left(const interval_type &) +// static point_type get_right(const interval_type &) +// accessors which fetch the left and right bound of the interval, respectively. +// +// static int compare(const point_type &a, const point_type &b) +// return < 0 if a < b, 0 if a == b, > 0 if a > b +// +// See interval_tree-test.cc for an example Traits class for 'int' ranges. +template<class Traits> +class IntervalTree { + private: + // Import types from the traits class to make code more readable. + typedef typename Traits::interval_type interval_type; + typedef typename Traits::point_type point_type; + + // And some convenience types. + typedef std::vector<interval_type> IntervalVector; + typedef interval_tree_internal::ITNode<Traits> node_type; + + public: + // Construct an Interval Tree containing the given set of intervals. + explicit IntervalTree(const IntervalVector &intervals); + + ~IntervalTree(); + + // Find all intervals in the tree which contain the query point. + // The resulting intervals are added to the 'results' vector. + // The vector is not cleared first. + // + // NOTE: 'QueryPointType' is usually point_type, but can be any other + // type for which there exists the appropriate Traits::Compare(...) method. + template<class QueryPointType> + void FindContainingPoint(const QueryPointType &query, + IntervalVector *results) const; + + // For each of the query points in the STL container 'queries', find all + // intervals in the tree which may contain those points. Calls 'cb(point, interval)' + // for each such interval. + // + // The points in the query container must be comparable to 'point_type' + // using Traits::Compare(). + // + // The implementation sequences the calls to 'cb' with the following guarantees: + // 1) all of the results corresponding to a given interval will be yielded in at + // most two "groups" of calls (i.e. sub-sequences of calls with the same interval). + // 2) within each "group" of calls, the query points will be in ascending order. + // + // For example, the callback sequence may be: + // + // cb(q1, interval_1) - + // cb(q2, interval_1) | first group of interval_1 + // cb(q6, interval_1) | + // cb(q7, interval_1) - + // + // cb(q2, interval_2) - + // cb(q3, interval_2) | first group of interval_2 + // cb(q4, interval_2) - + // + // cb(q3, interval_1) - + // cb(q4, interval_1) | second group of interval_1 + // cb(q5, interval_1) - + // + // cb(q2, interval_3) - + // cb(q3, interval_3) | first group of interval_3 + // cb(q4, interval_3) - + // + // cb(q5, interval_2) - + // cb(q6, interval_2) | second group of interval_2 + // cb(q7, interval_2) - + // + // REQUIRES: The input points must be pre-sorted or else this will return invalid + // results. + template<class Callback, class QueryContainer> + void ForEachIntervalContainingPoints(const QueryContainer& queries, + const Callback& cb) const; + + // Find all intervals in the tree which intersect the given interval. + // The resulting intervals are added to the 'results' vector. + // The vector is not cleared first. + void FindIntersectingInterval(const interval_type &query, + IntervalVector *results) const; + private: + static void Partition(const IntervalVector &in, + point_type *split_point, + IntervalVector *left, + IntervalVector *overlapping, + IntervalVector *right); + + // Create a node containing the given intervals, recursively splitting down the tree. + static node_type *CreateNode(const IntervalVector &intervals); + + node_type *root_; + + DISALLOW_COPY_AND_ASSIGN(IntervalTree); +}; + + +} // namespace kudu + +#endif http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonreader-test.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonreader-test.cc b/be/src/kudu/util/jsonreader-test.cc new file mode 100644 index 0000000..3c54cc7 --- /dev/null +++ b/be/src/kudu/util/jsonreader-test.cc @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <string> +#include <vector> + +#include <gtest/gtest.h> + +#include "kudu/gutil/integral_types.h" +#include "kudu/gutil/strings/substitute.h" +#include "kudu/util/jsonreader.h" +#include "kudu/util/test_macros.h" + +using rapidjson::Value; +using std::string; +using std::vector; +using strings::Substitute; + +namespace kudu { + +TEST(JsonReaderTest, Corrupt) { + JsonReader r(""); + Status s = r.Init(); + ASSERT_TRUE(s.IsCorruption()); + ASSERT_STR_CONTAINS( + s.ToString(), "JSON text is corrupt: Text only contains white space(s)"); +} + +TEST(JsonReaderTest, Empty) { + JsonReader r("{}"); + ASSERT_OK(r.Init()); + JsonReader r2("[]"); + ASSERT_OK(r2.Init()); + + // Not found. + ASSERT_TRUE(r.ExtractInt32(r.root(), "foo", nullptr).IsNotFound()); + ASSERT_TRUE(r.ExtractInt64(r.root(), "foo", nullptr).IsNotFound()); + ASSERT_TRUE(r.ExtractString(r.root(), "foo", nullptr).IsNotFound()); + ASSERT_TRUE(r.ExtractObject(r.root(), "foo", nullptr).IsNotFound()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "foo", nullptr).IsNotFound()); +} + +TEST(JsonReaderTest, Basic) { + JsonReader r("{ \"foo\" : \"bar\" }"); + ASSERT_OK(r.Init()); + string foo; + ASSERT_OK(r.ExtractString(r.root(), "foo", &foo)); + ASSERT_EQ("bar", foo); + + // Bad types. + ASSERT_TRUE(r.ExtractInt32(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractInt64(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "foo", nullptr).IsInvalidArgument()); +} + +TEST(JsonReaderTest, LessBasic) { + string doc = Substitute( + "{ \"small\" : 1, \"big\" : $0, \"null\" : null, \"empty\" : \"\" }", kint64max); + JsonReader r(doc); + ASSERT_OK(r.Init()); + int32_t small; + ASSERT_OK(r.ExtractInt32(r.root(), "small", &small)); + ASSERT_EQ(1, small); + int64_t big; + ASSERT_OK(r.ExtractInt64(r.root(), "big", &big)); + ASSERT_EQ(kint64max, big); + string str; + ASSERT_OK(r.ExtractString(r.root(), "null", &str)); + ASSERT_EQ("", str); + ASSERT_OK(r.ExtractString(r.root(), "empty", &str)); + ASSERT_EQ("", str); + + // Bad types. + ASSERT_TRUE(r.ExtractString(r.root(), "small", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), "small", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "small", nullptr).IsInvalidArgument()); + + ASSERT_TRUE(r.ExtractInt32(r.root(), "big", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractString(r.root(), "big", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), "big", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "big", nullptr).IsInvalidArgument()); + + ASSERT_TRUE(r.ExtractInt32(r.root(), "null", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractInt64(r.root(), "null", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), "null", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "null", nullptr).IsInvalidArgument()); + + ASSERT_TRUE(r.ExtractInt32(r.root(), "empty", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractInt64(r.root(), "empty", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), "empty", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "empty", nullptr).IsInvalidArgument()); +} + +TEST(JsonReaderTest, Objects) { + JsonReader r("{ \"foo\" : { \"1\" : 1 } }"); + ASSERT_OK(r.Init()); + + const Value* foo = nullptr; + ASSERT_OK(r.ExtractObject(r.root(), "foo", &foo)); + ASSERT_TRUE(foo); + + int32_t one; + ASSERT_OK(r.ExtractInt32(foo, "1", &one)); + ASSERT_EQ(1, one); + + // Bad types. + ASSERT_TRUE(r.ExtractInt32(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractInt64(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractString(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObjectArray(r.root(), "foo", nullptr).IsInvalidArgument()); +} + +TEST(JsonReaderTest, TopLevelArray) { + JsonReader r("[ { \"name\" : \"foo\" }, { \"name\" : \"bar\" } ]"); + ASSERT_OK(r.Init()); + + vector<const Value*> objs; + ASSERT_OK(r.ExtractObjectArray(r.root(), nullptr, &objs)); + ASSERT_EQ(2, objs.size()); + string name; + ASSERT_OK(r.ExtractString(objs[0], "name", &name)); + ASSERT_EQ("foo", name); + ASSERT_OK(r.ExtractString(objs[1], "name", &name)); + ASSERT_EQ("bar", name); + + // Bad types. + ASSERT_TRUE(r.ExtractInt32(r.root(), nullptr, NULL).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractInt64(r.root(), nullptr, NULL).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractString(r.root(), nullptr, NULL).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), nullptr, NULL).IsInvalidArgument()); +} + +TEST(JsonReaderTest, NestedArray) { + JsonReader r("{ \"foo\" : [ { \"val\" : 0 }, { \"val\" : 1 }, { \"val\" : 2 } ] }"); + ASSERT_OK(r.Init()); + + vector<const Value*> foo; + ASSERT_OK(r.ExtractObjectArray(r.root(), "foo", &foo)); + ASSERT_EQ(3, foo.size()); + int i = 0; + for (const Value* v : foo) { + int32_t number; + ASSERT_OK(r.ExtractInt32(v, "val", &number)); + ASSERT_EQ(i, number); + i++; + } + + // Bad types. + ASSERT_TRUE(r.ExtractInt32(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractInt64(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractString(r.root(), "foo", nullptr).IsInvalidArgument()); + ASSERT_TRUE(r.ExtractObject(r.root(), "foo", nullptr).IsInvalidArgument()); +} + +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonreader.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonreader.cc b/be/src/kudu/util/jsonreader.cc new file mode 100644 index 0000000..e39761d --- /dev/null +++ b/be/src/kudu/util/jsonreader.cc @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/util/jsonreader.h" + +#include "kudu/gutil/strings/substitute.h" + +using rapidjson::Value; +using std::string; +using std::vector; +using strings::Substitute; + +namespace kudu { + +JsonReader::JsonReader(string text) : text_(std::move(text)) {} + +JsonReader::~JsonReader() { +} + +Status JsonReader::Init() { + document_.Parse<0>(text_.c_str()); + if (document_.HasParseError()) { + return Status::Corruption("JSON text is corrupt", document_.GetParseError()); + } + return Status::OK(); +} + +Status JsonReader::ExtractInt32(const Value* object, + const char* field, + int32_t* result) const { + const Value* val; + RETURN_NOT_OK(ExtractField(object, field, &val)); + if (PREDICT_FALSE(!val->IsInt())) { + return Status::InvalidArgument(Substitute( + "Wrong type during field extraction: expected int32 but got $0", + val->GetType())); + } + *result = val->GetUint(); + return Status::OK(); +} + +Status JsonReader::ExtractInt64(const Value* object, + const char* field, + int64_t* result) const { + const Value* val; + RETURN_NOT_OK(ExtractField(object, field, &val)); + if (PREDICT_FALSE(!val->IsInt64())) { + return Status::InvalidArgument(Substitute( + "Wrong type during field extraction: expected int64 but got $0", + val->GetType())); } + *result = val->GetUint64(); + return Status::OK(); +} + +Status JsonReader::ExtractString(const Value* object, + const char* field, + string* result) const { + const Value* val; + RETURN_NOT_OK(ExtractField(object, field, &val)); + if (PREDICT_FALSE(!val->IsString())) { + if (val->IsNull()) { + *result = ""; + return Status::OK(); + } + return Status::InvalidArgument(Substitute( + "Wrong type during field extraction: expected string but got $0", + val->GetType())); } + result->assign(val->GetString()); + return Status::OK(); +} + +Status JsonReader::ExtractObject(const Value* object, + const char* field, + const Value** result) const { + const Value* val; + RETURN_NOT_OK(ExtractField(object, field, &val)); + if (PREDICT_FALSE(!val->IsObject())) { + return Status::InvalidArgument(Substitute( + "Wrong type during field extraction: expected object but got $0", + val->GetType())); } + *result = val; + return Status::OK(); +} + +Status JsonReader::ExtractObjectArray(const Value* object, + const char* field, + vector<const Value*>* result) const { + const Value* val; + RETURN_NOT_OK(ExtractField(object, field, &val)); + if (PREDICT_FALSE(!val->IsArray())) { + return Status::InvalidArgument(Substitute( + "Wrong type during field extraction: expected object array but got $0", + val->GetType())); } + for (Value::ConstValueIterator iter = val->Begin(); iter != val->End(); ++iter) { + result->push_back(iter); + } + return Status::OK(); +} + +Status JsonReader::ExtractField(const Value* object, + const char* field, + const Value** result) const { + if (field && PREDICT_FALSE(!object->HasMember(field))) { + return Status::NotFound("Missing field", field); + } + *result = field ? &(*object)[field] : object; + return Status::OK(); +} + +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonreader.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonreader.h b/be/src/kudu/util/jsonreader.h new file mode 100644 index 0000000..2d9e982 --- /dev/null +++ b/be/src/kudu/util/jsonreader.h @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifndef KUDU_UTIL_JSONREADER_H_ +#define KUDU_UTIL_JSONREADER_H_ + +#include <stdint.h> +#include <string> +#include <vector> + +#include <rapidjson/document.h> + +#include "kudu/gutil/gscoped_ptr.h" +#include "kudu/gutil/macros.h" +#include "kudu/util/status.h" + +namespace kudu { + +// Wraps the JSON parsing functionality of rapidjson::Document. +// +// Unlike JsonWriter, this class does not hide rapidjson internals from +// clients. That's because there's just no easy way to implement object and +// array parsing otherwise. At most, this class aspires to be a simpler +// error-handling wrapper for reading and parsing. +class JsonReader { + public: + explicit JsonReader(std::string text); + ~JsonReader(); + + Status Init(); + + // Extractor methods. + // + // If 'field' is not NULL, will look for a field with that name in the + // given object, returning Status::NotFound if it cannot be found. If + // 'field' is NULL, will try to convert 'object' directly into the + // desire type. + + Status ExtractInt32(const rapidjson::Value* object, + const char* field, + int32_t* result) const; + + Status ExtractInt64(const rapidjson::Value* object, + const char* field, + int64_t* result) const; + + Status ExtractString(const rapidjson::Value* object, + const char* field, + std::string* result) const; + + // 'result' is only valid for as long as JsonReader is alive. + Status ExtractObject(const rapidjson::Value* object, + const char* field, + const rapidjson::Value** result) const; + + // 'result' is only valid for as long as JsonReader is alive. + Status ExtractObjectArray(const rapidjson::Value* object, + const char* field, + std::vector<const rapidjson::Value*>* result) const; + + const rapidjson::Value* root() const { return &document_; } + + private: + Status ExtractField(const rapidjson::Value* object, + const char* field, + const rapidjson::Value** result) const; + + std::string text_; + rapidjson::Document document_; + + DISALLOW_COPY_AND_ASSIGN(JsonReader); +}; + +} // namespace kudu + +#endif // KUDU_UTIL_JSONREADER_H_ http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonwriter-test.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonwriter-test.cc b/be/src/kudu/util/jsonwriter-test.cc new file mode 100644 index 0000000..08d54de --- /dev/null +++ b/be/src/kudu/util/jsonwriter-test.cc @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> + +#include "kudu/gutil/strings/substitute.h" +#include "kudu/util/jsonwriter.h" +#include "kudu/util/jsonwriter_test.pb.h" +#include "kudu/util/logging.h" +#include "kudu/util/test_util.h" + +using jsonwriter_test::TestAllTypes; + +namespace kudu { + +class TestJsonWriter : public KuduTest {}; + +TEST_F(TestJsonWriter, TestPBEmpty) { + TestAllTypes pb; + ASSERT_EQ("{}", JsonWriter::ToJson(pb, JsonWriter::PRETTY)); +} + +TEST_F(TestJsonWriter, TestPBAllFieldTypes) { + ASSERT_NE("", gflags::SetCommandLineOption("redact", "log")); + TestAllTypes pb; + pb.set_optional_int32(1); + pb.set_optional_int64(2); + pb.set_optional_uint32(3); + pb.set_optional_uint64(4); + pb.set_optional_sint32(5); + pb.set_optional_sint64(6); + pb.set_optional_fixed32(7); + pb.set_optional_fixed64(8); + pb.set_optional_sfixed32(9); + pb.set_optional_sfixed64(10); + pb.set_optional_float(11); + pb.set_optional_double(12); + pb.set_optional_bool(true); + pb.set_optional_string("hello world"); + pb.set_optional_redacted_string("secret!"); + pb.set_optional_nested_enum(TestAllTypes::FOO); + ASSERT_EQ("{\n" + " \"optional_int32\": 1,\n" + " \"optional_int64\": 2,\n" + " \"optional_uint32\": 3,\n" + " \"optional_uint64\": 4,\n" + " \"optional_sint32\": 5,\n" + " \"optional_sint64\": 6,\n" + " \"optional_fixed32\": 7,\n" + " \"optional_fixed64\": 8,\n" + " \"optional_sfixed32\": 9,\n" + " \"optional_sfixed64\": 10,\n" + " \"optional_float\": 11,\n" + " \"optional_double\": 12,\n" + " \"optional_bool\": true,\n" + " \"optional_string\": \"hello world\",\n" + " \"optional_redacted_string\": \"<redacted>\",\n" + " \"optional_nested_enum\": \"FOO\"\n" + "}", JsonWriter::ToJson(pb, JsonWriter::PRETTY)); + ASSERT_EQ("{" + "\"optional_int32\":1," + "\"optional_int64\":2," + "\"optional_uint32\":3," + "\"optional_uint64\":4," + "\"optional_sint32\":5," + "\"optional_sint64\":6," + "\"optional_fixed32\":7," + "\"optional_fixed64\":8," + "\"optional_sfixed32\":9," + "\"optional_sfixed64\":10," + "\"optional_float\":11," + "\"optional_double\":12," + "\"optional_bool\":true," + "\"optional_string\":\"hello world\"," + "\"optional_redacted_string\":\"<redacted>\"," + "\"optional_nested_enum\":\"FOO\"" + "}", JsonWriter::ToJson(pb, JsonWriter::COMPACT)); + +} + +TEST_F(TestJsonWriter, TestPBRepeatedPrimitives) { + ASSERT_NE("", gflags::SetCommandLineOption("redact", "log")); + TestAllTypes pb; + for (int i = 0; i <= 3; i++) { + pb.add_repeated_int32(i); + pb.add_repeated_string(strings::Substitute("hi $0", i)); + pb.add_repeated_redacted_string("secret!"); + pb.add_repeated_redacted_bytes("secret!"); + } + ASSERT_EQ("{\n" + " \"repeated_int32\": [\n" + " 0,\n" + " 1,\n" + " 2,\n" + " 3\n" + " ],\n" + " \"repeated_string\": [\n" + " \"hi 0\",\n" + " \"hi 1\",\n" + " \"hi 2\",\n" + " \"hi 3\"\n" + " ],\n" + " \"repeated_redacted_string\": [\n" + " \"<redacted>\",\n" + " \"<redacted>\",\n" + " \"<redacted>\",\n" + " \"<redacted>\"\n" + " ],\n" + " \"repeated_redacted_bytes\": [\n" + " \"<redacted>\",\n" + " \"<redacted>\",\n" + " \"<redacted>\",\n" + " \"<redacted>\"\n" + " ]\n" + "}", JsonWriter::ToJson(pb, JsonWriter::PRETTY)); + ASSERT_EQ("{\"repeated_int32\":[0,1,2,3]," + "\"repeated_string\":[\"hi 0\",\"hi 1\",\"hi 2\",\"hi 3\"]," + "\"repeated_redacted_string\":[\"<redacted>\",\"<redacted>\"," + "\"<redacted>\",\"<redacted>\"]," + "\"repeated_redacted_bytes\":[\"<redacted>\",\"<redacted>\"," + "\"<redacted>\",\"<redacted>\"]}", + JsonWriter::ToJson(pb, JsonWriter::COMPACT)); +} + +TEST_F(TestJsonWriter, TestPBNestedMessage) { + TestAllTypes pb; + pb.add_repeated_nested_message()->set_int_field(12345); + pb.mutable_optional_nested_message()->set_int_field(54321); + ASSERT_EQ("{\n" + " \"optional_nested_message\": {\n" + " \"int_field\": 54321\n" + " },\n" + " \"repeated_nested_message\": [\n" + " {\n" + " \"int_field\": 12345\n" + " }\n" + " ]\n" + "}", JsonWriter::ToJson(pb, JsonWriter::PRETTY)); + ASSERT_EQ("{\"optional_nested_message\":{\"int_field\":54321}," + "\"repeated_nested_message\":" + "[{\"int_field\":12345}]}", + JsonWriter::ToJson(pb, JsonWriter::COMPACT)); +} + +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonwriter.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonwriter.cc b/be/src/kudu/util/jsonwriter.cc new file mode 100644 index 0000000..ef9d5ba --- /dev/null +++ b/be/src/kudu/util/jsonwriter.cc @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "kudu/util/jsonwriter.h" + +#include <sstream> +#include <string> +#include <vector> + +#include <glog/logging.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/message.h> +#include <rapidjson/prettywriter.h> +#include <rapidjson/rapidjson.h> + +#include "kudu/util/logging.h" +#include "kudu/util/pb_util.pb.h" + +using google::protobuf::FieldDescriptor; +using google::protobuf::Message; +using google::protobuf::Reflection; + +using std::ostringstream; +using std::string; +using std::vector; + +namespace kudu { + +// Adapter to allow RapidJSON to write directly to a stringstream. +// Since Squeasel exposes a stringstream as its interface, this is needed to avoid overcopying. +class UTF8StringStreamBuffer { + public: + explicit UTF8StringStreamBuffer(std::ostringstream* out); + void Put(rapidjson::UTF8<>::Ch c); + private: + std::ostringstream* out_; +}; + +// rapidjson doesn't provide any common interface between the PrettyWriter and +// Writer classes. So, we create our own pure virtual interface here, and then +// use JsonWriterImpl<T> below to make the two different rapidjson implementations +// correspond to this subclass. +class JsonWriterIf { + public: + virtual void Null() = 0; + virtual void Bool(bool b) = 0; + virtual void Int(int i) = 0; + virtual void Uint(unsigned u) = 0; + virtual void Int64(int64_t i64) = 0; + virtual void Uint64(uint64_t u64) = 0; + virtual void Double(double d) = 0; + virtual void String(const char* str, size_t length) = 0; + virtual void String(const char* str) = 0; + virtual void String(const std::string& str) = 0; + + virtual void StartObject() = 0; + virtual void EndObject() = 0; + virtual void StartArray() = 0; + virtual void EndArray() = 0; + + virtual ~JsonWriterIf() {} +}; + +// Adapts the different rapidjson Writer implementations to our virtual +// interface above. +template<class T> +class JsonWriterImpl : public JsonWriterIf { + public: + explicit JsonWriterImpl(ostringstream* out); + + virtual void Null() OVERRIDE; + virtual void Bool(bool b) OVERRIDE; + virtual void Int(int i) OVERRIDE; + virtual void Uint(unsigned u) OVERRIDE; + virtual void Int64(int64_t i64) OVERRIDE; + virtual void Uint64(uint64_t u64) OVERRIDE; + virtual void Double(double d) OVERRIDE; + virtual void String(const char* str, size_t length) OVERRIDE; + virtual void String(const char* str) OVERRIDE; + virtual void String(const std::string& str) OVERRIDE; + + virtual void StartObject() OVERRIDE; + virtual void EndObject() OVERRIDE; + virtual void StartArray() OVERRIDE; + virtual void EndArray() OVERRIDE; + + private: + UTF8StringStreamBuffer stream_; + T writer_; + DISALLOW_COPY_AND_ASSIGN(JsonWriterImpl); +}; + +// +// JsonWriter +// + +typedef rapidjson::PrettyWriter<UTF8StringStreamBuffer> PrettyWriterClass; +typedef rapidjson::Writer<UTF8StringStreamBuffer> CompactWriterClass; + +JsonWriter::JsonWriter(ostringstream* out, Mode m) { + switch (m) { + case PRETTY: + impl_.reset(new JsonWriterImpl<PrettyWriterClass>(DCHECK_NOTNULL(out))); + break; + case COMPACT: + impl_.reset(new JsonWriterImpl<CompactWriterClass>(DCHECK_NOTNULL(out))); + break; + } +} +JsonWriter::~JsonWriter() { +} + +void JsonWriter::Null() { impl_->Null(); } +void JsonWriter::Bool(bool b) { impl_->Bool(b); } +void JsonWriter::Int(int i) { impl_->Int(i); } +void JsonWriter::Uint(unsigned u) { impl_->Uint(u); } +void JsonWriter::Int64(int64_t i64) { impl_->Int64(i64); } +void JsonWriter::Uint64(uint64_t u64) { impl_->Uint64(u64); } +void JsonWriter::Double(double d) { impl_->Double(d); } +void JsonWriter::String(const char* str, size_t length) { impl_->String(str, length); } +void JsonWriter::String(const char* str) { impl_->String(str); } +void JsonWriter::String(const string& str) { impl_->String(str); } +void JsonWriter::StartObject() { impl_->StartObject(); } +void JsonWriter::EndObject() { impl_->EndObject(); } +void JsonWriter::StartArray() { impl_->StartArray(); } +void JsonWriter::EndArray() { impl_->EndArray(); } + +// Specializations for common primitive metric types. +template<> void JsonWriter::Value(const bool& val) { + Bool(val); +} +template<> void JsonWriter::Value(const int32_t& val) { + Int(val); +} +template<> void JsonWriter::Value(const uint32_t& val) { + Uint(val); +} +template<> void JsonWriter::Value(const int64_t& val) { + Int64(val); +} +template<> void JsonWriter::Value(const uint64_t& val) { + Uint64(val); +} +template<> void JsonWriter::Value(const double& val) { + Double(val); +} +template<> void JsonWriter::Value(const string& val) { + String(val); +} + +#if defined(__APPLE__) +template<> void JsonWriter::Value(const size_t& val) { + Uint64(val); +} +#endif + +void JsonWriter::Protobuf(const Message& pb) { + const Reflection* reflection = pb.GetReflection(); + vector<const FieldDescriptor*> fields; + reflection->ListFields(pb, &fields); + + StartObject(); + for (const FieldDescriptor* field : fields) { + String(field->name()); + if (field->is_repeated()) { + StartArray(); + for (int i = 0; i < reflection->FieldSize(pb, field); i++) { + ProtobufRepeatedField(pb, field, i); + } + EndArray(); + } else { + ProtobufField(pb, field); + } + } + EndObject(); +} + +void JsonWriter::ProtobufField(const Message& pb, const FieldDescriptor* field) { + const Reflection* reflection = pb.GetReflection(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + Int(reflection->GetInt32(pb, field)); + break; + case FieldDescriptor::CPPTYPE_INT64: + Int64(reflection->GetInt64(pb, field)); + break; + case FieldDescriptor::CPPTYPE_UINT32: + Uint(reflection->GetUInt32(pb, field)); + break; + case FieldDescriptor::CPPTYPE_UINT64: + Uint64(reflection->GetUInt64(pb, field)); + break; + case FieldDescriptor::CPPTYPE_DOUBLE: + Double(reflection->GetDouble(pb, field)); + break; + case FieldDescriptor::CPPTYPE_FLOAT: + Double(reflection->GetFloat(pb, field)); + break; + case FieldDescriptor::CPPTYPE_BOOL: + Bool(reflection->GetBool(pb, field)); + break; + case FieldDescriptor::CPPTYPE_ENUM: + String(reflection->GetEnum(pb, field)->name()); + break; + case FieldDescriptor::CPPTYPE_STRING: + String(KUDU_MAYBE_REDACT_IF(field->options().GetExtension(REDACT), + reflection->GetString(pb, field))); + break; + case FieldDescriptor::CPPTYPE_MESSAGE: + Protobuf(reflection->GetMessage(pb, field)); + break; + default: + LOG(FATAL) << "Unknown cpp_type: " << field->cpp_type(); + } +} + +void JsonWriter::ProtobufRepeatedField(const Message& pb, const FieldDescriptor* field, int index) { + const Reflection* reflection = pb.GetReflection(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + Int(reflection->GetRepeatedInt32(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_INT64: + Int64(reflection->GetRepeatedInt64(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_UINT32: + Uint(reflection->GetRepeatedUInt32(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_UINT64: + Uint64(reflection->GetRepeatedUInt64(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_DOUBLE: + Double(reflection->GetRepeatedDouble(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_FLOAT: + Double(reflection->GetRepeatedFloat(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_BOOL: + Bool(reflection->GetRepeatedBool(pb, field, index)); + break; + case FieldDescriptor::CPPTYPE_ENUM: + String(reflection->GetRepeatedEnum(pb, field, index)->name()); + break; + case FieldDescriptor::CPPTYPE_STRING: + String(KUDU_MAYBE_REDACT_IF(field->options().GetExtension(REDACT), + reflection->GetRepeatedString(pb, field, index))); + break; + case FieldDescriptor::CPPTYPE_MESSAGE: + Protobuf(reflection->GetRepeatedMessage(pb, field, index)); + break; + default: + LOG(FATAL) << "Unknown cpp_type: " << field->cpp_type(); + } +} + +string JsonWriter::ToJson(const Message& pb, Mode mode) { + ostringstream stream; + JsonWriter writer(&stream, mode); + writer.Protobuf(pb); + return stream.str(); +} + +// +// UTF8StringStreamBuffer +// + +UTF8StringStreamBuffer::UTF8StringStreamBuffer(std::ostringstream* out) + : out_(DCHECK_NOTNULL(out)) { +} + +void UTF8StringStreamBuffer::Put(rapidjson::UTF8<>::Ch c) { + out_->put(c); +} + +// +// JsonWriterImpl: simply forward to the underlying implementation. +// + +template<class T> +JsonWriterImpl<T>::JsonWriterImpl(ostringstream* out) + : stream_(DCHECK_NOTNULL(out)), + writer_(stream_) { +} +template<class T> +void JsonWriterImpl<T>::Null() { writer_.Null(); } +template<class T> +void JsonWriterImpl<T>::Bool(bool b) { writer_.Bool(b); } +template<class T> +void JsonWriterImpl<T>::Int(int i) { writer_.Int(i); } +template<class T> +void JsonWriterImpl<T>::Uint(unsigned u) { writer_.Uint(u); } +template<class T> +void JsonWriterImpl<T>::Int64(int64_t i64) { writer_.Int64(i64); } +template<class T> +void JsonWriterImpl<T>::Uint64(uint64_t u64) { writer_.Uint64(u64); } +template<class T> +void JsonWriterImpl<T>::Double(double d) { writer_.Double(d); } +template<class T> +void JsonWriterImpl<T>::String(const char* str, size_t length) { writer_.String(str, length); } +template<class T> +void JsonWriterImpl<T>::String(const char* str) { writer_.String(str); } +template<class T> +void JsonWriterImpl<T>::String(const string& str) { writer_.String(str.c_str(), str.length()); } +template<class T> +void JsonWriterImpl<T>::StartObject() { writer_.StartObject(); } +template<class T> +void JsonWriterImpl<T>::EndObject() { writer_.EndObject(); } +template<class T> +void JsonWriterImpl<T>::StartArray() { writer_.StartArray(); } +template<class T> +void JsonWriterImpl<T>::EndArray() { writer_.EndArray(); } + +} // namespace kudu http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonwriter.h ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonwriter.h b/be/src/kudu/util/jsonwriter.h new file mode 100644 index 0000000..d3fb604 --- /dev/null +++ b/be/src/kudu/util/jsonwriter.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifndef KUDU_UTIL_JSONWRITER_H +#define KUDU_UTIL_JSONWRITER_H + +#include <inttypes.h> + +#include <memory> +#include <string> + +#include "kudu/gutil/macros.h" + +namespace google { +namespace protobuf { +class Message; +class FieldDescriptor; +} // namespace protobuf +} // namespace google + +namespace kudu { + +class JsonWriterIf; + +// Acts as a pimpl for rapidjson so that not all metrics users must bring in the +// rapidjson library, which is template-based and therefore hard to forward-declare. +// +// This class implements all the methods of rapidjson::JsonWriter, plus an +// additional convenience method for String(std::string). +// +// We take an instance of std::stringstream in the constructor because Mongoose / Squeasel +// uses std::stringstream for output buffering. +class JsonWriter { + public: + enum Mode { + // Pretty-print the JSON, with nice indentation, newlines, etc. + PRETTY, + // Print the JSON as compactly as possible. + COMPACT + }; + + JsonWriter(std::ostringstream* out, Mode mode); + ~JsonWriter(); + + void Null(); + void Bool(bool b); + void Int(int i); + void Uint(unsigned u); + void Int64(int64_t i64); + void Uint64(uint64_t u64); + void Double(double d); + void String(const char* str, size_t length); + void String(const char* str); + void String(const std::string& str); + + // Convert the given protobuf message to JSON. + // The output respects redaction for 'string' and 'bytes' fields. + void Protobuf(const google::protobuf::Message& message); + + template<typename T> + void Value(const T& val); + + void StartObject(); + void EndObject(); + void StartArray(); + void EndArray(); + + // Convert the given protobuf to JSON format. + static std::string ToJson(const google::protobuf::Message& pb, + Mode mode); + + private: + void ProtobufField(const google::protobuf::Message& pb, + const google::protobuf::FieldDescriptor* field); + void ProtobufRepeatedField(const google::protobuf::Message& pb, + const google::protobuf::FieldDescriptor* field, + int index); + + std::unique_ptr<JsonWriterIf> impl_; + DISALLOW_COPY_AND_ASSIGN(JsonWriter); +}; + +} // namespace kudu + +#endif // KUDU_UTIL_JSONWRITER_H http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/jsonwriter_test.proto ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/jsonwriter_test.proto b/be/src/kudu/util/jsonwriter_test.proto new file mode 100644 index 0000000..b6f0300 --- /dev/null +++ b/be/src/kudu/util/jsonwriter_test.proto @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +syntax = "proto2"; +package jsonwriter_test; + +import "kudu/util/pb_util.proto"; + +// This proto includes every type of field in both singular and repeated +// forms. This is mostly copied from 'unittest.proto' in the protobuf source +// (hence the odd field numbers which skip some). +message TestAllTypes { + message NestedMessage { + optional int32 int_field = 1; + } + + enum NestedEnum { + FOO = 1; + BAR = 2; + BAZ = 3; + } + + // Singular + optional int32 optional_int32 = 1; + optional int64 optional_int64 = 2; + optional uint32 optional_uint32 = 3; + optional uint64 optional_uint64 = 4; + optional sint32 optional_sint32 = 5; + optional sint64 optional_sint64 = 6; + optional fixed32 optional_fixed32 = 7; + optional fixed64 optional_fixed64 = 8; + optional sfixed32 optional_sfixed32 = 9; + optional sfixed64 optional_sfixed64 = 10; + optional float optional_float = 11; + optional double optional_double = 12; + optional bool optional_bool = 13; + optional string optional_string = 14; + optional string optional_redacted_string = 15 [ (kudu.REDACT) = true ]; + optional bytes optional_bytes = 16; + optional bytes optional_redacted_bytes = 17 [ (kudu.REDACT) = true ]; + + optional NestedMessage optional_nested_message = 18; + optional NestedEnum optional_nested_enum = 21; + + // Repeated + repeated int32 repeated_int32 = 31; + repeated int64 repeated_int64 = 32; + repeated uint32 repeated_uint32 = 33; + repeated uint64 repeated_uint64 = 34; + repeated sint32 repeated_sint32 = 35; + repeated sint64 repeated_sint64 = 36; + repeated fixed32 repeated_fixed32 = 37; + repeated fixed64 repeated_fixed64 = 38; + repeated sfixed32 repeated_sfixed32 = 39; + repeated sfixed64 repeated_sfixed64 = 40; + repeated float repeated_float = 41; + repeated double repeated_double = 42; + repeated bool repeated_bool = 43; + repeated string repeated_string = 44; + repeated bytes repeated_bytes = 45; + repeated string repeated_redacted_string = 46 [ (kudu.REDACT) = true ]; + repeated string repeated_redacted_bytes = 47 [ (kudu.REDACT) = true ]; + + repeated NestedMessage repeated_nested_message = 48; + repeated NestedEnum repeated_nested_enum = 51; +} http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d6abb29d/be/src/kudu/util/kernel_stack_watchdog.cc ---------------------------------------------------------------------- diff --git a/be/src/kudu/util/kernel_stack_watchdog.cc b/be/src/kudu/util/kernel_stack_watchdog.cc new file mode 100644 index 0000000..829431f --- /dev/null +++ b/be/src/kudu/util/kernel_stack_watchdog.cc @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "kudu/util/kernel_stack_watchdog.h" + +#include <boost/bind.hpp> +#include <glog/logging.h> +#include <gflags/gflags.h> +#include <string> + +#include "kudu/util/debug-util.h" +#include "kudu/util/debug/leakcheck_disabler.h" +#include "kudu/util/env.h" +#include "kudu/util/faststring.h" +#include "kudu/util/flag_tags.h" +#include "kudu/util/thread.h" +#include "kudu/util/status.h" +#include "kudu/gutil/map-util.h" +#include "kudu/gutil/strings/substitute.h" + +DEFINE_int32(hung_task_check_interval_ms, 200, + "Number of milliseconds in between checks for hung threads"); +TAG_FLAG(hung_task_check_interval_ms, hidden); + +using std::lock_guard; +using strings::Substitute; + +namespace kudu { + +DEFINE_STATIC_THREAD_LOCAL(KernelStackWatchdog::TLS, + KernelStackWatchdog, tls_); + +KernelStackWatchdog::KernelStackWatchdog() + : log_collector_(nullptr), + finish_(1) { + + // During creation of the stack watchdog thread, we need to disable using + // the stack watchdog itself. Otherwise, the 'StartThread' function will + // try to call back into initializing the stack watchdog, and will self-deadlock. + CHECK_OK(Thread::CreateWithFlags( + "kernel-watchdog", "kernel-watcher", + boost::bind(&KernelStackWatchdog::RunThread, this), + Thread::NO_STACK_WATCHDOG, + &thread_)); +} + +KernelStackWatchdog::~KernelStackWatchdog() { + finish_.CountDown(); + CHECK_OK(ThreadJoiner(thread_.get()).Join()); +} + +void KernelStackWatchdog::SaveLogsForTests(bool save_logs) { + lock_guard<simple_spinlock> l(log_lock_); + if (save_logs) { + log_collector_.reset(new vector<string>()); + } else { + log_collector_.reset(); + } +} + +vector<string> KernelStackWatchdog::LoggedMessagesForTests() const { + lock_guard<simple_spinlock> l(log_lock_); + CHECK(log_collector_) << "Must call SaveLogsForTests(true) first"; + return *log_collector_; +} + +void KernelStackWatchdog::Register(TLS* tls) { + int64_t tid = Thread::CurrentThreadId(); + lock_guard<simple_spinlock> l(tls_lock_); + InsertOrDie(&tls_by_tid_, tid, tls); +} + +void KernelStackWatchdog::Unregister() { + int64_t tid = Thread::CurrentThreadId(); + MutexLock l(unregister_lock_); + lock_guard<simple_spinlock> l2(tls_lock_); + CHECK(tls_by_tid_.erase(tid)); +} + +Status GetKernelStack(pid_t p, string* ret) { + faststring buf; + RETURN_NOT_OK(ReadFileToString(Env::Default(), Substitute("/proc/$0/stack", p), &buf)); + *ret = buf.ToString(); + return Status::OK(); +} + +void KernelStackWatchdog::RunThread() { + while (true) { + MonoDelta delta = MonoDelta::FromMilliseconds(FLAGS_hung_task_check_interval_ms); + if (finish_.WaitFor(delta)) { + // Watchdog exiting. + break; + } + + // Prevent threads from unregistering between the snapshot loop and the sending of + // signals. This makes it safe for us to access their TLS. We might delay the thread + // exit a bit, but it would be unusual for any code to block on a thread exit, whereas + // it's relatively important for threads to _start_ quickly. + MutexLock l(unregister_lock_); + + // Take the snapshot of the thread information under a short lock. + // + // 'lock_' prevents new threads from starting, so we don't want to do any lengthy work + // (such as gathering stack traces) under this lock. + TLSMap tls_map_copy; + { + lock_guard<simple_spinlock> l(tls_lock_); + tls_map_copy = tls_by_tid_; + } + + MicrosecondsInt64 now = GetMonoTimeMicros(); + for (const auto& entry : tls_map_copy) { + pid_t p = entry.first; + TLS::Data* tls = &entry.second->data_; + TLS::Data tls_copy; + tls->SnapshotCopy(&tls_copy); + for (int i = 0; i < tls_copy.depth_; i++) { + const TLS::Frame* frame = &tls_copy.frames_[i]; + + int paused_ms = (now - frame->start_time_) / 1000; + if (paused_ms > frame->threshold_ms_) { + string kernel_stack; + Status s = GetKernelStack(p, &kernel_stack); + if (!s.ok()) { + // Can't read the kernel stack of the pid, just ignore it. + kernel_stack = "(could not read kernel stack)"; + } + + string user_stack = DumpThreadStack(p); + + lock_guard<simple_spinlock> l(log_lock_); + LOG_STRING(WARNING, log_collector_.get()) + << "Thread " << p << " stuck at " << frame->status_ + << " for " << paused_ms << "ms" << ":\n" + << "Kernel stack:\n" << kernel_stack << "\n" + << "User stack:\n" << user_stack; + } + } + } + } +} + +KernelStackWatchdog::TLS* KernelStackWatchdog::GetTLS() { + // Disable leak check. LSAN sometimes gets false positives on thread locals. + // See: https://github.com/google/sanitizers/issues/757 + debug::ScopedLeakCheckDisabler d; + INIT_STATIC_THREAD_LOCAL(KernelStackWatchdog::TLS, tls_); + return tls_; +} + +KernelStackWatchdog::TLS::TLS() { + memset(&data_, 0, sizeof(data_)); + KernelStackWatchdog::GetInstance()->Register(this); +} + +KernelStackWatchdog::TLS::~TLS() { + KernelStackWatchdog::GetInstance()->Unregister(); +} + +// Optimistic concurrency control approach to snapshot the value of another +// thread's TLS, even though that thread might be changing it. +// +// Called by the watchdog thread to see if a target thread is currently in the +// middle of a watched section. +void KernelStackWatchdog::TLS::Data::SnapshotCopy(Data* copy) const { + while (true) { + Atomic32 v_0 = base::subtle::Acquire_Load(&seq_lock_); + if (v_0 & 1) { + // If the value is odd, then the thread is in the middle of modifying + // its TLS, and we have to spin. + base::subtle::PauseCPU(); + continue; + } + ANNOTATE_IGNORE_READS_BEGIN(); + memcpy(copy, this, sizeof(*copy)); + ANNOTATE_IGNORE_READS_END(); + Atomic32 v_1 = base::subtle::Release_Load(&seq_lock_); + + // If the value hasn't changed since we started the copy, then + // we know that the copy was a consistent snapshot. + if (v_1 == v_0) break; + } +} + +} // namespace kudu
