You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@nifi.apache.org by ph...@apache.org on 2018/01/10 19:17:03 UTC
nifi-minifi-cpp git commit: MINIFICPP-358 Added TFExtractTopLabels
Repository: nifi-minifi-cpp
Updated Branches:
refs/heads/master b8e45cbf9 -> dec7caef7
MINIFICPP-358 Added TFExtractTopLabels
This closes #232.
Signed-off-by: Marc Parisi <ph...@apache.org>
Project: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/repo
Commit: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/commit/dec7caef
Tree: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/tree/dec7caef
Diff: http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/diff/dec7caef
Branch: refs/heads/master
Commit: dec7caef7dd348a1fa80d0f4db5d4fc979f785fa
Parents: b8e45cb
Author: Andy I. Christianson <an...@andyic.org>
Authored: Mon Jan 8 17:13:27 2018 -0500
Committer: Marc Parisi <ph...@apache.org>
Committed: Wed Jan 10 14:16:44 2018 -0500
----------------------------------------------------------------------
PROCESSORS.md | 116 ++++++++++++-
README.md | 23 +--
.../tensorflow/TFConvertImageToTensor.cpp | 2 +-
extensions/tensorflow/TFExtractTopLabels.cpp | 173 +++++++++++++++++++
extensions/tensorflow/TFExtractTopLabels.h | 92 ++++++++++
.../test/tensorflow-tests/TensorFlowTests.cpp | 117 ++++++++++++-
6 files changed, 498 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/PROCESSORS.md
----------------------------------------------------------------------
diff --git a/PROCESSORS.md b/PROCESSORS.md
index 6c4eedf..85b08d9 100644
--- a/PROCESSORS.md
+++ b/PROCESSORS.md
@@ -18,24 +18,29 @@
## Table of Contents
- [AppendHostInfo](#appendhostinfo)
+- [CompressContent](#compresscontent)
+* [ConsumeMQTT](#consumeMQTT)
- [ExecuteProcess](#executeprocess)
- [ExecuteScript](#executescript)
+- [ExtractText](#extracttext)
+- [FocusArchiveEntry](#focusarchiveentry)
+- [GenerateFlowFile](#generateflowfile)
- [GetFile](#getfile)
- [GetUSBCamera](#getusbcamera)
-- [GenerateFlowFile](#generateflowfile)
- [InvokeHTTP](#invokehttp)
-- [LogAttribute](#logattribute)
- [ListenHTTP](#listenhttp)
- [ListenSyslog](#listensyslog)
+- [LogAttribute](#logattribute)
+- [ManipulateArchive](#manipulatearchive)
+- [MergeContent](#mergecontent)
+- [PublishKafka](#publishkafka)
+* [PublishMQTT](PROCESSORS.md#publishMQTT)
- [PutFile](#putfile)
- [TailFile](#tailfile)
-- [MergeContent](#mergecontent)
-- [ExtractText](#extracttext)
-- [CompressContent](#compresscontent)
-- [FocusArchiveEntry](#focusarchiveentry)
+- [TFApplyGraph](#tfapplygraph)
+- [TFConvertImageToTensor](#tfconvertimagetotensor)
+- [TFExtractTopLabels](#tfextracttoplabels)
- [UnfocusArchiveEntry](#unfocusarchiveentry)
-- [ManipulateArchive](#manipulatearchive)
-- [PublishKafka](#publishkafka)
## AppendHostInfo
@@ -535,6 +540,101 @@ default values, and whether a property supports the NiFi Expression Language.
| - | - |
| success | All FlowFiles are routed to this Relationship. |
+## TFApplyGraph
+
+### Description
+
+Applies a TensorFlow graph to the tensor protobuf supplied as input. The tensor
+is fed into the node specified by the `Input Node` property. The output
+FlowFile is a tensor protobuf extracted from the node specified by the `Output
+Node` property.
+
+TensorFlow graphs are read dynamically by feeding a graph protobuf to the
+processor with the `tf.type` property set to `graph`.
+
+### Properties
+
+In the list below, the names of required properties appear in bold. Any other
+properties (not in bold) are considered optional. The table also indicates any
+default values, and whether a property supports the NiFi Expression Language.
+
+| Name | Default Value | Allowable Values | Description |
+| - | - | - | - |
+| **Input Node** | | | The node of the TensorFlow graph to feed tensor inputs to |
+| **Output Node** | | | The node of the TensorFlow graph to read tensor outputs from |
+
+### Relationships
+
+| Name | Description |
+| - | - |
+| success | Successful graph application outputs as tensor protobufs |
+| retry | Inputs which fail graph application but may work if sent again |
+| failure | Failures which will not work if retried |
+
+## TFConvertImageToTensor
+
+### Description
+
+Converts the input image file into a tensor protobuf. The image will be resized
+to the given output tensor dimensions.
+
+### Properties
+
+In the list below, the names of required properties appear in bold. Any other
+properties (not in bold) are considered optional. The table also indicates any
+default values, and whether a property supports the NiFi Expression Language.
+
+| Name | Default Value | Allowable Values | Description |
+| - | - | - | - |
+| **Input Format** | | PNG, RAW | The format of the input image (PNG or RAW). RAW is RGB24. |
+| **Input Width** | | | The width, in pixels, of the input image. |
+| **Input Height** | | | The height, in pixels, of the input image. |
+| **Output Width** | | | The width, in pixels, of the output image. |
+| **Output Height** | | | The height, in pixels, of the output image. |
+| **Channels** | 3 | | The number of channels (e.g. 3 for RGB, 4 for RGBA) in the input image. |
+
+### Relationships
+
+| Name | Description |
+| - | - |
+| success | Successfully read tensor protobufs |
+| failure | Inputs which could not be converted to tensor protobufs |
+
+## TFExtractTopLabels
+
+### Description
+
+Extracts the top 5 labels for categorical inference models.
+
+Labels are fed as newline (`\n`) -delimited files where each line is a label
+for the tensor index equivalent to the line number. Label files must be fed in
+with the `tf.type` property set to `labels`.
+
+The top 5 labels are written to the following attributes:
+
+- `top_label_0`
+- `top_label_1`
+- `top_label_2`
+- `top_label_3`
+- `top_label_4`
+
+### Properties
+
+In the list below, the names of required properties appear in bold. Any other
+properties (not in bold) are considered optional. The table also indicates any
+default values, and whether a property supports the NiFi Expression Language.
+
+| Name | Default Value | Allowable Values | Description |
+| - | - | - | - |
+
+### Relationships
+
+| Name | Description |
+| - | - |
+| success | Successful FlowFiles are sent here with labels as attributes |
+| retry | Failures which might work if retried |
+| failure | Failures which will not work if retried |
+
## MergeContent
Merges a Group of FlowFiles together based on a user-defined strategy and
http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/README.md
----------------------------------------------------------------------
diff --git a/README.md b/README.md
index 40f7c97..ed0b834 100644
--- a/README.md
+++ b/README.md
@@ -47,26 +47,29 @@ Perspectives of the role of MiNiFi should be from the perspective of the agent a
MiNiFi - C++ supports the following processors:
* [AppendHostInfo](PROCESSORS.md#appendhostinfo)
+* [CompressContent](PROCESSORS.md#compresscontent)
+* [ConsumeMQTT](PROCESSORS.md#consumeMQTT)
* [ExecuteProcess](PROCESSORS.md#executeprocess)
* [ExecuteScript](PROCESSORS.md#executescript)
+* [ExtractText](PROCESSORS.md#extracttext)
+* [FocusArchiveEntry](PROCESSORS.md#focusarchiveentry)
+* [GenerateFlowFile](PROCESSORS.md#generateflowfile)
* [GetFile](PROCESSORS.md#getfile)
* [GetUSBCamera](PROCESSORS.md#getusbcamera)
-* [GenerateFlowFile](PROCESSORS.md#generateflowfile)
* [InvokeHTTP](PROCESSORS.md#invokehttp)
-* [LogAttribute](PROCESSORS.md#logattribute)
* [ListenHTTP](PROCESSORS.md#listenhttp)
* [ListenSyslog](PROCESSORS.md#listensyslog)
-* [PutFile](PROCESSORS.md#putfile)
-* [TailFile](PROCESSORS.md#tailfile)
-* [MergeContent](PROCESSORS.md#mergecontent)
-* [ExtractText](PROCESSORS.md#extracttext)
-* [CompressContent](PROCESSORS.md#compresscontent)
-* [FocusArchiveEntry](PROCESSORS.md#focusarchiveentry)
-* [UnfocusArchiveEntry](PROCESSORS.md#unfocusarchiveentry)
+* [LogAttribute](PROCESSORS.md#logattribute)
* [ManipulateArchive](PROCESSORS.md#manipulatearchive)
+* [MergeContent](PROCESSORS.md#mergecontent)
* [PublishKafka](PROCESSORS.md#publishkafka)
* [PublishMQTT](PROCESSORS.md#publishMQTT)
-* [ConsumeMQTT](PROCESSORS.md#consumeMQTT)
+* [PutFile](PROCESSORS.md#putfile)
+* [TailFile](PROCESSORS.md#tailfile)
+* [TFApplyGraph](PROCESSORS.md#tfapplygraph)
+* [TFConvertImageToTensor](PROCESSORS.md#tfconvertimagetotensor)
+* [TFExtractTopLabels](PROCESSORS.md#tfextracttoplabels)
+* [UnfocusArchiveEntry](PROCESSORS.md#unfocusarchiveentry)
## Caveats
* 0.4.0 represents a non-GA release, APIs and interfaces are subject to change
http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/extensions/tensorflow/TFConvertImageToTensor.cpp
----------------------------------------------------------------------
diff --git a/extensions/tensorflow/TFConvertImageToTensor.cpp b/extensions/tensorflow/TFConvertImageToTensor.cpp
index be5e7a1..803ea48 100644
--- a/extensions/tensorflow/TFConvertImageToTensor.cpp
+++ b/extensions/tensorflow/TFConvertImageToTensor.cpp
@@ -27,7 +27,7 @@ namespace processors {
core::Property TFConvertImageToTensor::ImageFormat( // NOLINT
"Input Format",
- "The node of the TensorFlow graph to feed tensor inputs to (PNG or RAW). RAW is RGB24.", "");
+ "The format of the input image (PNG or RAW). RAW is RGB24.", "");
core::Property TFConvertImageToTensor::InputWidth( // NOLINT
"Input Width",
"The width, in pixels, of the input image.", "");
http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/extensions/tensorflow/TFExtractTopLabels.cpp
----------------------------------------------------------------------
diff --git a/extensions/tensorflow/TFExtractTopLabels.cpp b/extensions/tensorflow/TFExtractTopLabels.cpp
new file mode 100644
index 0000000..723f7dc
--- /dev/null
+++ b/extensions/tensorflow/TFExtractTopLabels.cpp
@@ -0,0 +1,173 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFExtractTopLabels.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace org {
+namespace apache {
+namespace nifi {
+namespace minifi {
+namespace processors {
+
+core::Relationship TFExtractTopLabels::Success( // NOLINT
+ "success",
+ "Successful FlowFiles are sent here with labels as attributes");
+core::Relationship TFExtractTopLabels::Retry( // NOLINT
+ "retry",
+ "Failures which might work if retried");
+core::Relationship TFExtractTopLabels::Failure( // NOLINT
+ "failure",
+ "Failures which will not work if retried");
+
+void TFExtractTopLabels::initialize() {
+ std::set<core::Property> properties;
+ setSupportedProperties(std::move(properties));
+
+ std::set<core::Relationship> relationships;
+ relationships.insert(Success);
+ relationships.insert(Retry);
+ relationships.insert(Failure);
+ setSupportedRelationships(std::move(relationships));
+}
+
+void TFExtractTopLabels::onSchedule(core::ProcessContext *context, core::ProcessSessionFactory *sessionFactory) {
+}
+
+void TFExtractTopLabels::onTrigger(const std::shared_ptr<core::ProcessContext> &context,
+ const std::shared_ptr<core::ProcessSession> &session) {
+ auto flow_file = session->get();
+
+ if (!flow_file) {
+ return;
+ }
+
+ try {
+
+ // Read labels
+ std::string tf_type;
+ flow_file->getAttribute("tf.type", tf_type);
+ std::shared_ptr<std::vector<std::string>> labels;
+
+ {
+ std::lock_guard<std::mutex> guard(labels_mtx_);
+
+ if (tf_type == "labels") {
+ logger_->log_info("Reading new labels...");
+ auto new_labels = std::make_shared<std::vector<std::string>>();
+ LabelsReadCallback cb(new_labels);
+ session->read(flow_file, &cb);
+ labels_ = new_labels;
+ logger_->log_info("Read %d new labels", labels_->size());
+ session->remove(flow_file);
+ return;
+ }
+
+ labels = labels_;
+ }
+
+ // Read input tensor from flow file
+ auto input_tensor_proto = std::make_shared<tensorflow::TensorProto>();
+ TensorReadCallback tensor_cb(input_tensor_proto);
+ session->read(flow_file, &tensor_cb);
+
+ tensorflow::Tensor input;
+ input.FromProto(*input_tensor_proto);
+ auto input_flat = input.flat<float>();
+
+ std::vector<std::pair<uint64_t, float>> scores;
+
+ for (int i = 0; i < input_flat.size(); i++) {
+ scores.emplace_back(std::make_pair(i, input_flat(i)));
+ }
+
+ std::sort(scores.begin(), scores.end(), [](const std::pair<uint64_t, float> &a,
+ const std::pair<uint64_t, float> &b) {
+ return a.second > b.second;
+ });
+
+ for (int i = 0; i < 5 && i < scores.size(); i++) {
+ if (!labels || scores[i].first > labels->size()) {
+ logger_->log_error("Label index is out of range (are the correct labels loaded?); routing to retry...");
+ session->transfer(flow_file, Retry);
+ return;
+ }
+ flow_file->addAttribute("tf.top_label_" + std::to_string(i), labels->at(scores[i].first));
+ }
+
+ session->transfer(flow_file, Success);
+
+ } catch (std::exception &exception) {
+ logger_->log_error("Caught Exception %s", exception.what());
+ session->transfer(flow_file, Failure);
+ this->yield();
+ } catch (...) {
+ logger_->log_error("Caught Exception");
+ session->transfer(flow_file, Failure);
+ this->yield();
+ }
+}
+
+int64_t TFExtractTopLabels::LabelsReadCallback::process(std::shared_ptr<io::BaseStream> stream) {
+ int64_t total_read = 0;
+ std::string label;
+ uint64_t max_label_len = 65536;
+ label.resize(max_label_len);
+ std::string buf;
+ uint64_t label_size = 0;
+ uint64_t buf_size = 8096;
+ buf.resize(buf_size);
+
+ while (total_read < stream->getSize()) {
+ auto read = stream->read(reinterpret_cast<uint8_t *>(&buf[0]), static_cast<int>(buf_size));
+
+ for (auto i = 0; i < read; i++) {
+ if (buf[i] == '\n' || total_read + i == stream->getSize()) {
+ labels_->emplace_back(label.substr(0, label_size));
+ label_size = 0;
+ } else {
+ label[label_size] = buf[i];
+ label_size++;
+ }
+ }
+
+ total_read += read;
+ }
+
+ return total_read;
+}
+
+int64_t TFExtractTopLabels::TensorReadCallback::process(std::shared_ptr<io::BaseStream> stream) {
+ std::string tensor_proto_buf;
+ tensor_proto_buf.resize(stream->getSize());
+ auto num_read = stream->readData(reinterpret_cast<uint8_t *>(&tensor_proto_buf[0]),
+ static_cast<int>(stream->getSize()));
+
+ if (num_read != stream->getSize()) {
+ throw std::runtime_error("TensorReadCallback failed to fully read flow file input stream");
+ }
+
+ tensor_proto_->ParseFromString(tensor_proto_buf);
+ return num_read;
+}
+
+} /* namespace processors */
+} /* namespace minifi */
+} /* namespace nifi */
+} /* namespace apache */
+} /* namespace org */
http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/extensions/tensorflow/TFExtractTopLabels.h
----------------------------------------------------------------------
diff --git a/extensions/tensorflow/TFExtractTopLabels.h b/extensions/tensorflow/TFExtractTopLabels.h
new file mode 100644
index 0000000..58ed57f
--- /dev/null
+++ b/extensions/tensorflow/TFExtractTopLabels.h
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NIFI_MINIFI_CPP_TFEXTRACTTOPLABELS_H
+#define NIFI_MINIFI_CPP_TFEXTRACTTOPLABELS_H
+
+#include <atomic>
+
+#include <core/Resource.h>
+#include <core/Processor.h>
+#include <tensorflow/core/public/session.h>
+#include <concurrentqueue.h>
+
+namespace org {
+namespace apache {
+namespace nifi {
+namespace minifi {
+namespace processors {
+
+class TFExtractTopLabels : public core::Processor {
+ public:
+ explicit TFExtractTopLabels(const std::string &name, uuid_t uuid = nullptr)
+ : Processor(name, uuid),
+ logger_(logging::LoggerFactory<TFExtractTopLabels>::getLogger()) {
+ }
+
+ static core::Relationship Success;
+ static core::Relationship Retry;
+ static core::Relationship Failure;
+
+ void initialize() override;
+ void onSchedule(core::ProcessContext *context, core::ProcessSessionFactory *sessionFactory) override;
+ void onTrigger(core::ProcessContext *context, core::ProcessSession *session) override {
+ logger_->log_error("onTrigger invocation with raw pointers is not implemented");
+ }
+ void onTrigger(const std::shared_ptr<core::ProcessContext> &context,
+ const std::shared_ptr<core::ProcessSession> &session) override;
+
+ class LabelsReadCallback : public InputStreamCallback {
+ public:
+ explicit LabelsReadCallback(std::shared_ptr<std::vector<std::string>> labels)
+ : labels_(std::move(labels)) {
+ }
+ ~LabelsReadCallback() override = default;
+ int64_t process(std::shared_ptr<io::BaseStream> stream) override;
+
+ private:
+ std::shared_ptr<std::vector<std::string>> labels_;
+ };
+
+ class TensorReadCallback : public InputStreamCallback {
+ public:
+ explicit TensorReadCallback(std::shared_ptr<tensorflow::TensorProto> tensor_proto)
+ : tensor_proto_(std::move(tensor_proto)) {
+ }
+ ~TensorReadCallback() override = default;
+ int64_t process(std::shared_ptr<io::BaseStream> stream) override;
+
+ private:
+ std::shared_ptr<tensorflow::TensorProto> tensor_proto_;
+ };
+
+ private:
+ std::shared_ptr<logging::Logger> logger_;
+
+ std::shared_ptr<std::vector<std::string>> labels_;
+ std::mutex labels_mtx_;
+};
+
+REGISTER_RESOURCE(TFExtractTopLabels); // NOLINT
+
+} /* namespace processors */
+} /* namespace minifi */
+} /* namespace nifi */
+} /* namespace apache */
+} /* namespace org */
+
+#endif //NIFI_MINIFI_CPP_TFEXTRACTTOPLABELS_H
http://git-wip-us.apache.org/repos/asf/nifi-minifi-cpp/blob/dec7caef/libminifi/test/tensorflow-tests/TensorFlowTests.cpp
----------------------------------------------------------------------
diff --git a/libminifi/test/tensorflow-tests/TensorFlowTests.cpp b/libminifi/test/tensorflow-tests/TensorFlowTests.cpp
index e9499de..4bca07a 100644
--- a/libminifi/test/tensorflow-tests/TensorFlowTests.cpp
+++ b/libminifi/test/tensorflow-tests/TensorFlowTests.cpp
@@ -25,14 +25,14 @@
#include <processors/GetFile.h>
#include <processors/LogAttribute.h>
#include <TFConvertImageToTensor.h>
+#include <TFExtractTopLabels.h>
#include "TFApplyGraph.h"
-#include "TFConvertImageToTensor.h"
#define CATCH_CONFIG_MAIN
#include "../TestBase.h"
-TEST_CASE("TensorFlow: Apply Graph", "[executescriptTensorFlowApplyGraph]") { // NOLINT
+TEST_CASE("TensorFlow: Apply Graph", "[tfApplyGraph]") { // NOLINT
TestController testController;
LogTestController::getInstance().setTrace<TestPlan>();
@@ -127,7 +127,7 @@ TEST_CASE("TensorFlow: Apply Graph", "[executescriptTensorFlowApplyGraph]") { //
// Read test TensorFlow graph into TFApplyGraph
plan->runNextProcessor([&get_file, &in_graph_file, &plan](const std::shared_ptr<core::ProcessContext> context,
- const std::shared_ptr<core::ProcessSession> session) {
+ const std::shared_ptr<core::ProcessSession> session) {
// Intercept the call so that we can add an attr (won't be required when we have UpdateAttribute processor)
auto flow_file = session->create();
session->import(in_graph_file, flow_file, false);
@@ -171,7 +171,7 @@ TEST_CASE("TensorFlow: Apply Graph", "[executescriptTensorFlowApplyGraph]") { //
}
}
-TEST_CASE("TensorFlow: ConvertImageToTensor", "[executescriptTensorFlowConvertImageToTensor]") { // NOLINT
+TEST_CASE("TensorFlow: ConvertImageToTensor", "[tfConvertImageToTensor]") { // NOLINT
TestController testController;
LogTestController::getInstance().setTrace<TestPlan>();
@@ -266,8 +266,8 @@ TEST_CASE("TensorFlow: ConvertImageToTensor", "[executescriptTensorFlowConvertIm
// Write test input image
{
// 2x2 single-channel 8 bit per channel
- const uint8_t in_img_raw[2*2] = {0, 0,
- 0, 0};
+ const uint8_t in_img_raw[2 * 2] = {0, 0,
+ 0, 0};
std::ofstream in_file_stream(in_img_file);
in_file_stream << in_img_raw;
@@ -299,3 +299,108 @@ TEST_CASE("TensorFlow: ConvertImageToTensor", "[executescriptTensorFlowConvertIm
1})); // Channels
}
}
+
+TEST_CASE("TensorFlow: Extract Top Labels", "[tfExtractTopLabels]") { // NOLINT
+ TestController testController;
+
+ LogTestController::getInstance().setTrace<TestPlan>();
+ LogTestController::getInstance().setTrace<processors::TFExtractTopLabels>();
+ LogTestController::getInstance().setTrace<processors::GetFile>();
+ LogTestController::getInstance().setTrace<processors::LogAttribute>();
+
+ auto plan = testController.createPlan();
+ auto repo = std::make_shared<TestRepository>();
+
+ // Define directory for input protocol buffers
+ std::string in_dir("/tmp/gt.XXXXXX");
+ REQUIRE(testController.createTempDirectory(&in_dir[0]) != nullptr);
+
+ // Define input labels file
+ std::string in_labels_file(in_dir);
+ in_labels_file.append("/in_labels");
+
+ // Define input tensor protocol buffer file
+ std::string in_tensor_file(in_dir);
+ in_tensor_file.append("/tensor.pb");
+
+ // Build MiNiFi processing graph
+ auto get_file = plan->addProcessor(
+ "GetFile",
+ "Get Input");
+ plan->setProperty(
+ get_file,
+ processors::GetFile::Directory.getName(), in_dir);
+ plan->setProperty(
+ get_file,
+ processors::GetFile::KeepSourceFile.getName(),
+ "false");
+ plan->addProcessor(
+ "LogAttribute",
+ "Log Pre Extract",
+ core::Relationship("success", "description"),
+ true);
+ auto tf_apply = plan->addProcessor(
+ "TFExtractTopLabels",
+ "Extract",
+ core::Relationship("success", "description"),
+ true);
+ plan->addProcessor(
+ "LogAttribute",
+ "Log Post Extract",
+ core::Relationship("success", "description"),
+ true);
+
+ // Build test labels
+ {
+ // Write labels
+ std::ofstream in_file_stream(in_labels_file);
+ in_file_stream << "label_a\nlabel_b\nlabel_c\nlabel_d\nlabel_e\nlabel_f\nlabel_g\nlabel_h\nlabel_i\nlabel_j\n";
+ }
+
+ // Read labels
+ plan->runNextProcessor([&get_file, &in_labels_file, &plan](const std::shared_ptr<core::ProcessContext> context,
+ const std::shared_ptr<core::ProcessSession> session) {
+ // Intercept the call so that we can add an attr (won't be required when we have UpdateAttribute processor)
+ auto flow_file = session->create();
+ session->import(in_labels_file, flow_file, false);
+ flow_file->addAttribute("tf.type", "labels");
+ session->transfer(flow_file, processors::GetFile::Success);
+ session->commit();
+ });
+
+ plan->runNextProcessor(); // Log
+ plan->runNextProcessor(); // Extract (loads labels)
+
+ // Write input tensor
+ {
+ tensorflow::Tensor input(tensorflow::DT_FLOAT, {10});
+ input.flat<float>().data()[0] = 0.000f;
+ input.flat<float>().data()[1] = 0.400f;
+ input.flat<float>().data()[2] = 0.100f;
+ input.flat<float>().data()[3] = 0.005f;
+ input.flat<float>().data()[4] = 1.000f;
+ input.flat<float>().data()[5] = 0.500f;
+ input.flat<float>().data()[6] = 0.200f;
+ input.flat<float>().data()[7] = 0.000f;
+ input.flat<float>().data()[8] = 0.300f;
+ input.flat<float>().data()[9] = 0.000f;
+ tensorflow::TensorProto tensor_proto;
+ input.AsProtoTensorContent(&tensor_proto);
+
+ std::ofstream in_file_stream(in_tensor_file);
+ tensor_proto.SerializeToOstream(&in_file_stream);
+ }
+
+ plan->reset();
+ plan->runNextProcessor(); // GetFile
+ plan->runNextProcessor(); // Log
+ plan->runNextProcessor(); // Extract
+ plan->runNextProcessor(); // Log
+
+ // Verify labels
+ REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_0 value:label_e"));
+ REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_1 value:label_f"));
+ REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_2 value:label_b"));
+ REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_3 value:label_i"));
+ REQUIRE(LogTestController::getInstance().contains("key:tf.top_label_4 value:label_g"));
+}