This is an automated email from the ASF dual-hosted git repository.

mattyb149 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new ec09c56e93 NIFI-10508: When inferring data types for values, allow 
float and double to encapsulate byte/short/int/long values
ec09c56e93 is described below

commit ec09c56e933c8418c97f89bfbcf898c8338b5d06
Author: Mark Payne <[email protected]>
AuthorDate: Wed Sep 14 16:03:29 2022 -0400

    NIFI-10508: When inferring data types for values, allow float and double to 
encapsulate byte/short/int/long values
    
    Signed-off-by: Matthew Burgess <[email protected]>
    
    This closes #6421
---
 .../serialization/record/util/DataTypeSet.java     |  78 +++++++++++++++
 .../serialization/record/util/DataTypeUtils.java   | 107 ++++++++++-----------
 .../serialization/record/TestDataTypeUtils.java    |   6 ++
 .../serialization/record/util/TestDataTypeSet.java |  68 +++++++++++++
 .../validation/TestStandardSchemaValidator.java    |  11 ---
 .../nifi/queryrecord/FlowFileEnumerator.java       |  16 +--
 .../nifi/schema/inference/FieldTypeInference.java  |  12 +--
 .../schema/inference/TestFieldTypeInference.java   |  93 ++++++++----------
 8 files changed, 261 insertions(+), 130 deletions(-)

diff --git 
a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java
 
b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java
new file mode 100644
index 0000000000..33ee847571
--- /dev/null
+++ 
b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeSet.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.serialization.record.util;
+
+import org.apache.nifi.serialization.record.DataType;
+import org.apache.nifi.serialization.record.RecordFieldType;
+import org.apache.nifi.serialization.record.type.ChoiceDataType;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * A container class, to which multiple DataTypes can be added, such that
+ * adding any two types where one is more narrow than the other will result
+ * in combining the two types into the wider type.
+ */
+public class DataTypeSet {
+    private final List<DataType> types = new ArrayList<>();
+
+    /**
+     * Adds the given data type to the set of types to consider
+     * @param dataType the data type to add
+     */
+    public void add(final DataType dataType) {
+        if (dataType == null) {
+            return;
+        }
+
+        if (dataType.getFieldType() == RecordFieldType.CHOICE) {
+            final ChoiceDataType choiceDataType = (ChoiceDataType) dataType;
+            choiceDataType.getPossibleSubTypes().forEach(this::add);
+            return;
+        }
+
+        if (types.contains(dataType)) {
+            return;
+        }
+
+        DataType toRemove = null;
+        DataType toAdd = null;
+        for (final DataType currentType : types) {
+            final Optional<DataType> widerType = 
DataTypeUtils.getWiderType(currentType, dataType);
+            if (widerType.isPresent()) {
+                toRemove = currentType;
+                toAdd = widerType.get();
+            }
+        }
+
+        if (toRemove != null) {
+            types.remove(toRemove);
+        }
+
+        types.add( toAdd == null ? dataType : toAdd );
+    }
+
+    /**
+     * @return the combined types
+     */
+    public List<DataType> getTypes() {
+        return new ArrayList<>(types);
+    }
+}
diff --git 
a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java
 
b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java
index e281d6cfbe..21e3b71827 100644
--- 
a/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java
+++ 
b/nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/util/DataTypeUtils.java
@@ -64,7 +64,6 @@ import java.util.EnumMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
-import java.util.LinkedHashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -590,34 +589,8 @@ public class DataTypeUtils {
                 m.forEach((k, v) -> map.put(k == null ? null : k.toString(), 
v));
             }
             return inferRecordDataType(map);
-//            // Check if all types are the same.
-//            if (map.isEmpty()) {
-//                return 
RecordFieldType.MAP.getMapDataType(RecordFieldType.STRING.getDataType());
-//            }
-//
-//            Object valueFromMap = null;
-//            Class<?> valueClass = null;
-//            for (final Object val : map.values()) {
-//                if (val == null) {
-//                    continue;
-//                }
-//
-//                valueFromMap = val;
-//                final Class<?> currentValClass = val.getClass();
-//                if (valueClass == null) {
-//                    valueClass = currentValClass;
-//                } else {
-//                    // If we have two elements that are of different types, 
then we cannot have a Map. Must be a Record.
-//                    if (valueClass != currentValClass) {
-//                        return inferRecordDataType(map);
-//                    }
-//                }
-//            }
-//
-//            // All values appear to be of the same type, so assume that it's 
a map.
-//            final DataType elementDataType = inferDataType(valueFromMap, 
RecordFieldType.STRING.getDataType());
-//            return RecordFieldType.MAP.getMapDataType(elementDataType);
         }
