adamdebreceni commented on a change in pull request #1219:
URL: https://github.com/apache/nifi-minifi-cpp/pull/1219#discussion_r781125574



##########
File path: extensions/splunk/tests/MockSplunkHEC.h
##########
@@ -0,0 +1,214 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+#include <CivetServer.h>
+#include "core/logging/Logger.h"
+#include "core/logging/LoggerConfiguration.h"
+#include "rapidjson/document.h"
+#include "rapidjson/writer.h"
+#include "rapidjson/stringbuffer.h"
+
+
+class MockSplunkHandler : public CivetHandler {
+ public:
+  explicit MockSplunkHandler(std::string token, std::function<void(const 
struct mg_request_info *request_info)>& assertions) : token_(std::move(token)), 
assertions_(assertions) {
+  }
+
+  enum HeaderResult {
+    MissingAuth,
+    InvalidAuth,
+    MissingReqChannel,
+    HeadersOk
+  };
+
+  bool handlePost(CivetServer*, struct mg_connection *conn) override {
+    switch (checkHeaders(conn)) {
+      case MissingAuth:
+        return send401(conn);
+      case InvalidAuth:
+        return send403(conn);
+      case MissingReqChannel:
+        return send400(conn);
+      case HeadersOk:
+        return handlePostImpl(conn);
+    }
+    return false;
+  }
+
+  HeaderResult checkHeaders(struct mg_connection *conn) const {
+    const struct mg_request_info* req_info = mg_get_request_info(conn);
+    assertions_(req_info);
+    auto auth_header = std::find_if(std::begin(req_info->http_headers),
+                                    std::end(req_info->http_headers),
+                                    [](auto header) -> bool {return 
strcmp(header.name, "Authorization") == 0;});
+    if (auth_header == std::end(req_info->http_headers))
+      return MissingAuth;
+    if (strcmp(auth_header->value, token_.c_str()) != 0)
+      return InvalidAuth;
+
+    auto request_channel_header = 
std::find_if(std::begin(req_info->http_headers),
+                                               
std::end(req_info->http_headers),
+                                               [](auto header) -> bool {return 
strcmp(header.name, "X-Splunk-Request-Channel") == 0;});
+
+    if (request_channel_header == std::end(req_info->http_headers))
+      return MissingReqChannel;
+    return HeadersOk;
+  }
+
+  bool send400(struct mg_connection *conn) const {
+    constexpr const char * body = "{\"text\":\"Data channel is 
missing\",\"code\":10}";
+    mg_printf(conn, "HTTP/1.1 400 Bad Request\r\n");
+    mg_printf(conn, "Content-length: %lu", strlen(body));
+    mg_printf(conn, "\r\n\r\n");
+    mg_printf(conn, body);
+    return true;
+  }
+
+  bool send401(struct mg_connection *conn) const {
+    constexpr const char * body = "{\"text\":\"Token is 
required\",\"code\":2}";
+    mg_printf(conn, "HTTP/1.1 401 Unauthorized\r\n");
+    mg_printf(conn, "Content-length: %lu", strlen(body));
+    mg_printf(conn, "\r\n\r\n");
+    mg_printf(conn, body);
+    return true;
+  }
+
+  bool send403(struct mg_connection *conn) const {
+    constexpr const char * body = "{\"text\":\"Invalid token\",\"code\":4}";
+    mg_printf(conn, "HTTP/1.1 403 Forbidden\r\n");
+    mg_printf(conn, "Content-length: %lu", strlen(body));
+    mg_printf(conn, "\r\n\r\n");
+    mg_printf(conn, body);
+    return true;
+  }
+
+ protected:
+  virtual bool handlePostImpl(struct mg_connection *conn) = 0;
+  std::string token_;
+  std::function<void(const struct mg_request_info *request_info)>& assertions_;
+};
+
+class RawCollectorHandler : public MockSplunkHandler {
+ public:
+  explicit RawCollectorHandler(std::string token, std::function<void(const 
struct mg_request_info *request_info)>& assertions) : 
MockSplunkHandler(std::move(token), assertions) {}
+ protected:
+  bool handlePostImpl(struct mg_connection* conn) override {
+    constexpr const char * body = 
"{\"text\":\"Success\",\"code\":0,\"ackId\":808}";
+    mg_printf(conn, "HTTP/1.1 200 OK\r\n");
+    mg_printf(conn, "Content-length: %lu", strlen(body));
+    mg_printf(conn, "\r\n\r\n");
+    mg_printf(conn, body);
+    return true;
+  }
+};
+
+class AckIndexerHandler : public MockSplunkHandler {
+ public:
+  explicit AckIndexerHandler(std::string token, std::vector<uint64_t> 
indexed_events, std::function<void(const struct mg_request_info 
*request_info)>& assertions)
+      : MockSplunkHandler(std::move(token), assertions), 
indexed_events_(indexed_events) {}
+
+ protected:
+  bool handlePostImpl(struct mg_connection* conn) override {
+    std::vector<char> data;
+    data.reserve(2048);
+    mg_read(conn, data.data(), 2048);
+    rapidjson::Document post_data;
+
+    rapidjson::ParseResult parse_result = 
post_data.Parse<rapidjson::kParseStopWhenDoneFlag>(data.data());
+    if (parse_result.IsError())
+      return sendInvalidFormat(conn);
+    if (!post_data.HasMember("acks") || !post_data["acks"].IsArray())
+      return sendInvalidFormat(conn);
+    std::vector<uint64_t> ids;
+    for (auto& id : post_data["acks"].GetArray()) {
+      ids.push_back(id.GetUint64());
+    }
+    rapidjson::Document reply = rapidjson::Document(rapidjson::kObjectType);
+    reply.AddMember("acks", rapidjson::kObjectType, reply.GetAllocator());
+    for (auto& id : ids) {
+      rapidjson::Value key(std::to_string(id).c_str(), reply.GetAllocator());
+      reply["acks"].AddMember(key, std::find(indexed_events_.begin(), 
indexed_events_.end(), id) != indexed_events_.end() ? true : false, 
reply.GetAllocator());
+    }
+    rapidjson::StringBuffer buffer;
+    rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
+    reply.Accept(writer);
+
+    mg_printf(conn, "HTTP/1.1 200 OK\r\n");
+    mg_printf(conn, "Content-length: %lu", buffer.GetSize());
+    mg_printf(conn, "\r\n\r\n");
+    mg_printf(conn, "%s" , buffer.GetString());
+    return true;
+  }
+
+  bool sendInvalidFormat(struct mg_connection* conn) {
+    constexpr const char * body = "{\"text\":\"Invalid data 
format\",\"code\":6}";
+    mg_printf(conn, "HTTP/1.1 400 Bad Request\r\n");
+    mg_printf(conn, "Content-length: %lu", strlen(body));
+    mg_printf(conn, "\r\n\r\n");
+    mg_printf(conn, body);
+    return true;
+  }
+
+  std::vector<uint64_t> indexed_events_;
+};
+
+class MockSplunkHEC {
+ public:
+  static constexpr const char* TOKEN = "Splunk 
822f7d13-2b70-4f8c-848b-86edfc251222";
+
+  static inline std::vector<uint64_t> indexed_events = {0, 1};
+
+  explicit MockSplunkHEC(std::string port) : port_(std::move(port)) {
+    std::vector<std::string> options;
+    options.emplace_back("listening_ports");
+    options.emplace_back(port_);
+    server_.reset(new CivetServer(options, &callbacks_, &logger_));

Review comment:
       I think we need to initialize the civet library to safely use it like in 
TestServer.h




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to