This is an automated email from the ASF dual-hosted git repository.

bakaid pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git


The following commit(s) were added to refs/heads/master by this push:
     new 09573f8  MINIFICPP-1144 - Fix HTTPCallback freeze and refactor class
09573f8 is described below

commit 09573f87841e3d70f078e11b0d17b5ed2305e68e
Author: Daniel Bakai <[email protected]>
AuthorDate: Mon Feb 3 21:58:14 2020 +0100

    MINIFICPP-1144 - Fix HTTPCallback freeze and refactor class
    
    Signed-off-by: Daniel Bakai <[email protected]>
    
    Approved by aboda and szaszm on GH
    
    This closes #740
---
 extensions/http-curl/client/HTTPCallback.h         | 220 +++++++++++++--------
 extensions/http-curl/tests/CMakeLists.txt          |   1 +
 extensions/http-curl/tests/HTTPSiteToSiteTests.cpp |   1 +
 .../tests/unit/HTTPStreamingCallbackTests.cpp      | 159 +++++++++++++++
 libminifi/include/utils/HTTPClient.h               |  42 ++--
 5 files changed, 322 insertions(+), 101 deletions(-)

diff --git a/extensions/http-curl/client/HTTPCallback.h 
b/extensions/http-curl/client/HTTPCallback.h
index 170f40a..12d4b92 100644
--- a/extensions/http-curl/client/HTTPCallback.h
+++ b/extensions/http-curl/client/HTTPCallback.h
@@ -18,12 +18,13 @@
 #ifndef EXTENSIONS_HTTP_CURL_CLIENT_HTTPCALLBACK_H_
 #define EXTENSIONS_HTTP_CURL_CLIENT_HTTPCALLBACK_H_
 
-#include "concurrentqueue.h"
+#include <deque>
 #include <thread>
 #include <mutex>
 #include <vector>
 #include <condition_variable>
 