+
         if (value.getClass().isArray()) {
             DataType mergedDataType = null;
 
@@ -633,8 +606,9 @@ public class DataTypeUtils {
 
             return RecordFieldType.ARRAY.getArrayDataType(mergedDataType);
         }
+
         if (value instanceof Iterable) {
-            final Iterable iterable = (Iterable<?>) value;
+            final Iterable<?> iterable = (Iterable<?>) value;
 
             DataType mergedDataType = null;
             for (final Object arrayValue : iterable) {
@@ -1998,33 +1972,34 @@ public class DataTypeUtils {
                 return widerType.get();
             }
 
-            final Set<DataType> possibleTypes = new LinkedHashSet<>();
-            if (thisDataType.getFieldType() == RecordFieldType.CHOICE) {
-                possibleTypes.addAll(((ChoiceDataType) 
thisDataType).getPossibleSubTypes());
-            } else {
-                possibleTypes.add(thisDataType);
-            }
+            final DataTypeSet dataTypeSet = new DataTypeSet();
+            dataTypeSet.add(thisDataType);
+            dataTypeSet.add(otherDataType);
 
-            if (otherDataType.getFieldType() == RecordFieldType.CHOICE) {
-                possibleTypes.addAll(((ChoiceDataType) 
otherDataType).getPossibleSubTypes());
-            } else {
-                possibleTypes.add(otherDataType);
-            }
-
-            ArrayList<DataType> possibleChildTypes = new 
ArrayList<>(possibleTypes);
-            Collections.sort(possibleChildTypes, 
Comparator.comparing(DataType::getFieldType));
+            final List<DataType> possibleChildTypes = dataTypeSet.getTypes();
+            
possibleChildTypes.sort(Comparator.comparing(DataType::getFieldType));
 
             return 
RecordFieldType.CHOICE.getChoiceDataType(possibleChildTypes);
         }
     }
 
     public static Optional<DataType> getWiderType(final DataType thisDataType, 
final DataType otherDataType) {
+        if (thisDataType == null) {
+            return Optional.ofNullable(otherDataType);
+        }
+        if (otherDataType == null) {
+            return Optional.of(thisDataType);
+        }
+
         final RecordFieldType thisFieldType = thisDataType.getFieldType();
         final RecordFieldType otherFieldType = otherDataType.getFieldType();
 
         final int thisIntTypeValue = getIntegerTypeValue(thisFieldType);
         final int otherIntTypeValue = getIntegerTypeValue(otherFieldType);
-        if (thisIntTypeValue > -1 && otherIntTypeValue > -1) {
+        final boolean thisIsInt = thisIntTypeValue > -1;
+        final boolean otherIsInt = otherIntTypeValue > -1;
+
+        if (thisIsInt && otherIsInt) {
             if (thisIntTypeValue > otherIntTypeValue) {
                 return Optional.of(thisDataType);
             }
@@ -2032,25 +2007,37 @@ public class DataTypeUtils {
             return Optional.of(otherDataType);
         }
 
+        final boolean otherIsDecimal = isDecimalType(otherFieldType);
+
         switch (thisFieldType) {
-            case FLOAT:
-                if (otherFieldType == RecordFieldType.DOUBLE) {
+            case BYTE:
+            case SHORT:
+            case INT:
+            case LONG:
+                if (otherIsDecimal) {
                     return Optional.of(otherDataType);
-                } else if (otherFieldType == RecordFieldType.DECIMAL) {
+                }
+                break;
+            case FLOAT:
+                if (otherFieldType == RecordFieldType.DOUBLE || otherFieldType 
== RecordFieldType.DECIMAL) {
                     return Optional.of(otherDataType);
                 }
+                if (otherFieldType == RecordFieldType.BYTE || otherFieldType 
== RecordFieldType.SHORT || otherFieldType == RecordFieldType.INT || 
otherFieldType == RecordFieldType.LONG) {
+                    return Optional.of(thisDataType);
+                }
                 break;
             case DOUBLE:
-                if (otherFieldType == RecordFieldType.FLOAT) {
-                    return Optional.of(thisDataType);
-                } else if (otherFieldType == RecordFieldType.DECIMAL) {
+                if (otherFieldType == RecordFieldType.DECIMAL) {
                     return Optional.of(otherDataType);
                 }
+                if (otherFieldType == RecordFieldType.BYTE || otherFieldType 
== RecordFieldType.SHORT || otherFieldType == RecordFieldType.INT || 
otherFieldType == RecordFieldType.LONG
+                    || otherFieldType == RecordFieldType.FLOAT) {
+
+                    return Optional.of(thisDataType);
+                }
                 break;
             case DECIMAL:
-                if (otherFieldType == RecordFieldType.DOUBLE) {
-                    return Optional.of(thisDataType);
-                } else if (otherFieldType == RecordFieldType.FLOAT) {
+                if (otherFieldType == RecordFieldType.DOUBLE || otherFieldType 
== RecordFieldType.FLOAT || otherIsInt) {
                     return Optional.of(thisDataType);
                 } else if (otherFieldType == RecordFieldType.DECIMAL) {
                     final DecimalDataType thisDecimalDataType = 
(DecimalDataType) thisDataType;
@@ -2062,12 +2049,13 @@ public class DataTypeUtils {
                 }
                 break;
             case CHAR:
+            case UUID:
                 if (otherFieldType == RecordFieldType.STRING) {
                     return Optional.of(otherDataType);
                 }
                 break;
             case STRING:
-                if (otherFieldType == RecordFieldType.CHAR) {
+                if (otherFieldType == RecordFieldType.CHAR || otherFieldType 
== RecordFieldType.UUID) {
                     return Optional.of(thisDataType);
                 }
                 break;
@@ -2076,6 +2064,17 @@ public class DataTypeUtils {
         return Optional.empty();
     }
 
+    private static boolean isDecimalType(final RecordFieldType fieldType) {
+        switch (fieldType) {
+            case FLOAT:
+            case DOUBLE:
+            case DECIMAL:
+                return true;
+            default:
+                return false;
+        }
+    }
+
     private static int getIntegerTypeValue(final RecordFieldType fieldType) {
         switch (fieldType) {
             case BIGINT:
diff --git 
a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java
 
b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java
index 46cd012239..863949b486 100644
--- 
a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java
+++ 
b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/TestDataTypeUtils.java
@@ -94,6 +94,12 @@ public class TestDataTypeUtils {
         assertEquals(ts.getTime(), sDate.getTime(), "Times didn't match");
     }
 
+    @Test
+    public void testIntDoubleWiderType() {
+        assertEquals(Optional.of(RecordFieldType.DOUBLE.getDataType()), 
DataTypeUtils.getWiderType(RecordFieldType.INT.getDataType(), 
RecordFieldType.DOUBLE.getDataType()));
+        assertEquals(Optional.of(RecordFieldType.DOUBLE.getDataType()), 
DataTypeUtils.getWiderType(RecordFieldType.DOUBLE.getDataType(), 
RecordFieldType.INT.getDataType()));
+    }
+
     /*
      * This was a bug in NiFi 1.8 where converting from a Timestamp to a Date 
with the record path API
      * would throw an exception.
diff --git 
a/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java
 
b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java
new file mode 100644
index 0000000000..14800688da
--- /dev/null
+++ 
b/nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/util/TestDataTypeSet.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.serialization.record.util;
+
+import org.apache.nifi.serialization.record.RecordFieldType;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class TestDataTypeSet {
+
+    @Test
+    public void testCombineNarrowThenWider() {
+        final DataTypeSet set = new DataTypeSet();
+        set.add(RecordFieldType.INT.getDataType());
+        set.add(RecordFieldType.DOUBLE.getDataType());
+        
assertEquals(Collections.singletonList(RecordFieldType.DOUBLE.getDataType()), 
set.getTypes());
+    }
+
+    @Test
+    public void testAddIncompatible() {
+        final DataTypeSet set = new DataTypeSet();
+        set.add(RecordFieldType.INT.getDataType());
+        set.add(RecordFieldType.BOOLEAN.getDataType());
+        assertEquals(Arrays.asList(RecordFieldType.INT.getDataType(), 
RecordFieldType.BOOLEAN.getDataType()), set.getTypes());
+
+    }
+
+    @Test
+    public void addSingleType() {
+        final DataTypeSet set = new DataTypeSet();
+        set.add(RecordFieldType.INT.getDataType());
+        
assertEquals(Collections.singletonList(RecordFieldType.INT.getDataType()), 
set.getTypes());
+
+    }
+
+    @Test
+    public void testCombineWiderThenNarrow() {
+        final DataTypeSet set = new DataTypeSet();
+        set.add(RecordFieldType.DOUBLE.getDataType());
+        set.add(RecordFieldType.INT.getDataType());
+        
assertEquals(Collections.singletonList(RecordFieldType.DOUBLE.getDataType()), 
set.getTypes());
+    }
+
+    @Test
+    public void testAddNothing() {
+        final DataTypeSet set = new DataTypeSet();
+        assertEquals(Collections.emptyList(), set.getTypes());
+    }
+}
diff --git 
a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java
 
b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java
index 17a109bba1..baa1063500 100644
--- 
a/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java
+++ 
b/nifi-nar-bundles/nifi-extension-utils/nifi-record-utils/nifi-standard-record-utils/src/test/java/org/apache/nifi/schema/validation/TestStandardSchemaValidator.java
@@ -193,11 +193,6 @@ public class TestStandardSchemaValidator {
         whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Integer.MAX_VALUE, 
RecordFieldType.DECIMAL);
     }
 
-    @Test
-    public void testIntegerOutsideRangeIsConsideredAsInvalid() {
-        
whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_FLOAT.intValue()
 + 1, RecordFieldType.FLOAT);
-        // Double handles integer completely
-    }
 
     @Test
     public void testLongWithinRangeIsConsideredToBeValidFloatingPoint() {
@@ -206,12 +201,6 @@ public class TestStandardSchemaValidator {
         whenValueIsAcceptedAsDataTypeThenConsideredAsValid(Long.MAX_VALUE, 
RecordFieldType.DECIMAL);
     }
 
-    @Test
-    public void testLongOutsideRangeIsConsideredAsInvalid() {
-        
whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_FLOAT
 + 1, RecordFieldType.FLOAT);
-        
whenValueIsNotAcceptedAsDataTypeThenConsideredAsInvalid(MAX_PRECISE_WHOLE_IN_DOUBLE
 + 1, RecordFieldType.DOUBLE);
-    }
-
     @Test
     public void testBigintWithinRangeIsConsideredToBeValidFloatingPoint() {
         
whenValueIsAcceptedAsDataTypeThenConsideredAsValid(BigInteger.valueOf(5L), 
RecordFieldType.FLOAT);
diff --git 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java
 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java
index db66c5a80e..1a0d2ba2ec 100644
--- 
a/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java
+++ 
b/nifi-nar-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/queryrecord/FlowFileEnumerator.java
@@ -120,17 +120,17 @@ public class FlowFileEnumerator implements 
Enumerator<Object> {
         return filtered;
     }
 
-    private Object cast(Object o) {
-        if (o == null) {
+    private Object cast(final Object toCast) {
+        if (toCast == null) {
             return null;
-        } else if (o.getClass().isArray()) {
-            List<Object> l = new ArrayList(Array.getLength(o));
-            for (int i = 0; i < Array.getLength(o); i++) {
-                l.add(Array.get(o, i));
+        } else if (toCast.getClass().isArray()) {
+            final List<Object> list = new ArrayList<>(Array.getLength(toCast));
+            for (int i = 0; i < Array.getLength(toCast); i++) {
+                list.add(Array.get(toCast, i));
             }
-            return l;
+            return list;
         } else {
-            return o;
+            return toCast;
         }
     }
 
diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java
index 1f52cb8357..a4186eef60 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/main/java/org/apache/nifi/schema/inference/FieldTypeInference.java
@@ -34,7 +34,7 @@ public class FieldTypeInference {
     // unique value for the data type, and so this paradigm allows us to avoid 
the cost of creating
     // and using the HashSet.
     private DataType singleDataType = null;
-    private Set<DataType> possibleDataTypes = new HashSet<>();
+    private final Set<DataType> possibleDataTypes = new HashSet<>();
 
     public void addPossibleDataType(final DataType dataType) {
         if (dataType == null) {
@@ -73,17 +73,17 @@ public class FieldTypeInference {
             possibleDataTypes.add(singleDataType);
         }
 
-        for (DataType possibleDataType : possibleDataTypes) {
-            RecordFieldType possibleFieldType = 
possibleDataType.getFieldType();
+        for (final DataType possibleDataType : possibleDataTypes) {
+            final RecordFieldType possibleFieldType = 
possibleDataType.getFieldType();
             if (!possibleFieldType.equals(RecordFieldType.STRING) && 
possibleFieldType.isWiderThan(additionalFieldType)) {
                 return;
             }
         }
 
-        Iterator<DataType> possibleDataTypeIterator = 
possibleDataTypes.iterator();
+        final Iterator<DataType> possibleDataTypeIterator = 
possibleDataTypes.iterator();
         while (possibleDataTypeIterator.hasNext()) {
-            DataType possibleDataType = possibleDataTypeIterator.next();
-            RecordFieldType possibleFieldType = 
possibleDataType.getFieldType();
+            final DataType possibleDataType = possibleDataTypeIterator.next();
+            final RecordFieldType possibleFieldType = 
possibleDataType.getFieldType();
 
             if (!additionalFieldType.equals(RecordFieldType.STRING) && 
additionalFieldType.isWiderThan(possibleFieldType)) {
                 possibleDataTypeIterator.remove();
diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java
index ea2f470095..80b60c61d8 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-record-serialization-services-bundle/nifi-record-serialization-services/src/test/java/org/apache/nifi/schema/inference/TestFieldTypeInference.java
@@ -20,11 +20,13 @@ import org.apache.nifi.serialization.SimpleRecordSchema;
 import org.apache.nifi.serialization.record.DataType;
 import org.apache.nifi.serialization.record.RecordField;
 import org.apache.nifi.serialization.record.RecordFieldType;
+import org.apache.nifi.serialization.record.RecordSchema;
 import org.apache.nifi.serialization.record.type.ChoiceDataType;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -40,6 +42,25 @@ public class TestFieldTypeInference {
         testSubject = new FieldTypeInference();
     }
 
+    @Test
+    public void testIntegerCombinedWithDouble() {
+        final FieldTypeInference inference = new FieldTypeInference();
+        inference.addPossibleDataType(RecordFieldType.INT.getDataType());
+        inference.addPossibleDataType(RecordFieldType.DOUBLE.getDataType());
+
+        assertEquals(RecordFieldType.DOUBLE.getDataType(), 
inference.toDataType());
+    }
+
+    @Test
+    public void testIntegerCombinedWithFloat() {
+        final FieldTypeInference inference = new FieldTypeInference();
+        inference.addPossibleDataType(RecordFieldType.INT.getDataType());
+        inference.addPossibleDataType(RecordFieldType.FLOAT.getDataType());
+
+        assertEquals(RecordFieldType.FLOAT.getDataType(), 
inference.toDataType());
+    }
+
+
     @Test
     public void testToDataTypeWith_SHORT_INT_LONG_shouldReturn_LONG() {
         // GIVEN
@@ -58,20 +79,13 @@ public class TestFieldTypeInference {
 
     @Test
     public void testToDataTypeWith_INT_FLOAT_ShouldReturn_INT_FLOAT() {
-        // GIVEN
-        List<DataType> dataTypes = Arrays.asList(
+        final List<DataType> dataTypes = Arrays.asList(
                 RecordFieldType.INT.getDataType(),
                 RecordFieldType.FLOAT.getDataType()
         );
 
-        Set<DataType> expected = new HashSet<>(Arrays.asList(
-                RecordFieldType.INT.getDataType(),
-                RecordFieldType.FLOAT.getDataType()
-        ));
-
-        // WHEN
-        // THEN
-        runWithAllPermutations(this::testToDataTypeShouldReturnChoice, 
dataTypes, expected);
+        final DataType expected = RecordFieldType.FLOAT.getDataType();
+        runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, 
dataTypes, expected);
     }
 
     @Test
@@ -94,52 +108,39 @@ public class TestFieldTypeInference {
     }
 
     @Test
-    public void 
testToDataTypeWith_INT_FLOAT_STRING_shouldReturn_INT_FLOAT_STRING() {
-        // GIVEN
-        List<DataType> dataTypes = Arrays.asList(
+    public void 
testToDataTypeWith_INT_FLOAT_STRING_shouldReturn_FLOAT_STRING() {
+        final List<DataType> dataTypes = Arrays.asList(
                 RecordFieldType.INT.getDataType(),
                 RecordFieldType.FLOAT.getDataType(),
                 RecordFieldType.STRING.getDataType()
         );
 
-        Set<DataType> expected = new HashSet<>(Arrays.asList(
-                RecordFieldType.INT.getDataType(),
+        final Set<DataType> expected = new HashSet<>(Arrays.asList(
                 RecordFieldType.FLOAT.getDataType(),
                 RecordFieldType.STRING.getDataType()
         ));
 
-        // WHEN
-        // THEN
         runWithAllPermutations(this::testToDataTypeShouldReturnChoice, 
dataTypes, expected);
     }
 
     @Test
     public void testToDataTypeWithMultipleRecord() {
-        // GIVEN
-        String fieldName = "fieldName";
-        DataType fieldType1 = RecordFieldType.INT.getDataType();
-        DataType fieldType2 = RecordFieldType.FLOAT.getDataType();
-        DataType fieldType3 = RecordFieldType.STRING.getDataType();
-
-        List<DataType> dataTypes = Arrays.asList(
-                
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
fieldType1)),
-                
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
fieldType2)),
-                
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
fieldType3)),
-                
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
fieldType2))
+        final String fieldName = "fieldName";
+        final DataType intType = RecordFieldType.INT.getDataType();
+        final DataType floatType = RecordFieldType.FLOAT.getDataType();
+        final DataType stringType = RecordFieldType.STRING.getDataType();
+
+        final List<DataType> dataTypes = Arrays.asList(
+            
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
intType)),
+            
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
floatType)),
+            
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
stringType)),
+            
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(fieldName, 
floatType))
         );
 
-        DataType expected = 
RecordFieldType.RECORD.getRecordDataType(createRecordSchema(
-                fieldName,
-                RecordFieldType.CHOICE.getChoiceDataType(
-                        fieldType1,
-                        fieldType2,
-                        fieldType3
-                )
-        ));
+        final RecordSchema expectedSchema = createRecordSchema(fieldName, 
RecordFieldType.CHOICE.getChoiceDataType(floatType, stringType));
+        final DataType expecteDataType = 
RecordFieldType.RECORD.getRecordDataType(expectedSchema);
 
-        // WHEN
-        // THEN
-        runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, 
dataTypes, expected);
+        runWithAllPermutations(this::testToDataTypeShouldReturnSingleType, 
dataTypes, expecteDataType);
     }
 
     @Test
@@ -192,8 +193,8 @@ public class TestFieldTypeInference {
     }
 
     private SimpleRecordSchema createRecordSchema(String fieldName, DataType 
fieldType) {
-        return new SimpleRecordSchema(Arrays.asList(
-                new RecordField(fieldName, fieldType)
+        return new SimpleRecordSchema(Collections.singletonList(
+            new RecordField(fieldName, fieldType)
         ));
     }
 
@@ -202,28 +203,18 @@ public class TestFieldTypeInference {
     }
 
     private Void testToDataTypeShouldReturnChoice(List<DataType> dataTypes, 
Set<DataType> expected) {
-        // GIVEN
         dataTypes.forEach(testSubject::addPossibleDataType);
 
-        // WHEN
         DataType actual = testSubject.toDataType();
-
-        // THEN
         assertEquals(expected, new HashSet<>(((ChoiceDataType) 
actual).getPossibleSubTypes()));
-
         return null;
     }
 
     private Void testToDataTypeShouldReturnSingleType(List<DataType> 
dataTypes, DataType expected) {
-        // GIVEN
         dataTypes.forEach(testSubject::addPossibleDataType);
 
-        // WHEN
         DataType actual = testSubject.toDataType();
-
-        // THEN
         assertEquals(expected, actual);
-
         return null;
     }
 }

Reply via email to