[ 
https://issues.apache.org/jira/browse/NIFI-5166?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16518809#comment-16518809
 ] 

ASF GitHub Bot commented on NIFI-5166:
--------------------------------------

Github user mans2singh commented on a diff in the pull request:

    https://github.com/apache/nifi/pull/2686#discussion_r196994774
  
    --- Diff: 
nifi-nar-bundles/nifi-deeplearning4j-bundle/nifi-deeplearning4j-processors/src/main/java/org/apache/nifi/processors/deeplearning4j/DeepLearning4JMultiLayerPredictor.java
 ---
    @@ -0,0 +1,240 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.nifi.processors.deeplearning4j;
    +import org.apache.nifi.annotation.behavior.EventDriven;
    +import org.apache.nifi.annotation.behavior.InputRequirement;
    +import org.apache.nifi.annotation.behavior.InputRequirement.Requirement;
    +import org.apache.nifi.annotation.behavior.SupportsBatching;
    +import org.apache.nifi.annotation.behavior.WritesAttribute;
    +import org.apache.nifi.annotation.behavior.WritesAttributes;
    +import org.apache.nifi.annotation.documentation.CapabilityDescription;
    +import org.apache.nifi.annotation.documentation.Tags;
    +import org.apache.nifi.annotation.lifecycle.OnStopped;
    +import org.apache.nifi.components.PropertyDescriptor;
    +import org.apache.nifi.flowfile.FlowFile;
    +import org.apache.nifi.processor.ProcessContext;
    +import org.apache.nifi.processor.ProcessSession;
    +import org.apache.nifi.processor.Relationship;
    +import org.apache.nifi.processor.exception.ProcessException;
    +import org.apache.nifi.stream.io.StreamUtils;
    +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    +import org.deeplearning4j.util.ModelSerializer;
    +import org.nd4j.linalg.api.ndarray.INDArray;
    +import org.nd4j.linalg.factory.Nd4j;
    +import com.google.gson.Gson;
    +import java.io.IOException;
    +import java.io.InputStream;
    +import java.nio.charset.Charset;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.Collections;
    +import java.util.HashMap;
    +import java.util.HashSet;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Set;
    +import java.util.stream.Collectors;
    +
    +@EventDriven
    +@SupportsBatching
    +@InputRequirement(Requirement.INPUT_REQUIRED)
    +@Tags({"deeplearning4j", "dl4j", "multilayer", "predict", 
"classification", "regression", "deep", "learning", "neural", "network"})
    +@CapabilityDescription("The DeepLearning4JMultiLayerPredictor predicts one 
or more value(s) based on provided deeplearning4j 
(https://github.com/deeplearning4j) model and the content of a FlowFile. "
    +    + "The processor supports both classification and regression by 
extracting the record from the FlowFile body and applying the model. "
    +    + "The processor supports batch by allowing multiple records to be 
passed in the FlowFile body with each record separated by the 'Record 
Separator' property. "
    +    + "Each record can contain multiple fields with each field separated 
by the 'Field Separator' property."
    +    )
    +@WritesAttributes({
    +    @WritesAttribute(attribute = 
AbstractDeepLearning4JProcessor.DEEPLEARNING4J_ERROR_MESSAGE, description = 
"Deeplearning4J error message"),
    +    @WritesAttribute(attribute = 
AbstractDeepLearning4JProcessor.DEEPLEARNING4J_OUTPUT_SHAPE, description = 
"Deeplearning4J output shape"),
    +    })
    +public class DeepLearning4JMultiLayerPredictor extends 
AbstractDeepLearning4JProcessor {
    +
    +    static final Relationship REL_SUCCESS = new 
Relationship.Builder().name("success")
    +            .description("Successful DeepLearning4j results are routed to 
this relationship").build();
    +
    +    static final Relationship REL_FAILURE = new 
Relationship.Builder().name("failure")
    +            .description("Failed DeepLearning4j results are routed to this 
relationship").build();
    +
    +    protected final Gson gson = new Gson();
    +
    +    protected MultiLayerNetwork model = null;
    +
    +    @OnStopped
    +    public void close() {
    +        getLogger().info("Closing");
    +        model = null;
    +    }
    +
    +    private static final Set<Relationship> relationships;
    +    private static final List<PropertyDescriptor> propertyDescriptors;
    +    static {
    +        final Set<Relationship> tempRelationships = new HashSet<>();
    +        tempRelationships.add(REL_SUCCESS);
    +        tempRelationships.add(REL_FAILURE);
    +        relationships = Collections.unmodifiableSet(tempRelationships);
    +        final List<PropertyDescriptor> tempDescriptors = new ArrayList<>();
    +        tempDescriptors.add(MODEL_FILE);
    +        tempDescriptors.add(RECORD_DIMENSIONS);
    +        tempDescriptors.add(CHARSET);
    +        tempDescriptors.add(FIELD_SEPARATOR);
    +        tempDescriptors.add(RECORD_SEPARATOR);
    +        propertyDescriptors = 
Collections.unmodifiableList(tempDescriptors);
    +    }
    +
    +    @Override
    +    public Set<Relationship> getRelationships() {
    +        return relationships;
    +    }
    +
    +    @Override
    +    public final List<PropertyDescriptor> 
getSupportedPropertyDescriptors() {
    +        return propertyDescriptors;
    +    }
    +
    +    protected synchronized MultiLayerNetwork getModel(ProcessContext 
context) throws IOException {
    +        if ( model == null ) {
    +            String modelFile = 
context.getProperty(MODEL_FILE).evaluateAttributeExpressions().getValue();
    +            getLogger().debug("Loading model from {}", new Object[] 
{modelFile});
    +
    +            long start = System.currentTimeMillis();
    +            model = 
ModelSerializer.restoreMultiLayerNetwork(modelFile,false);
    +            long end = System.currentTimeMillis();
    +
    +            getLogger().info("Time to load model " + (end-start) +  " ms");
    +        }
    +        return (MultiLayerNetwork)model;
    +    }
    +
    +    @Override
    +    public void onTrigger(final ProcessContext context, final 
ProcessSession session) throws ProcessException {
    +        FlowFile flowFile = session.get();
    +        if ( flowFile == null ) {
    +            return;
    +        }
    +
    +        Charset charset = 
Charset.forName(context.getProperty(CHARSET).evaluateAttributeExpressions(flowFile).getValue());
    +        if ( flowFile.getSize() == 0 ) {
    +            String message = "FlowFile query is empty";
    +            getLogger().error(message);
    +            flowFile = session.putAttribute(flowFile, 
DEEPLEARNING4J_ERROR_MESSAGE, message);
    +            session.transfer(flowFile, REL_FAILURE);
    +            return;
    +        }
    +
    +        String input = null;
    +        try {
    +            input = getFlowFileContents(session, charset, flowFile);
    +            String fieldSeparator = 
context.getProperty(FIELD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue();
    +            String recordSeparator = 
context.getProperty(RECORD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue();
    +
    +            int [] dimensions = getInputDimensions(context, charset, 
flowFile, fieldSeparator);
    +
    +            if ( getLogger().isDebugEnabled() )    {
    +                getLogger().debug("Received input {} with dimensions {}", 
new Object[] { input, dimensions });
    +            }
    +
    +            MultiLayerNetwork model = getModel(context);
    --- End diff --
    
    @jzonthemtn - I was going for lazy loading but it's not a problem to change 
it.  


> Create deep learning classification and regression processor
> ------------------------------------------------------------
>
>                 Key: NIFI-5166
>                 URL: https://issues.apache.org/jira/browse/NIFI-5166
>             Project: Apache NiFi
>          Issue Type: New Feature
>          Components: Extensions
>    Affects Versions: 1.6.0
>            Reporter: Mans Singh
>            Assignee: Mans Singh
>            Priority: Minor
>              Labels: Learning, classification,, deep, regression,
>   Original Estimate: 168h
>  Remaining Estimate: 168h
>
> We need a deep learning classification and regression processor.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to