jfrazee commented on a change in pull request #4482:
URL: https://github.com/apache/nifi/pull/4482#discussion_r475006113



##########
File path: 
nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/SampleRecord.java
##########
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.processors.standard;
+
+import org.apache.nifi.annotation.behavior.EventDriven;
+import org.apache.nifi.annotation.behavior.InputRequirement;
+import org.apache.nifi.annotation.behavior.SideEffectFree;
+import org.apache.nifi.annotation.behavior.SupportsBatching;
+import org.apache.nifi.annotation.documentation.CapabilityDescription;
+import org.apache.nifi.annotation.documentation.Tags;
+import org.apache.nifi.components.AllowableValue;
+import org.apache.nifi.components.PropertyDescriptor;
+import org.apache.nifi.components.PropertyValue;
+import org.apache.nifi.components.ValidationContext;
+import org.apache.nifi.components.ValidationResult;
+import org.apache.nifi.components.Validator;
+import org.apache.nifi.expression.ExpressionLanguageScope;
+import org.apache.nifi.flowfile.FlowFile;
+import org.apache.nifi.flowfile.attributes.CoreAttributes;
+import org.apache.nifi.processor.AbstractProcessor;
+import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.processor.ProcessSession;
+import org.apache.nifi.processor.Relationship;
+import org.apache.nifi.processor.exception.ProcessException;
+import org.apache.nifi.processor.util.StandardValidators;
+import org.apache.nifi.serialization.RecordReader;
+import org.apache.nifi.serialization.RecordReaderFactory;
+import org.apache.nifi.serialization.RecordSetWriter;
+import org.apache.nifi.serialization.RecordSetWriterFactory;
+import org.apache.nifi.serialization.WriteResult;
+import org.apache.nifi.serialization.record.Record;
+import org.apache.nifi.serialization.record.RecordSchema;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+@EventDriven
+@SideEffectFree
+@SupportsBatching
+@Tags({"record", "sample"})
+@InputRequirement(InputRequirement.Requirement.INPUT_REQUIRED)
+@CapabilityDescription("Samples the records of a FlowFile based on a specified 
sampling strategy (such as Reservoir Sampling). The resulting "
+        + "FlowFile may be of a fixed number of records (in the case of 
reservoir-based algorithms) or some subset of the total number of records "
+        + "(in the case of probabilistic sampling), or a deterministic number 
of records (in the case of interval sampling).")
+public class SampleRecord extends AbstractProcessor {
+
+    static final String INTERVAL_SAMPLING_KEY = "interval";
+    static final String PROBABILISTIC_SAMPLING_KEY = "probabilistic";
+    static final String RESERVOIR_SAMPLING_KEY = "reservoir";
+
+    static final AllowableValue INTERVAL_SAMPLING = new 
AllowableValue(INTERVAL_SAMPLING_KEY, "Interval Sampling",
+            "Selects every Nth record where N is the value of the 'Interval 
Value' property");
+    static final AllowableValue PROBABILISTIC_SAMPLING = new 
AllowableValue(PROBABILISTIC_SAMPLING_KEY, "Probabilistic Sampling",
+            "Selects each record with probability P where P is the value of 
the 'Selection Probability' property");
+    static final AllowableValue RESERVOIR_SAMPLING = new 
AllowableValue(RESERVOIR_SAMPLING_KEY, "Reservoir Sampling",
+            "Creates a sample of K records where each record has equal 
probability of being included, where K is "
+                    + "the value of the 'Reservoir Size' property");
+
+    static final PropertyDescriptor RECORD_READER_FACTORY = new 
PropertyDescriptor.Builder()
+            .name("record-reader")
+            .displayName("Record Reader")
+            .description("Specifies the Controller Service to use for parsing 
incoming data and determining the data's schema")
+            .identifiesControllerService(RecordReaderFactory.class)
+            .required(true)
+            .build();
+    static final PropertyDescriptor RECORD_WRITER_FACTORY = new 
PropertyDescriptor.Builder()
+            .name("record-writer")
+            .displayName("Record Writer")
+            .description("Specifies the Controller Service to use for writing 
results to a FlowFile")
+            .identifiesControllerService(RecordSetWriterFactory.class)
+            .required(true)
+            .build();
+    static final PropertyDescriptor SAMPLING_STRATEGY = new 
PropertyDescriptor.Builder()
+            .name("sample-record-sampling-strategy")
+            .displayName("Sampling Strategy")
+            .description("Specifies which method to use for sampling records 
from the incoming FlowFile")
+            .allowableValues(INTERVAL_SAMPLING, PROBABILISTIC_SAMPLING, 
RESERVOIR_SAMPLING)
+            .required(true)
+            .defaultValue(RESERVOIR_SAMPLING.getValue())
+            .addValidator(Validator.VALID)
+            .build();
+    static final PropertyDescriptor SAMPLING_INTERVAL = new 
PropertyDescriptor.Builder()
+            .name("sample-record-interval")
+            .displayName("Sampling Interval")
+            .description("Specifies the number of records to skip before 
writing a record to the outgoing FlowFile. This property is only "
+                    + "used if Sampling Strategy is set to Interval Sampling. 
A value of zero (0) will cause all records to be included in the"
+                    + "outgoing FlowFile.")
+            .required(false)
+            .addValidator(StandardValidators.NON_NEGATIVE_INTEGER_VALIDATOR)
+            
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
+            .build();
+    static final PropertyDescriptor SAMPLING_PROBABILITY = new 
PropertyDescriptor.Builder()
+            .name("sample-record-probability")
+            .displayName("Sampling Probability")
+            .description("Specifies the probability (as a percent from 0-100) 
of a record being included in the outgoing FlowFile. This property is only "
+                    + "used if Sampling Strategy is set to Probabilistic 
Sampling. A value of zero (0) will cause no records to be included in the"
+                    + "outgoing FlowFile, and a value of 100 will cause all 
records to be included in the outgoing FlowFile..")
+            .required(false)
+            .addValidator(StandardValidators.createLongValidator(0, 100, true))
+            
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
+            .build();
+    static final PropertyDescriptor RESERVOIR_SIZE = new 
PropertyDescriptor.Builder()
+            .name("sample-record-reservoir")
+            .displayName("Reservoir Size")
+            .description("Specifies the number of records to write to the 
outgoing FlowFile. This property is only used if Sampling Strategy is set to "
+                    + "reservoir-based strategies such as Reservoir Sampling 
or Weighted Random Sampling.")
+            .required(false)
+            .addValidator(StandardValidators.NON_NEGATIVE_INTEGER_VALIDATOR)
+            
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
+            .build();
+
+    static final PropertyDescriptor RANDOM_SEED = new 
PropertyDescriptor.Builder()
+            .name("sample-record-random-seed")
+            .displayName("Random Seed")
+            .description("Specifies a particular number to use as the seed for 
the random number generator (used by probabilistic strategies). "
+                    + "Setting this property will ensure the same records are 
selected even when using probabilistic strategies.")
+            .required(false)
+            .addValidator(StandardValidators.LONG_VALIDATOR)
+            
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
+            .build();
+
+    public static final Relationship REL_ORIGINAL = new Relationship.Builder()
+            .name("original")
+            .description("The original FlowFile is routed to this relationship 
if sampling is successful")
+            .autoTerminateDefault(true)
+            .build();
+    public static final Relationship REL_SUCCESS = new Relationship.Builder()
+            .name("success")
+            .description("The FlowFile is routed to this relationship if the 
sampling completed successfully")
+            .autoTerminateDefault(true)
+            .build();
+    public static final Relationship REL_FAILURE = new Relationship.Builder()
+            .name("failure")
+            .description("If a FlowFile fails processing for any reason (for 
example, any record "
+                    + "is not valid), the original FlowFile will be routed to 
this relationship")
+            .build();
+
+    private static final List<PropertyDescriptor> properties;
+    private static final Set<Relationship> relationships;
+
+    static {
+        final List<PropertyDescriptor> props = new ArrayList<>();
+        props.add(RECORD_READER_FACTORY);
+        props.add(RECORD_WRITER_FACTORY);
+        props.add(SAMPLING_STRATEGY);
+        props.add(SAMPLING_INTERVAL);
+        props.add(SAMPLING_PROBABILITY);
+        props.add(RESERVOIR_SIZE);
+        props.add(RANDOM_SEED);
+        properties = Collections.unmodifiableList(props);
+
+        final Set<Relationship> r = new HashSet<>();
+        r.add(REL_SUCCESS);
+        r.add(REL_FAILURE);
+        r.add(REL_ORIGINAL);
+        relationships = Collections.unmodifiableSet(r);
+    }
+
+    @Override
+    public Set<Relationship> getRelationships() {
+        return relationships;
+    }
+
+    @Override
+    protected List<PropertyDescriptor> getSupportedPropertyDescriptors() {
+        return properties;
+    }
+
+    @Override
+    protected Collection<ValidationResult> customValidate(ValidationContext 
validationContext) {
+
+        final List<ValidationResult> results = new 
ArrayList<>(super.customValidate(validationContext));
+
+        final String samplingStrategyValue = 
validationContext.getProperty(SAMPLING_STRATEGY).getValue();
+        if (INTERVAL_SAMPLING_KEY.equals(samplingStrategyValue)) {
+            final PropertyValue pd = 
validationContext.getProperty(SAMPLING_INTERVAL);
+            if (!pd.isSet()) {
+                results.add(new 
ValidationResult.Builder().subject(INTERVAL_SAMPLING.getDisplayName()).valid(false)
+                        .explanation(SAMPLING_INTERVAL.getDisplayName() + " 
property must be set to use " + INTERVAL_SAMPLING.getDisplayName() + " 
strategy")
+                        .build());
+            }
+        } else if (PROBABILISTIC_SAMPLING_KEY.equals(samplingStrategyValue)) {
+            final PropertyValue samplingProbabilityProperty = 
validationContext.getProperty(SAMPLING_PROBABILITY);
+            if (!samplingProbabilityProperty.isSet()) {
+                results.add(new 
ValidationResult.Builder().subject(PROBABILISTIC_SAMPLING.getDisplayName()).valid(false)
+                        .explanation(SAMPLING_PROBABILITY.getDisplayName() + " 
property must be set to use " + PROBABILISTIC_SAMPLING.getDisplayName() + " 
strategy")
+                        .build());
+            }
+        } else if (RESERVOIR_SAMPLING_KEY.equals(samplingStrategyValue)) {
+            final PropertyValue pd = 
validationContext.getProperty(RESERVOIR_SIZE);
+            if (!pd.isSet()) {
+                results.add(new 
ValidationResult.Builder().subject(RESERVOIR_SAMPLING.getDisplayName()).valid(false)
+                        .explanation(RESERVOIR_SIZE.getDisplayName() + " 
property must be set to use " + RESERVOIR_SAMPLING.getDisplayName() + " 
strategy")
+                        .build());
+            }
+        }
+        return results;
+    }
+
+    @Override
+    public void onTrigger(ProcessContext context, ProcessSession session) 
throws ProcessException {
+
+        FlowFile flowFile = session.get();
+        if (flowFile == null) {
+            return;
+        }
+
+        FlowFile sampledFlowFile = session.create(flowFile);
+        final FlowFile outFlowFile = sampledFlowFile;
+        final Map<String, String> attributes = new HashMap<>();
+        try (final InputStream inputStream = session.read(flowFile);
+             final OutputStream outputStream = session.write(sampledFlowFile)) 
{
+
+            final RecordReaderFactory recordParserFactory = 
context.getProperty(RECORD_READER_FACTORY)
+                    .asControllerService(RecordReaderFactory.class);
+            final RecordReader reader = 
recordParserFactory.createRecordReader(flowFile, inputStream, getLogger());
+            final RecordSetWriterFactory writerFactory = 
context.getProperty(RECORD_WRITER_FACTORY)
+                    .asControllerService(RecordSetWriterFactory.class);
+            final RecordSchema writeSchema = 
writerFactory.getSchema(flowFile.getAttributes(), reader.getSchema());
+            final RecordSetWriter recordSetWriter = 
writerFactory.createWriter(getLogger(), writeSchema, outputStream, outFlowFile);
+
+            final String samplingStrategyValue = 
context.getProperty(SAMPLING_STRATEGY).getValue();
+            final SamplingStrategy samplingStrategy;
+            if (INTERVAL_SAMPLING_KEY.equals(samplingStrategyValue)) {
+                final int intervalValue = 
context.getProperty(SAMPLING_INTERVAL).evaluateAttributeExpressions(outFlowFile).asInteger();
+                samplingStrategy = new 
IntervalSamplingStrategy(recordSetWriter, intervalValue);
+            } else if 
(PROBABILISTIC_SAMPLING_KEY.equals(samplingStrategyValue)) {
+                final int probabilityValue = 
context.getProperty(SAMPLING_PROBABILITY).evaluateAttributeExpressions(outFlowFile).asInteger();
+                final Long randomSeed = 
context.getProperty(RANDOM_SEED).isSet()
+                        ? 
context.getProperty(RANDOM_SEED).evaluateAttributeExpressions(outFlowFile).asLong()
+                        : null;
+                samplingStrategy = new 
ProbabilisticSamplingStrategy(recordSetWriter, probabilityValue, randomSeed);
+            } else {
+                final int reservoirSize = 
context.getProperty(RESERVOIR_SIZE).evaluateAttributeExpressions(outFlowFile).asInteger();
+                final Long randomSeed = 
context.getProperty(RANDOM_SEED).isSet()
+                        ? 
context.getProperty(RANDOM_SEED).evaluateAttributeExpressions(outFlowFile).asLong()
+                        : null;
+                samplingStrategy = new 
ReservoirSamplingStrategy(recordSetWriter, reservoirSize, randomSeed);
+            }
+            samplingStrategy.init();
+
+            Record record;
+            while ((record = reader.nextRecord()) != null) {
+                samplingStrategy.sample(record);
+            }
+
+            WriteResult writeResult = samplingStrategy.finish();
+            try {
+                recordSetWriter.flush();
+                recordSetWriter.close();
+            } catch (final IOException ioe) {
+                getLogger().warn("Failed to close Writer for {}", new 
Object[]{outFlowFile});
+            }
+
+            attributes.put("record.count", 
String.valueOf(writeResult.getRecordCount()));
+            attributes.put(CoreAttributes.MIME_TYPE.key(), 
recordSetWriter.getMimeType());
+            attributes.putAll(writeResult.getAttributes());
+        } catch (Exception e) {
+            getLogger().error("Error during transmission of records due to {}, 
routing to failure", new Object[]{e.getMessage()}, e);
+            session.transfer(flowFile, REL_FAILURE);
+            session.remove(sampledFlowFile);
+            return;
+        }
+        session.transfer(flowFile, REL_ORIGINAL);
+        sampledFlowFile = session.putAllAttributes(sampledFlowFile, 
attributes);
+        session.transfer(sampledFlowFile, REL_SUCCESS);
+    }
+
+    interface SamplingStrategy {
+        void init() throws IOException;
+
+        void sample(Record record) throws IOException;
+
+        WriteResult finish() throws IOException;
+    }
+
+    static class IntervalSamplingStrategy implements SamplingStrategy {
+        final RecordSetWriter writer;
+        final int interval;
+        int currentCount = 0;
+
+        IntervalSamplingStrategy(RecordSetWriter writer, int interval) {
+            this.writer = writer;
+            this.interval = interval;
+        }
+
+        @Override
+        public void init() throws IOException {
+            currentCount = 0;
+            writer.beginRecordSet();
+        }
+
+        @Override
+        public void sample(Record record) throws IOException {
+            if (++currentCount >= interval && interval > 0) {

Review comment:
       Since it's `interval > 0` it ends up writing no records.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to