exceptionfactory commented on a change in pull request #5388:
URL: https://github.com/apache/nifi/pull/5388#discussion_r713522228
##########
File path:
nifi-commons/nifi-record/src/main/java/org/apache/nifi/serialization/record/ResultSetRecordSet.java
##########
@@ -298,7 +272,67 @@ private DataType getDataType(final int sqlType, final
ResultSet rs, final int co
}
}
- private static DataType getArrayBaseType(final Array array) throws
SQLException {
+ private DataType determineDataTypeToReturn(final DataType dataType, final
boolean useLogicalTypes) {
+ RecordFieldType fieldType = dataType.getFieldType();
+ if (!useLogicalTypes
+ && (fieldType == RecordFieldType.DECIMAL
+ || fieldType == RecordFieldType.DATE
+ || fieldType == RecordFieldType.TIME
+ || fieldType == RecordFieldType.TIMESTAMP)) {
+ return RecordFieldType.STRING.getDataType();
+ } else {
+ return dataType;
+ }
+ }
+
+ private DataType getArrayDataType(final ResultSet rs, final int
columnIndex, final boolean useLogicalTypes) throws SQLException {
+ // The JDBC API does not allow us to know what the base type of an
array is through the metadata.
+ // As a result, we have to obtain the actual Array for this record.
Once we have this, we can determine
+ // the base type. However, if the base type is, itself, an array, we
will simply return a base type of
+ // String because otherwise, we need the ResultSet for the array
itself, and many JDBC Drivers do not
+ // support calling Array.getResultSet() and will throw an Exception if
that is not supported.
+ try {
+ final Array array = rs.getArray(columnIndex);
+
+ if (array == null) {
+ return
RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType());
+ }
+ final DataType baseType = getArrayBaseType(array, useLogicalTypes);
+ return RecordFieldType.ARRAY.getArrayDataType(baseType);
+ } catch (SQLFeatureNotSupportedException sfnse) {
+ return
RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType());
+ }
+ }
+
+ private DataType getDecimalDataType(final ResultSet rs, final int
columnIndex) throws SQLException {
+ int decimalPrecision;
+ final int decimalScale;
+ final int resultSetPrecision =
rs.getMetaData().getPrecision(columnIndex);
+ final int resultSetScale = rs.getMetaData().getScale(columnIndex);
+ if (rs.getMetaData().getPrecision(columnIndex) > 0) {
Review comment:
Is there a reason for calling `rs.getMetaData().getPrecision()` again as
opposed to using `resultSetPrecision`?
```suggestion
if (resultSetPrecision > 0) {
```
##########
File path:
nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/ResultSetRecordSetTest.java
##########
@@ -283,41 +544,235 @@ private ResultSet
givenResultSetForArrayThrowsException(boolean featureSupported
return resultSet;
}
- private ResultSet givenResultSetForOther() throws SQLException {
+ private ResultSet givenResultSetForOther(List<RecordField> fields) throws
SQLException {
final ResultSet resultSet = Mockito.mock(ResultSet.class);
final ResultSetMetaData resultSetMetaData =
Mockito.mock(ResultSetMetaData.class);
when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
- when(resultSetMetaData.getColumnCount()).thenReturn(1);
- when(resultSetMetaData.getColumnLabel(1)).thenReturn("column");
- when(resultSetMetaData.getColumnName(1)).thenReturn("column");
- when(resultSetMetaData.getColumnType(1)).thenReturn(Types.OTHER);
+ when(resultSetMetaData.getColumnCount()).thenReturn(fields.size());
+ for (int i = 0; i < fields.size(); ++i) {
+ int columnIndex = i + 1;
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnName(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.OTHER);
+ }
return resultSet;
}
- private RecordSchema givenRecordSchema() {
- final List<RecordField> fields = new ArrayList<>();
+ private Record givenInputRecord() {
+ List<RecordField> inputRecordFields = new ArrayList<>(2);
+ inputRecordFields.add(new RecordField("id",
RecordFieldType.INT.getDataType()));
+ inputRecordFields.add(new RecordField("name",
RecordFieldType.STRING.getDataType()));
+ RecordSchema inputRecordSchema = new
SimpleRecordSchema(inputRecordFields);
+
+ Map<String, Object> inputRecordData = new HashMap<>(2);
+ inputRecordData.put("id", 1);
+ inputRecordData.put("name", "John");
+
+ return new MapRecord(inputRecordSchema, inputRecordData);
+ }
- for (final Object[] column : COLUMNS) {
- fields.add(new RecordField((String) column[1], (DataType)
column[3]));
+ private List<RecordField> givenFieldsThatAreOfTypeRecord(List<Record>
concreteRecords) {
+ List<RecordField> fields = new ArrayList<>(concreteRecords.size());
+ int i = 1;
+ for (Record record : concreteRecords) {
+ fields.add(new RecordField("record" + String.valueOf(i),
RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
+ ++i;
}
+ return fields;
+ }
- return new SimpleRecordSchema(fields);
+ private List<RecordField> whenSchemaFieldsAreSetupForArrayType(final
List<ArrayTestData> testData,
+ final
ResultSet resultSet,
+ final
ResultSetMetaData resultSetMetaData)
+ throws SQLException {
+ List<RecordField> fields = new ArrayList<>();
+ for (int i = 0; i < testData.size(); ++i) {
+ ArrayTestData testDatum = testData.get(i);
+ int columnIndex = i + 1;
+ SqlArrayDummy arrayDummy = Mockito.mock(SqlArrayDummy.class);
+ when(arrayDummy.getArray()).thenReturn(testDatum.getTestArray());
+ when(resultSet.getArray(columnIndex)).thenReturn(arrayDummy);
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(testDatum.getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.ARRAY);
+ fields.add(new RecordField(testDatum.getFieldName(),
RecordFieldType.ARRAY.getDataType()));
+ }
+ return fields;
+ }
+
+ private void thenAllDataTypesMatchInputFieldType(final List<RecordField>
inputFields, final RecordSchema resultSchema) {
+ assertEquals("The number of input fields does not match the number of
fields in the result schema.", inputFields.size(),
resultSchema.getFieldCount());
+ for (int i = 0; i < inputFields.size(); ++i) {
+ assertEquals(inputFields.get(i).getDataType(),
resultSchema.getField(i).getDataType());
+ }
+ }
+
+ private void thenAllDataTypesAreString(final RecordSchema resultSchema) {
+ for (int i = 0; i < resultSchema.getFieldCount(); ++i) {
+ assertEquals(RecordFieldType.STRING.getDataType(),
resultSchema.getField(i).getDataType());
+ }
}
- private void thenAllColumnDataTypesAreCorrect(final RecordSchema
resultSchema) {
- assertNotNull(resultSchema);
+ private void thenAllColumnDataTypesAreCorrect(TestColumn[] columns,
RecordSchema expectedSchema, RecordSchema actualSchema) {
+ assertNotNull(actualSchema);
- for (final Object[] column : COLUMNS) {
+ for (TestColumn column : columns) {
+ int fieldIndex = column.getIndex() - 1;
// The DECIMAL column with scale larger than precision will not
match so verify that instead
- DataType actualDataType = resultSchema.getField((Integer)
column[0] - 1).getDataType();
- DataType expectedDataType = (DataType) column[3];
+ DataType actualDataType =
actualSchema.getField(fieldIndex).getDataType();
+ DataType expectedDataType =
expectedSchema.getField(fieldIndex).getDataType();
if
(expectedDataType.equals(RecordFieldType.DECIMAL.getDecimalDataType(3, 10))) {
DecimalDataType decimalDataType = (DecimalDataType)
expectedDataType;
if (decimalDataType.getScale() >
decimalDataType.getPrecision()) {
expectedDataType =
RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(),
decimalDataType.getScale());
}
}
- assertEquals("For column " + column[0] + " the converted type is
not matching", expectedDataType, actualDataType);
+ assertEquals("For column " + column.getIndex() + " the converted
type is not matching", expectedDataType, actualDataType);
+ }
+ }
+
+ private void thenActualArrayElementTypesMatchExpected(Map<String,
DataType> expectedTypes, RecordSchema actualSchema) {
+ for (RecordField recordField : actualSchema.getFields()) {
+ if (recordField.getDataType() instanceof ArrayDataType) {
+ ArrayDataType arrayType = (ArrayDataType)
recordField.getDataType();
+ if
(!arrayType.getElementType().equals(expectedTypes.get(recordField.getFieldName())))
{
+ throw new AssertionError("Array element type for " +
recordField.getFieldName()
+ + " is not of expected type " +
expectedTypes.get(recordField.getFieldName()).toString());
+ }
+ } else {
+ throw new AssertionError("RecordField " +
recordField.getFieldName() + " is not instance of ArrayDataType");
+ }
+ }
+ }
+
+ private void thenAllDataTypesAreChoice(final List<RecordField>
inputFields, final RecordSchema resultSchema) {
+ assertEquals("The number of input fields does not match the number of
fields in the result schema.", inputFields.size(),
resultSchema.getFieldCount());
+
+ DataType expectedType = getBroadestChoiceDataType();
+ for (int i = 0; i < inputFields.size(); ++i) {
+ assertEquals(expectedType, resultSchema.getField(i).getDataType());
+ }
+ }
+
+ private DataType getBroadestChoiceDataType() {
+ List<DataType> dataTypes = Stream.of(RecordFieldType.BIGINT,
RecordFieldType.BOOLEAN, RecordFieldType.BYTE, RecordFieldType.CHAR,
RecordFieldType.DATE,
+ RecordFieldType.DECIMAL, RecordFieldType.DOUBLE,
RecordFieldType.FLOAT, RecordFieldType.INT, RecordFieldType.LONG,
RecordFieldType.SHORT, RecordFieldType.STRING,
+ RecordFieldType.TIME, RecordFieldType.TIMESTAMP)
+ .map(RecordFieldType::getDataType)
+ .collect(Collectors.toList());
+ return RecordFieldType.CHOICE.getChoiceDataType(dataTypes);
+ }
+
+ private static class TestColumn {
+ private final int index; // Column indexing starts from 1, not 0.
+ private final String columnName;
+ private final int sqlType;
+ private final DataType recordFieldType;
+
+ public TestColumn(final int index, final String columnName, final int
sqlType, final DataType recordFieldType) {
+ this.index = index;
+ this.columnName = columnName;
+ this.sqlType = sqlType;
+ this.recordFieldType = recordFieldType;
+ }
+
+ public int getIndex() {
+ return index;
+ }
+
+ public String getColumnName() {
+ return columnName;
+ }
+
+ public int getSqlType() {
+ return sqlType;
+ }
+
+ public DataType getRecordFieldType() {
+ return recordFieldType;
+ }
+ }
+
+ private static class SqlArrayDummy implements Array {
Review comment:
Recommend renaming to `ResultSqlArray`:
```suggestion
private static class ResultSqlArray implements Array {
```
##########
File path:
nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/ResultSetRecordSetTest.java
##########
@@ -283,41 +544,235 @@ private ResultSet
givenResultSetForArrayThrowsException(boolean featureSupported
return resultSet;
}
- private ResultSet givenResultSetForOther() throws SQLException {
+ private ResultSet givenResultSetForOther(List<RecordField> fields) throws
SQLException {
final ResultSet resultSet = Mockito.mock(ResultSet.class);
final ResultSetMetaData resultSetMetaData =
Mockito.mock(ResultSetMetaData.class);
when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
- when(resultSetMetaData.getColumnCount()).thenReturn(1);
- when(resultSetMetaData.getColumnLabel(1)).thenReturn("column");
- when(resultSetMetaData.getColumnName(1)).thenReturn("column");
- when(resultSetMetaData.getColumnType(1)).thenReturn(Types.OTHER);
+ when(resultSetMetaData.getColumnCount()).thenReturn(fields.size());
+ for (int i = 0; i < fields.size(); ++i) {
+ int columnIndex = i + 1;
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnName(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.OTHER);
+ }
return resultSet;
}
- private RecordSchema givenRecordSchema() {
- final List<RecordField> fields = new ArrayList<>();
+ private Record givenInputRecord() {
+ List<RecordField> inputRecordFields = new ArrayList<>(2);
+ inputRecordFields.add(new RecordField("id",
RecordFieldType.INT.getDataType()));
+ inputRecordFields.add(new RecordField("name",
RecordFieldType.STRING.getDataType()));
+ RecordSchema inputRecordSchema = new
SimpleRecordSchema(inputRecordFields);
+
+ Map<String, Object> inputRecordData = new HashMap<>(2);
+ inputRecordData.put("id", 1);
+ inputRecordData.put("name", "John");
+
+ return new MapRecord(inputRecordSchema, inputRecordData);
+ }
- for (final Object[] column : COLUMNS) {
- fields.add(new RecordField((String) column[1], (DataType)
column[3]));
+ private List<RecordField> givenFieldsThatAreOfTypeRecord(List<Record>
concreteRecords) {
+ List<RecordField> fields = new ArrayList<>(concreteRecords.size());
+ int i = 1;
+ for (Record record : concreteRecords) {
+ fields.add(new RecordField("record" + String.valueOf(i),
RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
+ ++i;
}
+ return fields;
+ }
- return new SimpleRecordSchema(fields);
+ private List<RecordField> whenSchemaFieldsAreSetupForArrayType(final
List<ArrayTestData> testData,
+ final
ResultSet resultSet,
+ final
ResultSetMetaData resultSetMetaData)
+ throws SQLException {
+ List<RecordField> fields = new ArrayList<>();
+ for (int i = 0; i < testData.size(); ++i) {
+ ArrayTestData testDatum = testData.get(i);
+ int columnIndex = i + 1;
+ SqlArrayDummy arrayDummy = Mockito.mock(SqlArrayDummy.class);
+ when(arrayDummy.getArray()).thenReturn(testDatum.getTestArray());
+ when(resultSet.getArray(columnIndex)).thenReturn(arrayDummy);
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(testDatum.getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.ARRAY);
+ fields.add(new RecordField(testDatum.getFieldName(),
RecordFieldType.ARRAY.getDataType()));
+ }
+ return fields;
+ }
+
+ private void thenAllDataTypesMatchInputFieldType(final List<RecordField>
inputFields, final RecordSchema resultSchema) {
+ assertEquals("The number of input fields does not match the number of
fields in the result schema.", inputFields.size(),
resultSchema.getFieldCount());
+ for (int i = 0; i < inputFields.size(); ++i) {
+ assertEquals(inputFields.get(i).getDataType(),
resultSchema.getField(i).getDataType());
+ }
+ }
+
+ private void thenAllDataTypesAreString(final RecordSchema resultSchema) {
+ for (int i = 0; i < resultSchema.getFieldCount(); ++i) {
+ assertEquals(RecordFieldType.STRING.getDataType(),
resultSchema.getField(i).getDataType());
+ }
}
- private void thenAllColumnDataTypesAreCorrect(final RecordSchema
resultSchema) {
- assertNotNull(resultSchema);
+ private void thenAllColumnDataTypesAreCorrect(TestColumn[] columns,
RecordSchema expectedSchema, RecordSchema actualSchema) {
+ assertNotNull(actualSchema);
- for (final Object[] column : COLUMNS) {
+ for (TestColumn column : columns) {
+ int fieldIndex = column.getIndex() - 1;
// The DECIMAL column with scale larger than precision will not
match so verify that instead
- DataType actualDataType = resultSchema.getField((Integer)
column[0] - 1).getDataType();
- DataType expectedDataType = (DataType) column[3];
+ DataType actualDataType =
actualSchema.getField(fieldIndex).getDataType();
+ DataType expectedDataType =
expectedSchema.getField(fieldIndex).getDataType();
if
(expectedDataType.equals(RecordFieldType.DECIMAL.getDecimalDataType(3, 10))) {
DecimalDataType decimalDataType = (DecimalDataType)
expectedDataType;
if (decimalDataType.getScale() >
decimalDataType.getPrecision()) {
expectedDataType =
RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(),
decimalDataType.getScale());
}
}
- assertEquals("For column " + column[0] + " the converted type is
not matching", expectedDataType, actualDataType);
+ assertEquals("For column " + column.getIndex() + " the converted
type is not matching", expectedDataType, actualDataType);
+ }
+ }
+
+ private void thenActualArrayElementTypesMatchExpected(Map<String,
DataType> expectedTypes, RecordSchema actualSchema) {
+ for (RecordField recordField : actualSchema.getFields()) {
+ if (recordField.getDataType() instanceof ArrayDataType) {
+ ArrayDataType arrayType = (ArrayDataType)
recordField.getDataType();
+ if
(!arrayType.getElementType().equals(expectedTypes.get(recordField.getFieldName())))
{
+ throw new AssertionError("Array element type for " +
recordField.getFieldName()
+ + " is not of expected type " +
expectedTypes.get(recordField.getFieldName()).toString());
+ }
Review comment:
Is there a reason this cannot be changed to use `assertNotSame()` or
perhaps `fail()` if one of the available assertions does not work?
##########
File path:
nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/ResultSetRecordSetTest.java
##########
@@ -283,41 +544,235 @@ private ResultSet
givenResultSetForArrayThrowsException(boolean featureSupported
return resultSet;
}
- private ResultSet givenResultSetForOther() throws SQLException {
+ private ResultSet givenResultSetForOther(List<RecordField> fields) throws
SQLException {
final ResultSet resultSet = Mockito.mock(ResultSet.class);
final ResultSetMetaData resultSetMetaData =
Mockito.mock(ResultSetMetaData.class);
when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
- when(resultSetMetaData.getColumnCount()).thenReturn(1);
- when(resultSetMetaData.getColumnLabel(1)).thenReturn("column");
- when(resultSetMetaData.getColumnName(1)).thenReturn("column");
- when(resultSetMetaData.getColumnType(1)).thenReturn(Types.OTHER);
+ when(resultSetMetaData.getColumnCount()).thenReturn(fields.size());
+ for (int i = 0; i < fields.size(); ++i) {
+ int columnIndex = i + 1;
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnName(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.OTHER);
+ }
return resultSet;
}
- private RecordSchema givenRecordSchema() {
- final List<RecordField> fields = new ArrayList<>();
+ private Record givenInputRecord() {
+ List<RecordField> inputRecordFields = new ArrayList<>(2);
+ inputRecordFields.add(new RecordField("id",
RecordFieldType.INT.getDataType()));
+ inputRecordFields.add(new RecordField("name",
RecordFieldType.STRING.getDataType()));
+ RecordSchema inputRecordSchema = new
SimpleRecordSchema(inputRecordFields);
+
+ Map<String, Object> inputRecordData = new HashMap<>(2);
+ inputRecordData.put("id", 1);
+ inputRecordData.put("name", "John");
+
+ return new MapRecord(inputRecordSchema, inputRecordData);
+ }
- for (final Object[] column : COLUMNS) {
- fields.add(new RecordField((String) column[1], (DataType)
column[3]));
+ private List<RecordField> givenFieldsThatAreOfTypeRecord(List<Record>
concreteRecords) {
+ List<RecordField> fields = new ArrayList<>(concreteRecords.size());
+ int i = 1;
+ for (Record record : concreteRecords) {
+ fields.add(new RecordField("record" + String.valueOf(i),
RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
+ ++i;
}
+ return fields;
+ }
- return new SimpleRecordSchema(fields);
+ private List<RecordField> whenSchemaFieldsAreSetupForArrayType(final
List<ArrayTestData> testData,
+ final
ResultSet resultSet,
+ final
ResultSetMetaData resultSetMetaData)
+ throws SQLException {
+ List<RecordField> fields = new ArrayList<>();
+ for (int i = 0; i < testData.size(); ++i) {
+ ArrayTestData testDatum = testData.get(i);
+ int columnIndex = i + 1;
+ SqlArrayDummy arrayDummy = Mockito.mock(SqlArrayDummy.class);
+ when(arrayDummy.getArray()).thenReturn(testDatum.getTestArray());
+ when(resultSet.getArray(columnIndex)).thenReturn(arrayDummy);
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(testDatum.getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.ARRAY);
+ fields.add(new RecordField(testDatum.getFieldName(),
RecordFieldType.ARRAY.getDataType()));
+ }
+ return fields;
+ }
+
+ private void thenAllDataTypesMatchInputFieldType(final List<RecordField>
inputFields, final RecordSchema resultSchema) {
+ assertEquals("The number of input fields does not match the number of
fields in the result schema.", inputFields.size(),
resultSchema.getFieldCount());
+ for (int i = 0; i < inputFields.size(); ++i) {
+ assertEquals(inputFields.get(i).getDataType(),
resultSchema.getField(i).getDataType());
+ }
+ }
+
+ private void thenAllDataTypesAreString(final RecordSchema resultSchema) {
+ for (int i = 0; i < resultSchema.getFieldCount(); ++i) {
+ assertEquals(RecordFieldType.STRING.getDataType(),
resultSchema.getField(i).getDataType());
+ }
}
- private void thenAllColumnDataTypesAreCorrect(final RecordSchema
resultSchema) {
- assertNotNull(resultSchema);
+ private void thenAllColumnDataTypesAreCorrect(TestColumn[] columns,
RecordSchema expectedSchema, RecordSchema actualSchema) {
+ assertNotNull(actualSchema);
- for (final Object[] column : COLUMNS) {
+ for (TestColumn column : columns) {
+ int fieldIndex = column.getIndex() - 1;
// The DECIMAL column with scale larger than precision will not
match so verify that instead
- DataType actualDataType = resultSchema.getField((Integer)
column[0] - 1).getDataType();
- DataType expectedDataType = (DataType) column[3];
+ DataType actualDataType =
actualSchema.getField(fieldIndex).getDataType();
+ DataType expectedDataType =
expectedSchema.getField(fieldIndex).getDataType();
if
(expectedDataType.equals(RecordFieldType.DECIMAL.getDecimalDataType(3, 10))) {
DecimalDataType decimalDataType = (DecimalDataType)
expectedDataType;
if (decimalDataType.getScale() >
decimalDataType.getPrecision()) {
expectedDataType =
RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(),
decimalDataType.getScale());
}
}
- assertEquals("For column " + column[0] + " the converted type is
not matching", expectedDataType, actualDataType);
+ assertEquals("For column " + column.getIndex() + " the converted
type is not matching", expectedDataType, actualDataType);
+ }
+ }
+
+ private void thenActualArrayElementTypesMatchExpected(Map<String,
DataType> expectedTypes, RecordSchema actualSchema) {
+ for (RecordField recordField : actualSchema.getFields()) {
+ if (recordField.getDataType() instanceof ArrayDataType) {
+ ArrayDataType arrayType = (ArrayDataType)
recordField.getDataType();
+ if
(!arrayType.getElementType().equals(expectedTypes.get(recordField.getFieldName())))
{
+ throw new AssertionError("Array element type for " +
recordField.getFieldName()
+ + " is not of expected type " +
expectedTypes.get(recordField.getFieldName()).toString());
+ }
+ } else {
+ throw new AssertionError("RecordField " +
recordField.getFieldName() + " is not instance of ArrayDataType");
+ }
+ }
+ }
+
+ private void thenAllDataTypesAreChoice(final List<RecordField>
inputFields, final RecordSchema resultSchema) {
+ assertEquals("The number of input fields does not match the number of
fields in the result schema.", inputFields.size(),
resultSchema.getFieldCount());
+
+ DataType expectedType = getBroadestChoiceDataType();
+ for (int i = 0; i < inputFields.size(); ++i) {
+ assertEquals(expectedType, resultSchema.getField(i).getDataType());
+ }
+ }
+
+ private DataType getBroadestChoiceDataType() {
+ List<DataType> dataTypes = Stream.of(RecordFieldType.BIGINT,
RecordFieldType.BOOLEAN, RecordFieldType.BYTE, RecordFieldType.CHAR,
RecordFieldType.DATE,
+ RecordFieldType.DECIMAL, RecordFieldType.DOUBLE,
RecordFieldType.FLOAT, RecordFieldType.INT, RecordFieldType.LONG,
RecordFieldType.SHORT, RecordFieldType.STRING,
+ RecordFieldType.TIME, RecordFieldType.TIMESTAMP)
+ .map(RecordFieldType::getDataType)
+ .collect(Collectors.toList());
+ return RecordFieldType.CHOICE.getChoiceDataType(dataTypes);
+ }
+
+ private static class TestColumn {
+ private final int index; // Column indexing starts from 1, not 0.
+ private final String columnName;
+ private final int sqlType;
+ private final DataType recordFieldType;
+
+ public TestColumn(final int index, final String columnName, final int
sqlType, final DataType recordFieldType) {
+ this.index = index;
+ this.columnName = columnName;
+ this.sqlType = sqlType;
+ this.recordFieldType = recordFieldType;
+ }
+
+ public int getIndex() {
+ return index;
+ }
+
+ public String getColumnName() {
+ return columnName;
+ }
+
+ public int getSqlType() {
+ return sqlType;
+ }
+
+ public DataType getRecordFieldType() {
+ return recordFieldType;
+ }
+ }
+
+ private static class SqlArrayDummy implements Array {
+
+ @Override
+ public String getBaseTypeName() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public int getBaseType() throws SQLException {
+ return 0;
+ }
+
+ @Override
+ public Object getArray() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Object getArray(Map<String, Class<?>> map) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Object getArray(long index, int count) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public Object getArray(long index, int count, Map<String, Class<?>>
map) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public ResultSet getResultSet() throws SQLException {
+ return null;
+ }
+
+ @Override
+ public ResultSet getResultSet(Map<String, Class<?>> map) throws
SQLException {
+ return null;
+ }
+
+ @Override
+ public ResultSet getResultSet(long index, int count) throws
SQLException {
+ return null;
+ }
+
+ @Override
+ public ResultSet getResultSet(long index, int count, Map<String,
Class<?>> map) throws SQLException {
+ return null;
+ }
+
+ @Override
+ public void free() throws SQLException {
+
+ }
+ }
+
+ private static class BigDecimalDummy extends BigDecimal {
Review comment:
Recommend renaming to `ResultBigDecimal`:
```suggestion
private static class ResultBigDecimal extends BigDecimal {
```
##########
File path:
nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/ResultSetRecordSetTest.java
##########
@@ -272,6 +364,175 @@ public void
testCreateSchemaArrayThrowsNotSupportedException() throws SQLExcepti
assertEquals(RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType()),
resultSchema.getField(0).getDataType());
}
+ @Test
+ public void testArrayTypeWithLogicalTypes() throws SQLException {
+ testArrayType(true);
+ }
+
+ @Test
+ public void testArrayTypeNoLogicalTypes() throws SQLException {
+ testArrayType(false);
+ }
+
+ @Test
+ public void testCreateSchemaWithLogicalTypes() throws SQLException {
+ testCreateSchemaLogicalTypes(true, true);
+ }
+
+ @Test
+ public void testCreateSchemaNoLogicalTypes() throws SQLException {
+ testCreateSchemaLogicalTypes(false, true);
+ }
+
+ @Test
+ public void testCreateSchemaWithLogicalTypesNoInputSchema() throws
SQLException {
+ testCreateSchemaLogicalTypes(true, false);
+ }
+
+ @Test
+ public void testCreateSchemaNoLogicalTypesNoInputSchema() throws
SQLException {
+ testCreateSchemaLogicalTypes(false, false);
+ }
+
+ private void testArrayType(boolean useLogicalTypes) throws SQLException {
+ // GIVEN
+ List<ArrayTestData> testData =
givenArrayTypesThatRequireLogicalTypes();
+ Map<String, DataType> expectedTypes =
givenExpectedTypesForArrayTypesThatRequireLogicalTypes(useLogicalTypes);
+
+ // WHEN
+ ResultSet resultSet = Mockito.mock(ResultSet.class);
+ ResultSetMetaData resultSetMetaData =
Mockito.mock(ResultSetMetaData.class);
+ when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
+ when(resultSetMetaData.getColumnCount()).thenReturn(testData.size());
+
+ List<RecordField> fields =
whenSchemaFieldsAreSetupForArrayType(testData, resultSet, resultSetMetaData);
+ RecordSchema recordSchema = new SimpleRecordSchema(fields);
+
+ ResultSetRecordSet testSubject = new ResultSetRecordSet(resultSet,
recordSchema, 10,0, useLogicalTypes);
+ RecordSchema actualSchema = testSubject.getSchema();
+
+ // THEN
+ thenActualArrayElementTypesMatchExpected(expectedTypes, actualSchema);
+ }
+
+ private void testCreateSchemaLogicalTypes(boolean useLogicalTypes, boolean
provideInputSchema) throws SQLException {
+ // GIVEN
+ TestColumn[] columns = new TestColumn[]{
+ new TestColumn(1, COLUMN_NAME_DATE, Types.DATE,
RecordFieldType.DATE.getDataType()),
+ new TestColumn(2, "time", Types.TIME,
RecordFieldType.TIME.getDataType()),
+ new TestColumn(3, "time_with_timezone",
Types.TIME_WITH_TIMEZONE, RecordFieldType.TIME.getDataType()),
+ new TestColumn(4, "timestamp", Types.TIMESTAMP,
RecordFieldType.TIMESTAMP.getDataType()),
+ new TestColumn(5, "timestamp_with_timezone",
Types.TIMESTAMP_WITH_TIMEZONE, RecordFieldType.TIMESTAMP.getDataType()),
+ new TestColumn(6, COLUMN_NAME_BIG_DECIMAL_1,
Types.DECIMAL,RecordFieldType.DECIMAL.getDecimalDataType(7, 3)),
+ new TestColumn(7, COLUMN_NAME_BIG_DECIMAL_2, Types.NUMERIC,
RecordFieldType.DECIMAL.getDecimalDataType(4, 0)),
+ new TestColumn(8, COLUMN_NAME_BIG_DECIMAL_3,
Types.JAVA_OBJECT, RecordFieldType.DECIMAL.getDecimalDataType(501, 1)),
+ new TestColumn(9, COLUMN_NAME_BIG_DECIMAL_4, Types.DECIMAL,
RecordFieldType.DECIMAL.getDecimalDataType(10, 3)),
+ new TestColumn(10, COLUMN_NAME_BIG_DECIMAL_5, Types.DECIMAL,
RecordFieldType.DECIMAL.getDecimalDataType(3, 10)),
+ };
+ final RecordSchema recordSchema = provideInputSchema ?
givenRecordSchema(columns) : null;
+
+ ResultSetMetaData resultSetMetaData =
Mockito.mock(ResultSetMetaData.class);
+ ResultSet resultSet = Mockito.mock(ResultSet.class);
+
+ RecordSchema expectedSchema = useLogicalTypes ?
givenRecordSchema(columns) : givenRecordSchemaWithOnlyStringType(columns);
+
+ // WHEN
+ setUpMocks(columns, resultSetMetaData, resultSet);
+
+ ResultSetRecordSet testSubject = new ResultSetRecordSet(resultSet,
recordSchema, 10,0, useLogicalTypes);
+ RecordSchema actualSchema = testSubject.getSchema();
+
+ // THEN
+ thenAllColumnDataTypesAreCorrect(columns, expectedSchema,
actualSchema);
+ }
+
+ private void setUpMocks(TestColumn[] columns, ResultSetMetaData
resultSetMetaData, ResultSet resultSet) throws SQLException {
+ when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
+ when(resultSetMetaData.getColumnCount()).thenReturn(columns.length);
+
+ int indexOfBigDecimal = -1;
+ int index = 0;
+ for (final TestColumn column : columns) {
+
when(resultSetMetaData.getColumnLabel(column.getIndex())).thenReturn(column.getColumnName());
+
when(resultSetMetaData.getColumnName(column.getIndex())).thenReturn(column.getColumnName());
+
when(resultSetMetaData.getColumnType(column.getIndex())).thenReturn(column.getSqlType());
+
+ if (column.getRecordFieldType() instanceof DecimalDataType) {
+ DecimalDataType ddt = (DecimalDataType)
column.getRecordFieldType();
+
when(resultSetMetaData.getPrecision(column.getIndex())).thenReturn(ddt.getPrecision());
+
when(resultSetMetaData.getScale(column.getIndex())).thenReturn(ddt.getScale());
+ }
+ if (column.getSqlType() == Types.JAVA_OBJECT) {
+ indexOfBigDecimal = index + 1;
+ }
+ ++index;
+ }
+
+ // Big decimal values are necessary in order to determine precision
and scale
+ when(resultSet.getBigDecimal(indexOfBigDecimal)).thenReturn(new
BigDecimal(String.join("", Collections.nCopies(500, "1")) + ".1"));
+
+ // This will be handled by a dedicated branch for Java Objects, needs
some further details
+
when(resultSetMetaData.getColumnClassName(indexOfBigDecimal)).thenReturn(BigDecimal.class.getName());
+ }
+
+ private List<RecordField> givenFieldsThatRequireLogicalTypes() {
+ final List<RecordField> fields = new ArrayList<>();
+ fields.add(new RecordField("decimal",
RecordFieldType.DECIMAL.getDecimalDataType(30, 10)));
+ fields.add(new RecordField("date",
RecordFieldType.DATE.getDataType()));
+ fields.add(new RecordField("time",
RecordFieldType.TIME.getDataType()));
+ fields.add(new RecordField("timestamp",
RecordFieldType.TIMESTAMP.getDataType()));
+ return fields;
+ }
+
+ private RecordSchema givenRecordSchema(TestColumn[] columns) {
+ final List<RecordField> fields = new ArrayList<>(columns.length);
+
+ for (TestColumn column : columns) {
+ fields.add(new RecordField(column.getColumnName(),
column.getRecordFieldType()));
+ }
+
+ return new SimpleRecordSchema(fields);
+ }
+
+ private RecordSchema givenRecordSchemaWithOnlyStringType(TestColumn[]
columns) {
+ final List<RecordField> fields = new ArrayList<>(columns.length);
+
+ for (TestColumn column : columns) {
+ fields.add(new RecordField(column.getColumnName(),
RecordFieldType.STRING.getDataType()));
+ }
+
+ return new SimpleRecordSchema(fields);
+ }
+
+ private List<ArrayTestData> givenArrayTypesThatRequireLogicalTypes() {
+ List<ArrayTestData> testData = new ArrayList<>();
+ testData.add(new ArrayTestData("arrayBigDecimal",
+ new BigDecimalDummy[]{new BigDecimalDummy(), new
BigDecimalDummy()}));
+ testData.add(new ArrayTestData("arrayDate",
+ new Date[]{new Date(1631809132516L), new
Date(1631809132516L)}));
+ testData.add(new ArrayTestData("arrayTime",
+ new Time[]{new Time(1631809132516L), new
Time(1631809132516L)}));
+ testData.add(new ArrayTestData("arrayTimestamp",
+ new Timestamp[]{new Timestamp(1631809132516L), new
Timestamp(1631809132516L)}));
Review comment:
It might be easier to use a static long for the millisecond value
instead of repeating the same number multiple times.
##########
File path:
nifi-commons/nifi-record/src/test/java/org/apache/nifi/serialization/record/ResultSetRecordSetTest.java
##########
@@ -283,41 +544,235 @@ private ResultSet
givenResultSetForArrayThrowsException(boolean featureSupported
return resultSet;
}
- private ResultSet givenResultSetForOther() throws SQLException {
+ private ResultSet givenResultSetForOther(List<RecordField> fields) throws
SQLException {
final ResultSet resultSet = Mockito.mock(ResultSet.class);
final ResultSetMetaData resultSetMetaData =
Mockito.mock(ResultSetMetaData.class);
when(resultSet.getMetaData()).thenReturn(resultSetMetaData);
- when(resultSetMetaData.getColumnCount()).thenReturn(1);
- when(resultSetMetaData.getColumnLabel(1)).thenReturn("column");
- when(resultSetMetaData.getColumnName(1)).thenReturn("column");
- when(resultSetMetaData.getColumnType(1)).thenReturn(Types.OTHER);
+ when(resultSetMetaData.getColumnCount()).thenReturn(fields.size());
+ for (int i = 0; i < fields.size(); ++i) {
+ int columnIndex = i + 1;
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnName(columnIndex)).thenReturn(fields.get(i).getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.OTHER);
+ }
return resultSet;
}
- private RecordSchema givenRecordSchema() {
- final List<RecordField> fields = new ArrayList<>();
+ private Record givenInputRecord() {
+ List<RecordField> inputRecordFields = new ArrayList<>(2);
+ inputRecordFields.add(new RecordField("id",
RecordFieldType.INT.getDataType()));
+ inputRecordFields.add(new RecordField("name",
RecordFieldType.STRING.getDataType()));
+ RecordSchema inputRecordSchema = new
SimpleRecordSchema(inputRecordFields);
+
+ Map<String, Object> inputRecordData = new HashMap<>(2);
+ inputRecordData.put("id", 1);
+ inputRecordData.put("name", "John");
+
+ return new MapRecord(inputRecordSchema, inputRecordData);
+ }
- for (final Object[] column : COLUMNS) {
- fields.add(new RecordField((String) column[1], (DataType)
column[3]));
+ private List<RecordField> givenFieldsThatAreOfTypeRecord(List<Record>
concreteRecords) {
+ List<RecordField> fields = new ArrayList<>(concreteRecords.size());
+ int i = 1;
+ for (Record record : concreteRecords) {
+ fields.add(new RecordField("record" + String.valueOf(i),
RecordFieldType.RECORD.getRecordDataType(record.getSchema())));
+ ++i;
}
+ return fields;
+ }
- return new SimpleRecordSchema(fields);
+ private List<RecordField> whenSchemaFieldsAreSetupForArrayType(final
List<ArrayTestData> testData,
+ final
ResultSet resultSet,
+ final
ResultSetMetaData resultSetMetaData)
+ throws SQLException {
+ List<RecordField> fields = new ArrayList<>();
+ for (int i = 0; i < testData.size(); ++i) {
+ ArrayTestData testDatum = testData.get(i);
+ int columnIndex = i + 1;
+ SqlArrayDummy arrayDummy = Mockito.mock(SqlArrayDummy.class);
+ when(arrayDummy.getArray()).thenReturn(testDatum.getTestArray());
+ when(resultSet.getArray(columnIndex)).thenReturn(arrayDummy);
+
when(resultSetMetaData.getColumnLabel(columnIndex)).thenReturn(testDatum.getFieldName());
+
when(resultSetMetaData.getColumnType(columnIndex)).thenReturn(Types.ARRAY);
+ fields.add(new RecordField(testDatum.getFieldName(),
RecordFieldType.ARRAY.getDataType()));
+ }
+ return fields;
+ }
+
+ private void thenAllDataTypesMatchInputFieldType(final List<RecordField>
inputFields, final RecordSchema resultSchema) {
+ assertEquals("The number of input fields does not match the number of
fields in the result schema.", inputFields.size(),
resultSchema.getFieldCount());
+ for (int i = 0; i < inputFields.size(); ++i) {
+ assertEquals(inputFields.get(i).getDataType(),
resultSchema.getField(i).getDataType());
+ }
+ }
+
+ private void thenAllDataTypesAreString(final RecordSchema resultSchema) {
+ for (int i = 0; i < resultSchema.getFieldCount(); ++i) {
+ assertEquals(RecordFieldType.STRING.getDataType(),
resultSchema.getField(i).getDataType());
+ }
}
- private void thenAllColumnDataTypesAreCorrect(final RecordSchema
resultSchema) {
- assertNotNull(resultSchema);
+ private void thenAllColumnDataTypesAreCorrect(TestColumn[] columns,
RecordSchema expectedSchema, RecordSchema actualSchema) {
+ assertNotNull(actualSchema);
- for (final Object[] column : COLUMNS) {
+ for (TestColumn column : columns) {
+ int fieldIndex = column.getIndex() - 1;
// The DECIMAL column with scale larger than precision will not
match so verify that instead
- DataType actualDataType = resultSchema.getField((Integer)
column[0] - 1).getDataType();
- DataType expectedDataType = (DataType) column[3];
+ DataType actualDataType =
actualSchema.getField(fieldIndex).getDataType();
+ DataType expectedDataType =
expectedSchema.getField(fieldIndex).getDataType();
if
(expectedDataType.equals(RecordFieldType.DECIMAL.getDecimalDataType(3, 10))) {
DecimalDataType decimalDataType = (DecimalDataType)
expectedDataType;
if (decimalDataType.getScale() >
decimalDataType.getPrecision()) {
expectedDataType =
RecordFieldType.DECIMAL.getDecimalDataType(decimalDataType.getScale(),
decimalDataType.getScale());
}
}
- assertEquals("For column " + column[0] + " the converted type is
not matching", expectedDataType, actualDataType);
+ assertEquals("For column " + column.getIndex() + " the converted
type is not matching", expectedDataType, actualDataType);
+ }
+ }
+
+ private void thenActualArrayElementTypesMatchExpected(Map<String,
DataType> expectedTypes, RecordSchema actualSchema) {
+ for (RecordField recordField : actualSchema.getFields()) {
+ if (recordField.getDataType() instanceof ArrayDataType) {
+ ArrayDataType arrayType = (ArrayDataType)
recordField.getDataType();
+ if
(!arrayType.getElementType().equals(expectedTypes.get(recordField.getFieldName())))
{
+ throw new AssertionError("Array element type for " +
recordField.getFieldName()
+ + " is not of expected type " +
expectedTypes.get(recordField.getFieldName()).toString());
+ }
+ } else {
+ throw new AssertionError("RecordField " +
recordField.getFieldName() + " is not instance of ArrayDataType");
Review comment:
This should be changed to use `fail()` instead of throwing an
`AssertionError`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]