Github user jzonthemtn commented on a diff in the pull request:
https://github.com/apache/nifi/pull/2686#discussion_r197102298
--- Diff:
nifi-nar-bundles/nifi-deeplearning4j-bundle/nifi-deeplearning4j-processors/src/main/java/org/apache/nifi/processors/deeplearning4j/DeepLearning4JMultiLayerPredictor.java
---
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.processors.deeplearning4j;
+import org.apache.nifi.annotation.behavior.EventDriven;
+import org.apache.nifi.annotation.behavior.InputRequirement;
+import org.apache.nifi.annotation.behavior.InputRequirement.Requirement;
+import org.apache.nifi.annotation.behavior.SupportsBatching;
+import org.apache.nifi.annotation.behavior.WritesAttribute;
+import org.apache.nifi.annotation.behavior.WritesAttributes;
+import org.apache.nifi.annotation.documentation.CapabilityDescription;
+import org.apache.nifi.annotation.documentation.Tags;
+import org.apache.nifi.annotation.lifecycle.OnStopped;
+import org.apache.nifi.components.PropertyDescriptor;
+import org.apache.nifi.flowfile.FlowFile;
+import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.processor.ProcessSession;
+import org.apache.nifi.processor.Relationship;
+import org.apache.nifi.processor.exception.ProcessException;
+import org.apache.nifi.stream.io.StreamUtils;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.util.ModelSerializer;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import com.google.gson.Gson;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+@EventDriven
+@SupportsBatching
+@InputRequirement(Requirement.INPUT_REQUIRED)
+@Tags({"deeplearning4j", "dl4j", "multilayer", "predict",
"classification", "regression", "deep", "learning", "neural", "network"})
+@CapabilityDescription("The DeepLearning4JMultiLayerPredictor predicts one
or more value(s) based on provided deeplearning4j
(https://github.com/deeplearning4j) model and the content of a FlowFile. "
+ + "The processor supports both classification and regression by
extracting the record from the FlowFile body and applying the model. "
+ + "The processor supports batch by allowing multiple records to be
passed in the FlowFile body with each record separated by the 'Record
Separator' property. "
+ + "Each record can contain multiple fields with each field separated
by the 'Field Separator' property."
+ )
+@WritesAttributes({
+ @WritesAttribute(attribute =
AbstractDeepLearning4JProcessor.DEEPLEARNING4J_ERROR_MESSAGE, description =
"Deeplearning4J error message"),
+ @WritesAttribute(attribute =
AbstractDeepLearning4JProcessor.DEEPLEARNING4J_OUTPUT_SHAPE, description =
"Deeplearning4J output shape"),
+ })
+public class DeepLearning4JMultiLayerPredictor extends
AbstractDeepLearning4JProcessor {
+
+ static final Relationship REL_SUCCESS = new
Relationship.Builder().name("success")
+ .description("Successful DeepLearning4j results are routed to
this relationship").build();
+
+ static final Relationship REL_FAILURE = new
Relationship.Builder().name("failure")
+ .description("Failed DeepLearning4j results are routed to this
relationship").build();
+
+ protected final Gson gson = new Gson();
+
+ protected MultiLayerNetwork model = null;
+
+ @OnStopped
+ public void close() {
+ getLogger().info("Closing");
+ model = null;
+ }
+
+ private static final Set<Relationship> relationships;
+ private static final List<PropertyDescriptor> propertyDescriptors;
+ static {
+ final Set<Relationship> tempRelationships = new HashSet<>();
+ tempRelationships.add(REL_SUCCESS);
+ tempRelationships.add(REL_FAILURE);
+ relationships = Collections.unmodifiableSet(tempRelationships);
+ final List<PropertyDescriptor> tempDescriptors = new ArrayList<>();
+ tempDescriptors.add(MODEL_FILE);
+ tempDescriptors.add(RECORD_DIMENSIONS);
+ tempDescriptors.add(CHARSET);
+ tempDescriptors.add(FIELD_SEPARATOR);
+ tempDescriptors.add(RECORD_SEPARATOR);
+ propertyDescriptors =
Collections.unmodifiableList(tempDescriptors);
+ }
+
+ @Override
+ public Set<Relationship> getRelationships() {
+ return relationships;
+ }
+
+ @Override
+ public final List<PropertyDescriptor>
getSupportedPropertyDescriptors() {
+ return propertyDescriptors;
+ }
+
+ protected synchronized MultiLayerNetwork getModel(ProcessContext
context) throws IOException {
+ if ( model == null ) {
+ String modelFile =
context.getProperty(MODEL_FILE).evaluateAttributeExpressions().getValue();
+ getLogger().debug("Loading model from {}", new Object[]
{modelFile});
+
+ long start = System.currentTimeMillis();
+ model =
ModelSerializer.restoreMultiLayerNetwork(modelFile,false);
+ long end = System.currentTimeMillis();
+
+ getLogger().info("Time to load model " + (end-start) + " ms");
+ }
+ return (MultiLayerNetwork)model;
+ }
+
+ @Override
+ public void onTrigger(final ProcessContext context, final
ProcessSession session) throws ProcessException {
+ FlowFile flowFile = session.get();
+ if ( flowFile == null ) {
+ return;
+ }
+
+ Charset charset =
Charset.forName(context.getProperty(CHARSET).evaluateAttributeExpressions(flowFile).getValue());
+ if ( flowFile.getSize() == 0 ) {
+ String message = "FlowFile query is empty";
+ getLogger().error(message);
+ flowFile = session.putAttribute(flowFile,
DEEPLEARNING4J_ERROR_MESSAGE, message);
+ session.transfer(flowFile, REL_FAILURE);
+ return;
+ }
+
+ String input = null;
+ try {
+ input = getFlowFileContents(session, charset, flowFile);
+ String fieldSeparator =
context.getProperty(FIELD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue();
+ String recordSeparator =
context.getProperty(RECORD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue();
+
+ int [] dimensions = getInputDimensions(context, charset,
flowFile, fieldSeparator);
+
+ if ( getLogger().isDebugEnabled() ) {
+ getLogger().debug("Received input {} with dimensions {}",
new Object[] { input, dimensions });
+ }
+
+ MultiLayerNetwork model = getModel(context);
--- End diff --
@mans2singh Makes sense. I'm honestly not sure if `@OnScheduled` would be
any better. Maybe someone else will comment. Thanks!
---