Repository: nifi-minifi-cpp Updated Branches: refs/heads/master b8e45cbf9 -> dec7caef7
MINIFICPP-358 Added TFExtractTopLabels This closes #232. Signed-off-by: Marc Parisi <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/repo Commit: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/commit/dec7caef Tree: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/tree/dec7caef Diff: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/diff/dec7caef Branch: refs/heads/master Commit: dec7caef7dd348a1fa80d0f4db5d4fc979f785fa Parents: b8e45cb Author: Andy I. Christianson <[email protected]> Authored: Mon Jan 8 17:13:27 2018 -0500 Committer: Marc Parisi <[email protected]> Committed: Wed Jan 10 14:16:44 2018 -0500 ---------------------------------------------------------------------- PROCESSORS.md | 116 ++++++++++++- README.md | 23 +-- .../tensorflow/TFConvertImageToTensor.cpp | 2 +- extensions/tensorflow/TFExtractTopLabels.cpp | 173 +++++++++++++++++++ extensions/tensorflow/TFExtractTopLabels.h | 92 ++++++++++ .../test/tensorflow-tests/TensorFlowTests.cpp | 117 ++++++++++++- 6 files changed, 498 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/PROCESSORS.md ---------------------------------------------------------------------- diff --git a/PROCESSORS.md b/PROCESSORS.md index 6c4eedf..85b08d9 100644 --- a/PROCESSORS.md +++ b/PROCESSORS.md @@ -18,24 +18,29 @@ ## Table of Contents - [AppendHostInfo](#appendhostinfo) +- [CompressContent](#compresscontent) +* [ConsumeMQTT](#consumeMQTT) - [ExecuteProcess](#executeprocess) - [ExecuteScript](#executescript) +- [ExtractText](#extracttext) +- [FocusArchiveEntry](#focusarchiveentry) +- [GenerateFlowFile](#generateflowfile) - [GetFile](#getfile) - [GetUSBCamera](#getusbcamera) -- [GenerateFlowFile](#generateflowfile) - [InvokeHTTP](#invokehttp) -- [LogAttribute](#logattribute) - [ListenHTTP](#listenhttp) - [ListenSyslog](#listensyslog) +- [LogAttribute](#logattribute) +- [ManipulateArchive](#manipulatearchive) +- [MergeContent](#mergecontent) +- [PublishKafka](#publishkafka) +* [PublishMQTT](PROCESSORS.md#publishMQTT) - [PutFile](#putfile) - [TailFile](#tailfile) -- [MergeContent](#mergecontent) -- [ExtractText](#extracttext) -- [CompressContent](#compresscontent) -- [FocusArchiveEntry](#focusarchiveentry) +- [TFApplyGraph](#tfapplygraph) +- [TFConvertImageToTensor](#tfconvertimagetotensor) +- [TFExtractTopLabels](#tfextracttoplabels) - [UnfocusArchiveEntry](#unfocusarchiveentry) -- [ManipulateArchive](#manipulatearchive) -- [PublishKafka](#publishkafka) ## AppendHostInfo @@ -535,6 +540,101 @@ default values, and whether a property supports the NiFi Expression Language. | - | - | | success | All FlowFiles are routed to this Relationship. | +## TFApplyGraph + +### Description + +Applies a TensorFlow graph to the tensor protobuf supplied as input. The tensor +is fed into the node specified by the `Input Node` property. The output +FlowFile is a tensor protobuf extracted from the node specified by the `Output +Node` property. + +TensorFlow graphs are read dynamically by feeding a graph protobuf to the +processor with the `tf.type` property set to `graph`. + +### Properties + +In the list below, the names of required properties appear in bold. Any other +properties (not in bold) are considered optional. The table also indicates any +default values, and whether a property supports the NiFi Expression Language. + +| Name | Default Value | Allowable Values | Description | +| - | - | - | - | +| **Input Node** | | | The node of the TensorFlow graph to feed tensor inputs to | +| **Output Node** | | | The node of the TensorFlow graph to read tensor outputs from | + +### Relationships + +| Name | Description | +| - | - | +| success | Successful graph application outputs as tensor protobufs | +| retry | Inputs which fail graph application but may work if sent again | +| failure | Failures which will not work if retried | + +## TFConvertImageToTensor + +### Description + +Converts the input image file into a tensor protobuf. The image will be resized +to the given output tensor dimensions. + +### Properties + +In the list below, the names of required properties appear in bold. Any other +properties (not in bold) are considered optional. The table also indicates any +default values, and whether a property supports the NiFi Expression Language. + +| Name | Default Value | Allowable Values | Description | +| - | - | - | - | +| **Input Format** | | PNG, RAW | The format of the input image (PNG or RAW). RAW is RGB24. | +| **Input Width** | | | The width, in pixels, of the input image. | +| **Input Height** | | | The height, in pixels, of the input image. | +| **Output Width** | | | The width, in pixels, of the output image. | +| **Output Height** | | | The height, in pixels, of the output image. | +| **Channels** | 3 | | The number of channels (e.g. 3 for RGB, 4 for RGBA) in the input image. | + +### Relationships + +| Name | Description | +| - | - | +| success | Successfully read tensor protobufs | +| failure | Inputs which could not be converted to tensor protobufs | + +## TFExtractTopLabels + +### Description + +Extracts the top 5 labels for categorical inference models. + +Labels are fed as newline (`\n`) -delimited files where each line is a label +for the tensor index equivalent to the line number. Label files must be fed in +with the `tf.type` property set to `labels`. + +The top 5 labels are written to the following attributes: + +- `top_label_0` +- `top_label_1` +- `top_label_2` +- `top_label_3` +- `top_label_4` + +### Properties + +In the list below, the names of required properties appear in bold. Any other +properties (not in bold) are considered optional. The table also indicates any +default values, and whether a property supports the NiFi Expression Language. + +| Name | Default Value | Allowable Values | Description | +| - | - | - | - | + +### Relationships + +| Name | Description | +| - | - | +| success | Successful FlowFiles are sent here with labels as attributes | +| retry | Failures which might work if retried | +| failure | Failures which will not work if retried | + ## MergeContent Merges a Group of FlowFiles together based on a user-defined strategy and http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/README.md ---------------------------------------------------------------------- diff --git a/README.md b/README.md index 40f7c97..ed0b834 100644 --- a/README.md +++ b/README.md @@ -47,26 +47,29 @@ Perspectives of the role of MiNiFi should be from the perspective of the agent a MiNiFi - C++ supports the following processors: * [AppendHostInfo](PROCESSORS.md#appendhostinfo) +* [CompressContent](PROCESSORS.md#compresscontent) +* [ConsumeMQTT](PROCESSORS.md#consumeMQTT) * [ExecuteProcess](PROCESSORS.md#executeprocess) * [ExecuteScript](PROCESSORS.md#executescript) +* [ExtractText](PROCESSORS.md#extracttext) +* [FocusArchiveEntry](PROCESSORS.md#focusarchiveentry) +* [GenerateFlowFile](PROCESSORS.md#generateflowfile) * [GetFile](PROCESSORS.md#getfile) * [GetUSBCamera](PROCESSORS.md#getusbcamera) -* [GenerateFlowFile](PROCESSORS.md#generateflowfile) * [InvokeHTTP](PROCESSORS.md#invokehttp) -* [LogAttribute](PROCESSORS.md#logattribute) * [ListenHTTP](PROCESSORS.md#listenhttp) * [ListenSyslog](PROCESSORS.md#listensyslog) -* [PutFile](PROCESSORS.md#putfile) -* [TailFile](PROCESSORS.md#tailfile) -* [MergeContent](PROCESSORS.md#mergecontent) -* [ExtractText](PROCESSORS.md#extracttext) -* [CompressContent](PROCESSORS.md#compresscontent) -* [FocusArchiveEntry](PROCESSORS.md#focusarchiveentry) -* [UnfocusArchiveEntry](PROCESSORS.md#unfocusarchiveentry) +* [LogAttribute](PROCESSORS.md#logattribute) * [ManipulateArchive](PROCESSORS.md#manipulatearchive) +* [MergeContent](PROCESSORS.md#mergecontent) * [PublishKafka](PROCESSORS.md#publishkafka) * [PublishMQTT](PROCESSORS.md#publishMQTT) -* [ConsumeMQTT](PROCESSORS.md#consumeMQTT) +* [PutFile](PROCESSORS.md#putfile) +* [TailFile](PROCESSORS.md#tailfile) +* [TFApplyGraph](PROCESSORS.md#tfapplygraph) +* [TFConvertImageToTensor](PROCESSORS.md#tfconvertimagetotensor) +* [TFExtractTopLabels](PROCESSORS.md#tfextracttoplabels) +* [UnfocusArchiveEntry](PROCESSORS.md#unfocusarchiveentry) ## Caveats * 0.4.0 represents a non-GA release, APIs and interfaces are subject to change http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/extensions/tensorflow/TFConvertImageToTensor.cpp ---------------------------------------------------------------------- diff --git a/extensions/tensorflow/TFConvertImageToTensor.cpp b/extensions/tensorflow/TFConvertImageToTensor.cpp index be5e7a1..803ea48 100644 --- a/extensions/tensorflow/TFConvertImageToTensor.cpp +++ b/extensions/tensorflow/TFConvertImageToTensor.cpp @@ -27,7 +27,7 @@ namespace processors { core::Property TFConvertImageToTensor::ImageFormat( // NOLINT "Input Format", - "The node of the TensorFlow graph to feed tensor inputs to (PNG or RAW). RAW is RGB24.", ""); + "The format of the input image (PNG or RAW). RAW is RGB24.", ""); core::Property TFConvertImageToTensor::InputWidth( // NOLINT "Input Width", "The width, in pixels, of the input image.", ""); http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/extensions/tensorflow/TFExtractTopLabels.cpp ---------------------------------------------------------------------- diff --git a/extensions/tensorflow/TFExtractTopLabels.cpp b/extensions/tensorflow/TFExtractTopLabels.cpp new file mode 100644 index 0000000..723f7dc --- /dev/null +++ b/extensions/tensorflow/TFExtractTopLabels.cpp @@ -0,0 +1,173 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TFExtractTopLabels.h" + +#include "tensorflow/cc/ops/standard_ops.h" + +namespace org { +namespace apache { +namespace nifi { +namespace minifi { +namespace processors { + +core::Relationship TFExtractTopLabels::Success( // NOLINT + "success", + "Successful FlowFiles are sent here with labels as attributes"); +core::Relationship TFExtractTopLabels::Retry( // NOLINT + "retry", + "Failures which might work if retried"); +core::Relationship TFExtractTopLabels::Failure( // NOLINT + "failure", + "Failures which will not work if retried"); + +void TFExtractTopLabels::initialize() { + std::set<core::Property> properties; + setSupportedProperties(std::move(properties)); + + std::set<core::Relationship> relationships; + relationships.insert(Success); + relationships.insert(Retry); + relationships.insert(Failure); + setSupportedRelationships(std::move(relationships)); +} + +void TFExtractTopLabels::onSchedule(core::ProcessContext *context, core::ProcessSessionFactory *sessionFactory) { +} + +void TFExtractTopLabels::onTrigger(const std::shared_ptr<core::ProcessContext> &context, + const std::shared_ptr<core::ProcessSession> &session) { + auto flow_file = session->get(); + + if (!flow_file) { + return; + } + + try { + + // Read labels + std::string tf_type; + flow_file->getAttribute("tf.type", tf_type); + std::shared_ptr<std::vector<std::string>> labels; + + { + std::lock_guard<std::mutex> guard(labels_mtx_); + + if (tf_type == "labels") { + logger_->log_info("Reading new labels..."); + auto new_labels = std::make_shared<std::vector<std::string>>(); + LabelsReadCallback cb(new_labels); + session->read(flow_file, &cb); + labels_ = new_labels; + logger_->log_info("Read %d new labels", labels_->size()); + session->remove(flow_file); + return; + } + + labels = labels_; + } + + // Read input tensor from flow file + auto input_tensor_proto = std::make_shared<tensorflow::TensorProto>(); + TensorReadCallback tensor_cb(input_tensor_proto); + session->read(flow_file, &tensor_cb); + + tensorflow::Tensor input; + input.FromProto(*input_tensor_proto); + auto input_flat = input.flat<float>(); + + std::vector<std::pair<uint64_t, float>> scores; + + for (int i = 0; i < input_flat.size(); i++) { + scores.emplace_back(std::make_pair(i, input_flat(i))); + } + + std::sort(scores.begin(), scores.end(), [](const std::pair<uint64_t, float> &a, + const std::pair<uint64_t, float> &b) { + return a.second > b.second; + }); + + for (int i = 0; i < 5 && i < scores.size(); i++) { + if (!labels || scores[i].first > labels->size()) { + logger_->log_error("Label index is out of range (are the correct labels loaded?); routing to retry..."); + session->transfer(flow_file, Retry); + return; + } + flow_file->addAttribute("tf.top_label_" + std::to_string(i), labels->at(scores[i].first)); + } + + session->transfer(flow_file, Success); + + } catch (std::exception &exception) { + logger_->log_error("Caught Exception %s", exception.what()); + session->transfer(flow_file, Failure); + this->yield(); + } catch (...) { + logger_->log_error("Caught Exception"); + session->transfer(flow_file, Failure); + this->yield(); + } +} + +int64_t TFExtractTopLabels::LabelsReadCallback::process(std::shared_ptr<io::BaseStream> stream) { + int64_t total_read = 0; + std::string label; + uint64_t max_label_len = 65536; + label.resize(max_label_len); + std::string buf; + uint64_t label_size = 0; + uint64_t buf_size = 8096; + buf.resize(buf_size); + + while (total_read < stream->getSize()) { + auto read = stream->read(reinterpret_cast<uint8_t *>(&buf[0]), static_cast<int>(buf_size)); + + for (auto i = 0; i < read; i++) { + if (buf[i] == '\n' || total_read + i == stream->getSize()) { + labels_->emplace_back(label.substr(0, label_size)); + label_size = 0; + } else { + label[label_size] = buf[i]; + label_size++; + } + } + + total_read += read; + } + + return total_read; +} + +int64_t TFExtractTopLabels::TensorReadCallback::process(std::shared_ptr<io::BaseStream> stream) { + std::string tensor_proto_buf; + tensor_proto_buf.resize(stream->getSize()); + auto num_read = stream->readData(reinterpret_cast<uint8_t *>(&tensor_proto_buf[0]), + static_cast<int>(stream->getSize())); + + if (num_read != stream->getSize()) { + throw std::runtime_error("TensorReadCallback failed to fully read flow file input stream"); + } + + tensor_proto_->ParseFromString(tensor_proto_buf); + return num_read; +} + +} /* namespace processors */ +} /* namespace minifi */ +} /* namespace nifi */ +} /* namespace apache */ +} /* namespace org */ http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/extensions/tensorflow/TFExtractTopLabels.h ---------------------------------------------------------------------- diff --git a/extensions/tensorflow/TFExtractTopLabels.h b/extensions/tensorflow/TFExtractTopLabels.h new file mode 100644 index 0000000..58ed57f --- /dev/null +++ b/extensions/tensorflow/TFExtractTopLabels.h @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NIFI_MINIFI_CPP_TFEXTRACTTOPLABELS_H +#define NIFI_MINIFI_CPP_TFEXTRACTTOPLABELS_H + +#include <atomic> + +#include <core/Resource.h> +#include <core/Processor.h> +#include <tensorflow/core/public/session.h> +#include <concurrentqueue.h> + +namespace org { +namespace apache { +namespace nifi { +namespace minifi { +namespace processors { + +class TFExtractTopLabels : public core::Processor { + public: + explicit TFExtractTopLabels(const std::string &name, uuid_t uuid = nullptr) + : Processor(name, uuid), + logger_(logging::LoggerFactory<TFExtractTopLabels>::getLogger()) { + } + + static core::Relationship Success; + static core::Relationship Retry; + static core::Relationship Failure; + + void initialize() override; + void onSchedule(core::ProcessContext *context, core::ProcessSessionFactory *sessionFactory) override; + void onTrigger(core::ProcessContext *context, core::ProcessSession *session) override { + logger_->log_error("onTrigger invocation with raw pointers is not implemented"); + } + void onTrigger(const std::shared_ptr<core::ProcessContext> &context, + const std::shared_ptr<core::ProcessSession> &session) override; + + class LabelsReadCallback : public InputStreamCallback { + public: + explicit LabelsReadCallback(std::shared_ptr<std::vector<std::string>> labels) + : labels_(std::move(labels)) { + } + ~LabelsReadCallback() override = default; + int64_t process(std::shared_ptr<io::BaseStream> stream) override; + + private: + std::shared_ptr<std::vector<std::string>> labels_; + }; + + class TensorReadCallback : public InputStreamCallback { + public: + explicit TensorReadCallback(std::shared_ptr<tensorflow::TensorProto> tensor_proto) + : tensor_proto_(std::move(tensor_proto)) { + } + ~TensorReadCallback() override = default; + int64_t process(std::shared_ptr<io::BaseStream> stream) override; + + private: + std::shared_ptr<tensorflow::TensorProto> tensor_proto_; + }; + + private: + std::shared_ptr<logging::Logger> logger_; + + std::shared_ptr<std::vector<std::string>> labels_; + std::mutex labels_mtx_; +}; + +REGISTER_RESOURCE(TFExtractTopLabels); // NOLINT + +} /* namespace processors */ +} /* namespace minifi */ +} /* namespace nifi */ +} /* namespace apache */ +} /* namespace org */ + +#endif //NIFI_MINIFI_CPP_TFEXTRACTTOPLABELS_H http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/libminifi/test/tensorflow-tests/TensorFlowTests.cpp ---------------------------------------------------------------------- diff --git a/libminifi/test/tensorflow-tests/TensorFlowTests.cpp b/libminifi/test/tensorflow-tests/TensorFlowTests.cpp index e9499de..4bca07a 100644 --- a/libminifi/test/tensorflow-tests/TensorFlowTests.cpp +++ b/libminifi/test/tensorflow-tests/TensorFlowTests.cpp @@ -25,14 +25,14 @@ #include <processors/GetFile.h> #include <processors/LogAttribute.h> #include <TFConvertImageToTensor.h> +#include <TFExtractTopLabels.h> #include "TFApplyGraph.h" -#include "TFConvertImageToTensor.h" #define CATCH_CONFIG_MAIN #include "../TestBase.h" -TEST_CASE("TensorFlow: Apply Graph", "[executescriptTensorFlowApplyGraph]") { // NOLINT +TEST_CASE("TensorFlow: Apply Graph", "[tfApplyGraph]") { // NOLINT TestController testController; LogTestController::getInstance().setTrace<TestPlan>(); @@ -127,7 +127,7 @@ TEST_CASE("TensorFlow: Apply Graph", "[executescriptTensorFlowApplyGraph]") { // // Read test TensorFlow graph into TFApplyGraph plan->runNextProcessor([&get_file, &in_graph_file, &plan](const std::shared_ptr<core::ProcessContext> context, - const std::shared_ptr<core::ProcessSession> session) { + const std::shared_ptr<core::ProcessSession> session) { // Intercept the call so that we can add an attr (won't be required when we have UpdateAttribute processor) auto flow_file = session->create(); session->import(in_graph_file, flow_file, false); @@ -171,7 +171,7 @@ TEST_CASE("TensorFlow: Apply Graph", "[executescriptTensorFlowApplyGraph]") { // } } -TEST_CASE("TensorFlow: ConvertImageToTensor", "[executescriptTensorFlowConvertImageToTensor]") { // NOLINT +TEST_CASE("TensorFlow: ConvertImageToTensor", "[tfConvertImageToTensor]") { // NOLINT TestController testController; LogTestController::getInstance().setTrace<TestPlan>(); @@ -266,8 +266,8 @@ TEST_CASE("TensorFlow: ConvertImageToTensor", "[executescriptTensorFlowConvertIm // Write test input image { // 2x2 single-channel 8 bit per channel - const uint8_t in_img_raw[2*2] = {0, 0, - 0, 0}; + const uint8_t in_img_raw[2 * 2] = {0, 0, + 0, 0}; std::ofstream in_file_stream(in_img_file); in_file_stream << in_img_raw; @@ -299,3 +299,108 @@ TEST_CASE("TensorFlow: ConvertImageToTensor", "[executescriptTensorFlowConvertIm 1})); // Channels } } + +TEST_CASE("TensorFlow: Extract Top Labels", "[tfExtractTopLabels]") { // NOLINT + TestController testController; + + LogTestController::getInstance().setTrace<TestPlan>(); + LogTestController::getInstance().setTrace<processors::TFExtractTopLabels>(); + LogTestController::getInstance().setTrace<processors::GetFile>(); + LogTestController::getInstance().setTrace<processors::LogAttribute>(); + + auto plan = testController.createPlan(); + auto repo = std::make_shared<TestRepository>(); + + // Define directory for input protocol buffers + std::string in_dir("/tmp/gt.XXXXXX"); + REQUIRE(testController.createTempDirectory(&in_dir[0]) != nullptr); + + // Define input labels file + std::string in_labels_file(in_dir); + in_labels_file.append("/in_labels"); + + // Define input tensor protocol buffer file + std::string in_tensor_file(in_dir); + in_tensor_file.append("/tensor.pb"); + + // Build MiNiFi processing graph + auto get_file = plan->addProcessor( + "GetFile", + "Get Input"); + plan->setProperty( + get_file, + processors::GetFile::Directory.getName(), in_dir); + plan->setProperty( + get_file, + processors::GetFile::KeepSourceFile.getName(), + "false"); + plan->addProcessor( + "LogAttribute", + "Log Pre Extract", + core::Relationship("success", "description"), + true); + auto tf_apply = plan->addProcessor( + "TFExtractTopLabels", + "Extract", + core::Relationship("success", "description"), + true); + plan->addProcessor( + "LogAttribute", + "Log Post Extract", + core::Relationship("success", "description"), + true); + + // Build test labels + { + // Write labels + std::ofstream in_file_stream(in_labels_file); + in_file_stream << "label_a\nlabel_b\nlabel_c\nlabel_d\nlabel_e\nlabel_f\nlabel_g\nlabel_h\nlabel_i\nlabel_j\n"; + } + + // Read labels + plan->runNextProcessor([&get_file, &in_labels_file, &plan](const std::shared_ptr<core::ProcessContext> context, + const std::shared_ptr<core::ProcessSession> session) { + // Intercept the call so that we can add an attr (won't be required when we have UpdateAttribute processor) + auto flow_file = session->create(); + session->import(in_labels_file, flow_file, false); + flow_file->addAttribute("tf.type", "labels"); + session->transfer(flow_file, processors::GetFile::Success); + session->commit(); + }); + + plan->runNextProcessor(); // Log + plan->runNextProcessor(); // Extract (loads labels) + + // Write input tensor + { + tensorflow::Tensor input(tensorflow::DT_FLOAT, {10}); + input.flat<float>().data()[0] = 0.000f; + input.flat<float>().data()[1] = 0.400f; + input.flat<float>().data()[2] = 0.100f; + input.flat<float>().data()[3] = 0.005f; + input.flat<float>().data()[4] = 1.000f; + input.flat<float>().data()[5] = 0.500f; + input.flat<float>().data()[6] = 0.200f; + input.flat<float>().data()[7] = 0.000f; + input.flat<float>().data()[8] = 0.300f; + input.flat<float>().data()[9] = 0.000f; + tensorflow::TensorProto tensor_proto; + input.AsProtoTensorContent(&tensor_proto); + + std::ofstream in_file_stream(in_tensor_file); + tensor_proto.SerializeToOstream(&in_file_stream); + } + + plan->reset(); + plan->runNextProcessor(); // GetFile + plan->runNextProcessor(); // Log + plan->runNextProcessor(); // Extract + plan->runNextProcessor(); // Log + + // Verify labels + REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_0 value:label_e")); + REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_1 value:label_f")); + REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_2 value:label_b")); + REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_3 value:label_i")); + REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_4 value:label_g")); +}
