[ 
https://issues.apache.org/jira/browse/NIFI-5166?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16472152#comment-16472152
 ] 

ASF GitHub Bot commented on NIFI-5166:
--------------------------------------

Github user mans2singh commented on a diff in the pull request:

    https://github.com/apache/nifi/pull/2686#discussion_r187655735
  
    --- Diff: 
nifi-nar-bundles/nifi-deeplearning4j-bundle/nifi-deeplearning4j-processors/src/main/java/org/apache/nifi/processors/deeplearning4j/DeepLearning4JPredictor.java
 ---
    @@ -0,0 +1,218 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.nifi.processors.deeplearning4j;
    +import org.apache.nifi.annotation.behavior.EventDriven;
    +import org.apache.nifi.annotation.behavior.InputRequirement;
    +import org.apache.nifi.annotation.behavior.InputRequirement.Requirement;
    +import org.apache.nifi.annotation.behavior.SupportsBatching;
    +import org.apache.nifi.annotation.behavior.WritesAttribute;
    +import org.apache.nifi.annotation.behavior.WritesAttributes;
    +import org.apache.nifi.annotation.documentation.CapabilityDescription;
    +import org.apache.nifi.annotation.documentation.Tags;
    +import org.apache.nifi.components.PropertyDescriptor;
    +import org.apache.nifi.flowfile.FlowFile;
    +import org.apache.nifi.processor.ProcessContext;
    +import org.apache.nifi.processor.ProcessSession;
    +import org.apache.nifi.processor.Relationship;
    +import org.apache.nifi.processor.exception.ProcessException;
    +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    +import org.nd4j.linalg.api.ndarray.INDArray;
    +import org.nd4j.linalg.factory.Nd4j;
    +import com.google.gson.Gson;
    +import java.io.ByteArrayInputStream;
    +import java.io.ByteArrayOutputStream;
    +import java.io.IOException;
    +import java.nio.charset.Charset;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.Collections;
    +import java.util.HashMap;
    +import java.util.HashSet;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Set;
    +import java.util.stream.Collectors;
    +
    +@EventDriven
    +@SupportsBatching
    +@InputRequirement(Requirement.INPUT_REQUIRED)
    +@Tags({"deeplearning4j", "dl4j", "predict", "classification", 
"regression", "deep", "learning"})
    +@CapabilityDescription("The DeepLearning4JPredictor predicts one or more 
value(s) based on provided deeplearning4j (https://github.com/deeplearning4j) 
model and the content of a FlowFile. "
    +    + "The processor supports both classification and regression by 
extracting the record from the FlowFile body and applying the model. "
    +    + "The processor supports batch by allowing multiple records to be 
passed in the FlowFile body with each record separated by the 'Record 
Separator' property. "
    +    + "Each record can contain multiple fields with each field separated 
by the 'Field Separator' property."
    +    )
    +@WritesAttributes({
    +    @WritesAttribute(attribute = 
AbstractDeepLearning4JProcessor.DEEPLEARNING4J_ERROR_MESSAGE, description = 
"Deeplearning4J error message"),
    +    @WritesAttribute(attribute = 
AbstractDeepLearning4JProcessor.DEEPLEARNING4J_OUTPUT_SHAPE, description = 
"Deeplearning4J output shape"),
    +    })
    +public class DeepLearning4JPredictor extends 
AbstractDeepLearning4JProcessor {
    +
    +    static final Relationship REL_SUCCESS = new 
Relationship.Builder().name("success")
    +            .description("Successful DeepLearning4j results are routed to 
this relationship").build();
    +
    +    static final Relationship REL_FAILURE = new 
Relationship.Builder().name("failure")
    +            .description("Failed DeepLearning4j results are routed to this 
relationship").build();
    +
    +    protected Gson gson = new Gson();
    +
    +    private static final Set<Relationship> relationships;
    +    private static final List<PropertyDescriptor> propertyDescriptors;
    +    static {
    +        final Set<Relationship> tempRelationships = new HashSet<>();
    +        tempRelationships.add(REL_SUCCESS);
    +        tempRelationships.add(REL_FAILURE);
    +        relationships = Collections.unmodifiableSet(tempRelationships);
    +        final List<PropertyDescriptor> tempDescriptors = new ArrayList<>();
    +        tempDescriptors.add(MODEL_FILE);
    +        tempDescriptors.add(RECORD_DIMENSIONS);
    +        tempDescriptors.add(CHARSET);
    +        tempDescriptors.add(FIELD_SEPARATOR);
    +        tempDescriptors.add(RECORD_SEPARATOR);
    +        propertyDescriptors = 
Collections.unmodifiableList(tempDescriptors);
    +    }
    +
    +    @Override
    +    public Set<Relationship> getRelationships() {
    +        return relationships;
    +    }
    +
    +    @Override
    +    public final List<PropertyDescriptor> 
getSupportedPropertyDescriptors() {
    +        return propertyDescriptors;
    +    }
    +
    +    @Override
    +    public void onTrigger(final ProcessContext context, final 
ProcessSession session) throws ProcessException {
    +        FlowFile flowFile = session.get();
    +        if ( flowFile == null ) {
    +            return;
    +        }
    +
    +        Charset charset = 
Charset.forName(context.getProperty(CHARSET).evaluateAttributeExpressions(flowFile).getValue());
    +        if ( flowFile.getSize() == 0 ) {
    +            String message = "FlowFile query is empty";
    +            getLogger().error(message);
    +            flowFile = session.putAttribute(flowFile, 
DEEPLEARNING4J_ERROR_MESSAGE, message);
    +            session.transfer(flowFile, REL_FAILURE);
    +            return;
    +        }
    +
    +        String input = null;
    +        try {
    +            input = getFlowFileContents(session, charset, flowFile);
    +            String fieldSeparator = 
context.getProperty(FIELD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue();
    +            String recordSeparator = 
context.getProperty(RECORD_SEPARATOR).evaluateAttributeExpressions(flowFile).getValue();
    +
    +            int [] dimensions = getInputDimensions(context, charset, 
flowFile, fieldSeparator);
    +
    +            if ( getLogger().isDebugEnabled() )    {
    +                getLogger().debug("Received input {} with dimensions {}", 
new Object[] { input, dimensions });
    +            }
    +
    +            MultiLayerNetwork model = getModel(context);
    +
    +            long startTimeMillis = System.currentTimeMillis();
    +
    +            String [] inputRecords = input.split(recordSeparator);
    +
    +            List<INDArray> features = Arrays.stream(inputRecords).map(
    +                record -> {
    +                    double [] parameters = 
Arrays.stream(record.split(fieldSeparator)).mapToDouble(
    +                             field -> Double.parseDouble(field)).toArray();
    +
    +                    INDArray featureInput = Nd4j.create(parameters, 
dimensions);
    +
    +                    if ( getLogger().isDebugEnabled() ) {
    +                        getLogger().debug("Features for record {} 
parameters {} dims {} featureInput {} ",
    +                            new Object[] {record, parameters, dimensions, 
featureInput});
    +                    }
    +
    +                    return featureInput;
    +
    +                }).collect(Collectors.toList());
    +
    +           INDArray allFeatures = Nd4j.vstack(features);
    +
    +           INDArray results = model.output(allFeatures);
    +
    +           double [][] partitionedResults = new 
double[inputRecords.length][];
    +           for (int row = 0; row < inputRecords.length; row++) {
    +                INDArray result = results.getRow(row);
    +                partitionedResults[row] = 
Nd4j.toFlattened(result).toDoubleVector();
    +           }
    +
    +           String jsonResult = gson.toJson(partitionedResults);
    +           int [] shape = results.shape();
    +           String jsonShape = gson.toJson(Arrays.copyOfRange(shape, 1, 
shape.length));
    +
    +           if ( getLogger().isDebugEnabled() ) {
    +                getLogger().debug("Prediction for inputRecords {}, dims 
{}, results {}, result.shape {}, partitionedResults {}, jsonResult {}, shape 
{}, jsonShape {}",
    +                        new Object[] {inputRecords, dimensions, results, 
Arrays.toString(results.shape()), partitionedResults, jsonResult, shape, 
jsonShape});
    +           }
    +
    +           try (ByteArrayInputStream bais = new 
ByteArrayInputStream(jsonResult.getBytes(charset))) {
    +               flowFile = session.importFrom(bais, flowFile);
    +           }
    +
    +           session.putAttribute(flowFile, DEEPLEARNING4J_OUTPUT_SHAPE, 
jsonShape);
    +
    +           final long endTimeMillis = System.currentTimeMillis();
    +
    +           session.transfer(flowFile, REL_SUCCESS);
    +
    +           session.getProvenanceReporter().send(flowFile, 
makeProvenanceUrl(context),
    +                    (endTimeMillis - startTimeMillis));
    +        } catch (Exception exception) {
    +            flowFile = populateErrorAttributes(session, flowFile, 
exception.getMessage());
    +                getLogger().error("Failed to process data due to {} for 
input {}",
    +                        new Object[]{exception.getLocalizedMessage(), 
input}, exception);
    +                session.transfer(flowFile, REL_FAILURE);
    +            context.yield();
    --- End diff --
    
    Removed yield.  Thanks for the comments.


> Create deep learning classification and regression processor
> ------------------------------------------------------------
>
>                 Key: NIFI-5166
>                 URL: https://issues.apache.org/jira/browse/NIFI-5166
>             Project: Apache NiFi
>          Issue Type: New Feature
>          Components: Extensions
>    Affects Versions: 1.6.0
>            Reporter: Mans Singh
>            Assignee: Mans Singh
>            Priority: Minor
>              Labels: Learning, classification,, deep, regression,
>             Fix For: 1.7.0
>
>   Original Estimate: 168h
>  Remaining Estimate: 168h
>
> We need a deep learning classification and regression processor.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to