+#include "core/logging/LoggerConfiguration.h"
 #include "utils/ByteArrayCallback.h"
 
 namespace org {
@@ -33,148 +34,195 @@ namespace minifi {
 namespace utils {
 
 /**
- * will stream as items are processed.
+ * The original class here was deadlock-prone, undocumented and was a 
smorgasbord of multithreading primitives used inconsistently.
+ * This is a rewrite based on the contract inferred from this class's usage in 
utils::HTTPClient
+ * through HTTPStream and the non-buggy part of the behaviour of the original 
class.
+ * Based on these:
+ *  - this class provides a mechanism through which chunks of data can be 
inserted on a producer thread, while a
+ *    consumer thread simultaneously reads this stream of data in 
CURLOPT_READFUNCTION to supply a POST or PUT request
+ *    body with data utilizing HTTP chunked transfer encoding
+ *  - once a chunk of data is completely processed, we can discard it (i.e. 
the consumer will not seek backwards)
+ *  - if we expect that more data will be available, but there is none 
available at the current time, we should block
+ *    the consumer thread until either new data becomes available, or we are 
closed, signaling that there will be no
+ *    new data
+ *  - we signal that we have provided all data by returning a nullptr from 
getBuffer. After this no further calls asking
+ *    for data should be made on us
+ *  - we keep a current buffer and change this buffer once the consumer 
requests an offset which can no longer be served
+ *    by the current buffer
+ *  - because of this, all functions that request data at a specific offset 
are implicit seeks and potentially modify
+ *    the current buffer
  */
 class HttpStreamingCallback : public ByteInputCallBack {
  public:
   HttpStreamingCallback()
-      : is_alive_(true),
-        ptr(nullptr) {
-    previous_pos_ = 0;
-    rolling_count_ = 0;
+      : logger_(logging::LoggerFactory<HttpStreamingCallback>::getLogger()),
+        is_alive_(true),
+        total_bytes_loaded_(0U),
+        current_buffer_start_(0U),
+        current_pos_(0U),
+        ptr_(nullptr) {
   }
 
-  virtual ~HttpStreamingCallback() {
-
-  }
+  virtual ~HttpStreamingCallback() = default;
 
   void close() {
+    logger_->log_trace("close() called");
+    std::unique_lock<std::mutex> lock(mutex_);
     is_alive_ = false;
     cv.notify_all();
   }
 
-  virtual void seek(size_t pos) {
-    if ((pos - previous_pos_) >= current_vec_.size() || current_vec_.size() == 
0)
-      load_buffer();
+  void seek(size_t pos) override {
+    logger_->log_trace("seek(pos: %zu) called", pos);
+    std::unique_lock<std::mutex> lock(mutex_);
+    seekInner(lock, pos);
   }
 
-  virtual int64_t process(std::shared_ptr<io::BaseStream> stream) {
-
+  int64_t process(std::shared_ptr<io::BaseStream> stream) override {
     std::vector<char> vec;
 
     if (stream->getSize() > 0) {
       vec.resize(stream->getSize());
-
       stream->readData(reinterpret_cast<uint8_t*>(vec.data()), 
stream->getSize());
     }
 
-    size_t added_size = vec.size();
-
-    byte_arrays_.enqueue(std::move(vec));
-
-    cv.notify_all();
-
-    return added_size;
-
+    return processInner(std::move(vec));
   }
 
-  virtual int64_t process(uint8_t *vector, size_t size) {
-
+  virtual int64_t process(const uint8_t* data, size_t size) {
     std::vector<char> vec;
+    vec.resize(size);
+    memcpy(vec.data(), reinterpret_cast<const char*>(data), size);
 
-    if (size > 0) {
-      vec.resize(size);
+    return processInner(std::move(vec));
+  }
 
-      memcpy(vec.data(), vector, size);
+  void write(std::string content) override {
+    std::vector<char> vec;
+    vec.assign(content.begin(), content.end());
 
-      size_t added_size = vec.size();
+    (void) processInner(std::move(vec));
+  }
 
-      byte_arrays_.enqueue(std::move(vec));
+  char* getBuffer(size_t pos) override {
+    logger_->log_trace("getBuffer(pos: %zu) called", pos);
 
-      cv.notify_all();
+    std::unique_lock<std::mutex> lock(mutex_);
 
-      return added_size;
-    } else {
-      return 0;
+    seekInner(lock, pos);
+    if (ptr_ == nullptr) {
+      return nullptr;
     }
 
-  }
+    size_t relative_pos = pos - current_buffer_start_;
+    current_pos_ = pos;
 
-  virtual void write(std::string content) {
-    std::vector<char> vec;
-    vec.assign(content.begin(), content.end());
-    byte_arrays_.enqueue(vec);
+    return ptr_ + relative_pos;
   }
 
-  virtual char *getBuffer(size_t pos) {
+  const size_t getRemaining(size_t pos) override {
+    logger_->log_trace("getRemaining(pos: %zu) called", pos);
 
-    // if there is no space remaining in our current buffer,
-    // we should load the next. If none exists after that we have no more 
buffer
-    std::lock_guard<std::recursive_mutex> lock(mutex_);
+    std::unique_lock<std::mutex> lock(mutex_);
+    seekInner(lock, pos);
+    return total_bytes_loaded_ - pos;
+  }
 
-    if ((pos - previous_pos_) >= current_vec_.size() || current_vec_.size() == 
0)
-      load_buffer();
+  const size_t getBufferSize() override {
+    logger_->log_trace("getBufferSize() called");
 
-    if (ptr == nullptr)
-      return nullptr;
-
-    size_t absolute_position = pos - previous_pos_;
+    std::unique_lock<std::mutex> lock(mutex_);
+    // This is needed to make sure that the first buffer is loaded
+    seekInner(lock, current_pos_);
+    return total_bytes_loaded_;
+  }
 
-    current_pos_ = pos;
+ private:
 
-    return ptr + absolute_position;
+  /**
+   * Loads the next available buffer
+   * @param lock unique_lock which *must* own the lock
+   */
+  inline void loadNextBuffer(std::unique_lock<std::mutex>& lock) {
+    cv.wait(lock, [&] {
+      return !byte_arrays_.empty() || !is_alive_;
+    });
+
+    if (byte_arrays_.empty()) {
+      logger_->log_trace("loadNextBuffer() ran out of buffers");
+      ptr_ = nullptr;
+    } else {
+      current_vec_ = std::move(byte_arrays_.front());
+      byte_arrays_.pop_front();
+
+      ptr_ = current_vec_.data();
+      current_buffer_start_ = total_bytes_loaded_;
+      current_pos_ = current_buffer_start_;
+      total_bytes_loaded_ += current_vec_.size();
+      logger_->log_trace("loadNextBuffer() loaded new buffer, ptr_: %p, size: 
%zu, current_buffer_start_: %zu, current_pos_: %zu, total_bytes_loaded_: %zu",
+          ptr_,
+          current_vec_.size(),
+          current_buffer_start_,
+          current_pos_,
+          total_bytes_loaded_);
+    }
   }
 
-  virtual const size_t getRemaining(size_t pos) {
-    return current_vec_.size();
-  }
+  /**
+   * Common implementation for placing a buffer into the queue
+   * @param vec the buffer to be inserted
+   * @return the number of bytes processed (the size of vec)
+   */
+  int64_t processInner(std::vector<char>&& vec) {
+    size_t size = vec.size();
 
-  virtual const size_t getBufferSize() {
-    std::lock_guard<std::recursive_mutex> lock(mutex_);
+    logger_->log_trace("processInner() called, vec.data(): %p, vec.size(): 
%zu", vec.data(), size);
 
-    if (ptr == nullptr || current_pos_ >= rolling_count_) {
-      load_buffer();
+    if (size == 0U) {
+      return 0U;
     }
-    return rolling_count_;
-  }
 
- private:
+    std::unique_lock<std::mutex> lock(mutex_);
+    byte_arrays_.emplace_back(std::move(vec));
+    cv.notify_all();
+
+    return size;
+  }
 
-  inline void load_buffer() {
-    std::unique_lock<std::recursive_mutex> lock(mutex_);
-    cv.wait(lock, [&] {return byte_arrays_.size_approx() > 0 || 
is_alive_==false;});
-    if (!is_alive_ && byte_arrays_.size_approx() == 0) {
-      lock.unlock();
-      return;
+  /**
+   * Seeks to the specified position
+   * @param lock unique_lock which *must* own the lock
+   * @param pos position to seek to
+   */
+  void seekInner(std::unique_lock<std::mutex>& lock, size_t pos) {
+    logger_->log_trace("seekInner() called, current_pos_: %zu, pos: %zu", 
current_pos_, pos);
+    if (pos < current_pos_) {
+      const std::string errstr = "Seeking backwards is not supported, tried to 
seek from " + std::to_string(current_pos_) + " to " + std::to_string(pos);
+      logger_->log_error("%s", errstr);
+      throw std::logic_error(errstr);
     }
-    try {
-      if (byte_arrays_.try_dequeue(current_vec_)) {
-        ptr = &current_vec_[0];
-        previous_pos_.store(rolling_count_.load());
-        current_pos_ = 0;
-        rolling_count_ += current_vec_.size();
-      } else {
-        ptr = nullptr;
+    while ((pos - current_buffer_start_) >= current_vec_.size()) {
+      loadNextBuffer(lock);
+      if (ptr_ == nullptr) {
+        break;
       }
-      lock.unlock();
-    } catch (...) {
-      lock.unlock();
     }
   }
 
-  std::atomic<bool> is_alive_;
-  std::atomic<size_t> rolling_count_;
-  std::condition_variable_any cv;
-  std::atomic<size_t> previous_pos_;
-  std::atomic<size_t> current_pos_;
+  std::shared_ptr<logging::Logger> logger_;
 
-  std::recursive_mutex mutex_;
+  std::mutex mutex_;
+  std::condition_variable cv;
 
-  moodycamel::ConcurrentQueue<std::vector<char>> byte_arrays_;
+  bool is_alive_;
+  size_t total_bytes_loaded_;
+  size_t current_buffer_start_;
+  size_t current_pos_;
 
-  char *ptr;
+  std::deque<std::vector<char>> byte_arrays_;
 
   std::vector<char> current_vec_;
+  char *ptr_;
 };
 
 } /* namespace utils */
diff --git a/extensions/http-curl/tests/CMakeLists.txt 
b/extensions/http-curl/tests/CMakeLists.txt
index e12e956..ab9176d 100644
--- a/extensions/http-curl/tests/CMakeLists.txt
+++ b/extensions/http-curl/tests/CMakeLists.txt
@@ -67,6 +67,7 @@ ENDFOREACH()
 message("-- Finished building ${CURL_INT_TEST_COUNT} libcURL integration test 
file(s)...")
 
 add_test(NAME HTTPClientTests COMMAND "HTTPClientTests" WORKING_DIRECTORY 
${TEST_DIR})
+add_test(NAME HTTPStreamingCallbackTests COMMAND "HTTPStreamingCallbackTests" 
WORKING_DIRECTORY ${TEST_DIR})
 
 add_test(NAME HttpGetIntegrationTest COMMAND HttpGetIntegrationTest 
"${TEST_RESOURCES}/TestHTTPGet.yml"  "${TEST_RESOURCES}/")
 add_test(NAME C2UpdateTest COMMAND C2UpdateTest 
"${TEST_RESOURCES}/TestHTTPGet.yml"  "${TEST_RESOURCES}/")
diff --git a/extensions/http-curl/tests/HTTPSiteToSiteTests.cpp 
b/extensions/http-curl/tests/HTTPSiteToSiteTests.cpp
index 0d2ec0c..39ef776 100644
--- a/extensions/http-curl/tests/HTTPSiteToSiteTests.cpp
+++ b/extensions/http-curl/tests/HTTPSiteToSiteTests.cpp
@@ -66,6 +66,7 @@ class SiteToSiteTestHarness : public CoapIntegrationBase {
     
LogTestController::getInstance().setTrace<minifi::controllers::SSLContextService>();
     LogTestController::getInstance().setInfo<minifi::FlowController>();
     LogTestController::getInstance().setDebug<core::ConfigurableComponent>();
+    LogTestController::getInstance().setTrace<utils::HttpStreamingCallback>();
 
     std::fstream file;
     ss << dir << utils::file::FileUtils::get_separator() << "tstFile.ext";
diff --git a/extensions/http-curl/tests/unit/HTTPStreamingCallbackTests.cpp 
b/extensions/http-curl/tests/unit/HTTPStreamingCallbackTests.cpp
new file mode 100644
index 0000000..567531f
--- /dev/null
+++ b/extensions/http-curl/tests/unit/HTTPStreamingCallbackTests.cpp
@@ -0,0 +1,159 @@
+/**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <thread>
+#include <mutex>
+#include <vector>
+#include <string>
+#include <chrono>
+#include <cstring>
+#include <cstdint>
+
+#include "client/HTTPCallback.h"
+#include "TestBase.h"
+
+class HttpStreamingCallbackTestsFixture {
+ public:
+  HttpStreamingCallbackTestsFixture() {
+    LogTestController::getInstance().setTrace<utils::HttpStreamingCallback>();
+  }
+
+  virtual ~HttpStreamingCallbackTestsFixture() {
+    if (consumer_thread_.joinable()) {
+      consumer_thread_.join();
+    }
+    LogTestController::getInstance().reset();
+  }
+
+  void startConsumerThread() {
+    if (consumer_thread_.joinable()) {
+      throw std::logic_error("Consumer thread already started");
+    }
+    consumer_thread_ = std::thread([this](){
+      std::cerr << "Consumer thread started" << std::endl;
+
+      size_t current_pos = 0U;
+
+      while (true) {
+        size_t buffer_size = callback_.getBufferSize();
+        if (current_pos <= buffer_size) {
+          size_t len = buffer_size - current_pos;
+          char* ptr = callback_.getBuffer(current_pos);
+          if (ptr == nullptr) {
+            break;
+          }
+          {
+            std::unique_lock<std::mutex> lock(content_mutex_);
+            content_.resize(content_.size() + len);
+            memcpy(content_.data() + current_pos, ptr, len);
+          }
+          current_pos += len;
+          callback_.seek(current_pos);
+        }
+      }
+    });
+  }
+
+  std::string getContent() {
+    std::unique_lock<std::mutex> lock(content_mutex_);
+    return std::string(content_.data(), content_.size());
+  }
+
+  std::string waitForCompletionAndGetContent() {
+    if (consumer_thread_.joinable()) {
+      consumer_thread_.join();
+    }
+    return getContent();
+  }
+
+ protected:
+  utils::HttpStreamingCallback callback_;
+  std::mutex content_mutex_;
+  std::vector<char> content_;
+  std::thread consumer_thread_;
+};
+
+
+TEST_CASE_METHOD(HttpStreamingCallbackTestsFixture, "HttpStreamingCallback 
empty", "[basic]") {
+  SECTION("with wait") {
+    startConsumerThread();
+    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+  }
+
+  callback_.close();
+
+  SECTION("without wait") {
+    startConsumerThread();
+  }
+
+  std::string content = waitForCompletionAndGetContent();
+
+  REQUIRE(0U == content.length());
+}
+
+TEST_CASE_METHOD(HttpStreamingCallbackTestsFixture, "HttpStreamingCallback one 
buffer", "[basic]") {
+  SECTION("with wait") {
+    startConsumerThread();
+    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+  }
+
+  std::string input = "foobar";
+  callback_.process(reinterpret_cast<const uint8_t*>(input.c_str()), 
input.length());
+  callback_.close();
+
+  SECTION("without wait") {
+    startConsumerThread();
+  }
+
+  std::string content = waitForCompletionAndGetContent();
+
+  REQUIRE(input == content);
+}
+
+TEST_CASE_METHOD(HttpStreamingCallbackTestsFixture, "HttpStreamingCallback 
multiple buffers", "[basic]") {
+  SECTION("with wait") {
+    startConsumerThread();
+    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+  }
+
+  std::string input;
+  for (size_t i = 0U; i < 16U; i++) {
+    std::string chunk = std::to_string(i << 16);
+    input += chunk;
+    callback_.process(reinterpret_cast<const uint8_t*>(chunk.c_str()), 
chunk.length());
+    if (i == 7U) {
+      SECTION("with staggered wait" + std::to_string(i)) {
+        startConsumerThread();
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+      }
+    }
+  }
+  SECTION("with wait before close") {
+    startConsumerThread();
+    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+  }
+  callback_.close();
+
+  SECTION("without wait") {
+    startConsumerThread();
+  }
+
+  std::string content = waitForCompletionAndGetContent();
+
+  REQUIRE(input == content);
+}
diff --git a/libminifi/include/utils/HTTPClient.h 
b/libminifi/include/utils/HTTPClient.h
index 8cdc3ef..46a066d 100644
--- a/libminifi/include/utils/HTTPClient.h
+++ b/libminifi/include/utils/HTTPClient.h
@@ -168,6 +168,8 @@ class HTTPRequestResponse {
 
  public:
 
+  static const size_t CALLBACK_ABORT = 0x10000000;
+
   const std::vector<char> &getData() {
     return data;
   }
@@ -185,11 +187,19 @@ class HTTPRequestResponse {
    * Receive HTTP Response.
    */
   static size_t recieve_write(char * data, size_t size, size_t nmemb, void * 
p) {
-    HTTPReadCallback *callback = static_cast<HTTPReadCallback*>(p);
-    if (callback->stop)
-      return 0x10000000;
-    callback->ptr->write(data, (size * nmemb));
-    return (size * nmemb);
+    try {
+      if (p == nullptr) {
+        return CALLBACK_ABORT;
+      }
+      HTTPReadCallback *callback = static_cast<HTTPReadCallback *>(p);
+      if (callback->stop) {
+        return CALLBACK_ABORT;
+      }
+      callback->ptr->write(data, (size * nmemb));
+      return (size * nmemb);
+    } catch (...) {
+      return CALLBACK_ABORT;
+    }
   }
 
   /**
@@ -201,15 +211,18 @@ class HTTPRequestResponse {
    */
 
   static size_t send_write(char * data, size_t size, size_t nmemb, void * p) {
-    if (p != 0) {
-      HTTPUploadCallback *callback = (HTTPUploadCallback*) p;
-      if (callback->stop)
-        return 0x10000000;
+    try {
+      if (p == nullptr) {
+        return CALLBACK_ABORT;
+      }
+      HTTPUploadCallback *callback = (HTTPUploadCallback *) p;
+      if (callback->stop) {
+        return CALLBACK_ABORT;
+      }
       size_t buffer_size = callback->ptr->getBufferSize();
       if (callback->getPos() <= buffer_size) {
         size_t len = buffer_size - callback->pos;
-        if (len <= 0)
-        {
+        if (len <= 0) {
           return 0;
         }
         char *ptr = callback->ptr->getBuffer(callback->getPos());
@@ -219,16 +232,15 @@ class HTTPRequestResponse {
         }
         if (len > size * nmemb)
           len = size * nmemb;
-        auto strr = std::string(ptr,len);
         memcpy(data, ptr, len);
         callback->pos += len;
         callback->ptr->seek(callback->getPos());
         return len;
       }
-    } else {
-      return 0x10000000;
+      return 0;
+    } catch (...) {
+      return CALLBACK_ABORT;
     }
-    return 0;
   }
 
   int read_data(uint8_t *buf, size_t size) {

Reply via email to