This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch release-1.11 in repository https://gitbox.apache.org/repos/asf/flink.git
commit c3ff1de47cd01d7448d325c42d9ad76681e8c85d Author: Timo Walther <twal...@apache.org> AuthorDate: Mon May 18 11:07:22 2020 +0200 [hotfix][table] Reduce friction around logical type roots --- .../flink/table/types/logical/LogicalTypeRoot.java | 10 + .../types/logical/utils/LogicalTypeChecks.java | 13 + .../types/logical/utils/LogicalTypeUtils.java | 38 +- .../flink/table/planner/codegen/CodeGenUtils.scala | 499 ++++++++++++--------- .../planner/codegen/EqualiserCodeGenerator.scala | 17 +- .../table/planner/codegen/ExpressionReducer.scala | 6 +- .../table/planner/codegen/GenerateUtils.scala | 237 ++++++---- .../codegen/agg/batch/AggCodeGenHelper.scala | 34 +- .../table/runtime/typeutils/TypeCheckUtils.java | 18 +- 9 files changed, 531 insertions(+), 341 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java index 0079a8d..e2a97f6 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/LogicalTypeRoot.java @@ -37,6 +37,16 @@ import java.util.Set; * {@code SYMBOL}, or {@code RAW}). * * <p>See the type-implementing classes for a more detailed description of each type. + * + * <p>Note to implementers: Whenever we perform a match against a type root (e.g. using a switch/case + * statement), it is recommended to: + * <ul> + * <li>Order the items by the type root definition in this class for easy readability. + * <li>Think about the behavior of all type roots for the implementation. A default fallback is + * dangerous when introducing a new type root in the future. + * <li>In many <b>runtime</b> cases, resolve the indirection of {@link #DISTINCT_TYPE}: + * {@code return myMethod(((DistinctType) type).getSourceType)} + * </ul> */ @PublicEvolving public enum LogicalTypeRoot { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java index 8a3e301..b6117f4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java @@ -108,6 +108,9 @@ public final class LogicalTypeChecks { /** * Checks if the given type is a composite type. * + * <p>Use {@link #getFieldCount(LogicalType)}, {@link #getFieldNames(LogicalType)}, + * {@link #getFieldTypes(LogicalType)} for unified handling of composite types. + * * @param logicalType Logical data type to check * @return True if the type is composite type. */ @@ -198,6 +201,16 @@ public final class LogicalTypeChecks { return logicalType.accept(FIELD_NAMES_EXTRACTOR); } + /** + * Returns the field types of row and structured types. + */ + public static List<LogicalType> getFieldTypes(LogicalType logicalType) { + if (logicalType instanceof DistinctType) { + return getFieldTypes(((DistinctType) logicalType).getSourceType()); + } + return logicalType.getChildren(); + } + private LogicalTypeChecks() { // no instantiation } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java index 5e8be86..033d711 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeUtils.java @@ -26,6 +26,7 @@ import org.apache.flink.table.data.RawValueData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; import org.apache.flink.table.data.TimestampData; +import org.apache.flink.table.types.logical.DistinctType; import org.apache.flink.table.types.logical.LocalZonedTimestampType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.TimestampType; @@ -45,14 +46,23 @@ public final class LogicalTypeUtils { /** * Returns the conversion class for the given {@link LogicalType} that is used by the - * table runtime. + * table runtime as internal data structure. * * @see RowData */ public static Class<?> toInternalConversionClass(LogicalType type) { + // ordered by type root definition switch (type.getTypeRoot()) { + case CHAR: + case VARCHAR: + return StringData.class; case BOOLEAN: return Boolean.class; + case BINARY: + case VARBINARY: + return byte[].class; + case DECIMAL: + return DecimalData.class; case TINYINT: return Byte.class; case SMALLINT: @@ -65,32 +75,32 @@ public final class LogicalTypeUtils { case BIGINT: case INTERVAL_DAY_TIME: return Long.class; - case TIMESTAMP_WITHOUT_TIME_ZONE: - case TIMESTAMP_WITH_LOCAL_TIME_ZONE: - return TimestampData.class; case FLOAT: return Float.class; case DOUBLE: return Double.class; - case CHAR: - case VARCHAR: - return StringData.class; - case DECIMAL: - return DecimalData.class; + case TIMESTAMP_WITHOUT_TIME_ZONE: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + return TimestampData.class; + case TIMESTAMP_WITH_TIME_ZONE: + throw new UnsupportedOperationException("Unsupported type: " + type); case ARRAY: return ArrayData.class; - case MAP: case MULTISET: + case MAP: return MapData.class; case ROW: + case STRUCTURED_TYPE: return RowData.class; - case BINARY: - case VARBINARY: - return byte[].class; + case DISTINCT_TYPE: + return toInternalConversionClass(((DistinctType) type).getSourceType()); case RAW: return RawValueData.class; + case NULL: + case SYMBOL: + case UNRESOLVED: default: - throw new UnsupportedOperationException("Unsupported type: " + type); + throw new IllegalArgumentException("Illegal type: " + type); } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala index 58b7010..6e62a3f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala @@ -40,10 +40,12 @@ import org.apache.flink.table.runtime.util.MurmurHashUtil import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical._ -import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getPrecision, getScale, hasRoot} import org.apache.flink.table.types.logical.utils.LogicalTypeUtils.toInternalConversionClass import org.apache.flink.types.{Row, RowKind} +import scala.annotation.tailrec + object CodeGenUtils { // ------------------------------- DEFAULT TERMS ------------------------------------------ @@ -161,117 +163,118 @@ object CodeGenUtils { // works, but for boxed types we need this: // Float a = 1.0f; // Byte b = (byte)(float) a; + @tailrec def primitiveTypeTermForType(t: LogicalType): String = t.getTypeRoot match { - case INTEGER => "int" - case BIGINT => "long" - case SMALLINT => "short" + // ordered by type root definition + case BOOLEAN => "boolean" case TINYINT => "byte" + case SMALLINT => "short" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => "int" + case BIGINT | INTERVAL_DAY_TIME => "long" case FLOAT => "float" case DOUBLE => "double" - case BOOLEAN => "boolean" - - case DATE => "int" - case TIME_WITHOUT_TIME_ZONE => "int" - case INTERVAL_YEAR_MONTH => "int" - case INTERVAL_DAY_TIME => "long" - + case DISTINCT_TYPE => primitiveTypeTermForType(t.asInstanceOf[DistinctType].getSourceType) case _ => boxedTypeTermForType(t) } + @tailrec def boxedTypeTermForType(t: LogicalType): String = t.getTypeRoot match { - case INTEGER => className[JInt] - case BIGINT => className[JLong] - case SMALLINT => className[JShort] + // ordered by type root definition + case CHAR | VARCHAR => BINARY_STRING + case BOOLEAN => className[JBoolean] + case BINARY | VARBINARY => "byte[]" + case DECIMAL => className[DecimalData] case TINYINT => className[JByte] + case SMALLINT => className[JShort] + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => className[JInt] + case BIGINT | INTERVAL_DAY_TIME => className[JLong] case FLOAT => className[JFloat] case DOUBLE => className[JDouble] - case BOOLEAN => className[JBoolean] - - case DATE => className[JInt] - case TIME_WITHOUT_TIME_ZONE => className[JInt] - case INTERVAL_YEAR_MONTH => className[JInt] - case INTERVAL_DAY_TIME => className[JLong] - - case VARCHAR | CHAR => BINARY_STRING - case VARBINARY | BINARY => "byte[]" - - case DECIMAL => className[DecimalData] + case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => className[TimestampData] + case TIMESTAMP_WITH_TIME_ZONE => + throw new UnsupportedOperationException("Unsupported type: " + t) case ARRAY => className[ArrayData] case MULTISET | MAP => className[MapData] - case ROW => className[RowData] + case ROW | STRUCTURED_TYPE => className[RowData] case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => className[TimestampData] - + case DISTINCT_TYPE => boxedTypeTermForType(t.asInstanceOf[DistinctType].getSourceType) + case NULL => className[JObject] // special case for untyped null literals case RAW => className[BinaryRawValueData[_]] - - // special case for untyped null literals - case NULL => className[JObject] + case SYMBOL | UNRESOLVED => + throw new IllegalArgumentException("Illegal type: " + t) } /** * Gets the default value for a primitive type, and null for generic types */ + @tailrec def primitiveDefaultValue(t: LogicalType): String = t.getTypeRoot match { - case INTEGER | TINYINT | SMALLINT => "-1" - case BIGINT => "-1L" + // ordered by type root definition + case CHAR | VARCHAR => s"$BINARY_STRING.EMPTY_UTF8" + case BOOLEAN => "false" + case TINYINT | SMALLINT | INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => "-1" + case BIGINT | INTERVAL_DAY_TIME => "-1L" case FLOAT => "-1.0f" case DOUBLE => "-1.0d" - case BOOLEAN => "false" - case VARCHAR | CHAR => s"$BINARY_STRING.EMPTY_UTF8" - case DATE | TIME_WITHOUT_TIME_ZONE => "-1" - case INTERVAL_YEAR_MONTH => "-1" - case INTERVAL_DAY_TIME => "-1L" + case DISTINCT_TYPE => primitiveDefaultValue(t.asInstanceOf[DistinctType].getSourceType) case _ => "null" } - /** - * If it's internally compatible, don't need to DataStructure converter. - * clazz != classOf[Row] => Row can only infer GenericType[Row]. - */ - def isInternalClass(t: DataType): Boolean = { - val clazz = t.getConversionClass - clazz != classOf[Object] && clazz != classOf[Row] && - (classOf[RowData].isAssignableFrom(clazz) || - clazz == toInternalConversionClass(fromDataTypeToLogicalType(t))) - } - + @tailrec def hashCodeForType( - ctx: CodeGeneratorContext, t: LogicalType, term: String): String = t.getTypeRoot match { - case BOOLEAN => s"${className[JBoolean]}.hashCode($term)" - case TINYINT => s"${className[JByte]}.hashCode($term)" - case SMALLINT => s"${className[JShort]}.hashCode($term)" - case INTEGER => s"${className[JInt]}.hashCode($term)" - case BIGINT => s"${className[JLong]}.hashCode($term)" + ctx: CodeGeneratorContext, + t: LogicalType, + term: String) + : String = t.getTypeRoot match { + // ordered by type root definition + case VARCHAR | CHAR => + s"$term.hashCode()" + case BOOLEAN => + s"${className[JBoolean]}.hashCode($term)" + case BINARY | VARBINARY => + s"${className[MurmurHashUtil]}.hashUnsafeBytes($term, $BYTE_ARRAY_BASE_OFFSET, $term.length)" + case DECIMAL => + s"$term.hashCode()" + case TINYINT => + s"${className[JByte]}.hashCode($term)" + case SMALLINT => + s"${className[JShort]}.hashCode($term)" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => + s"${className[JInt]}.hashCode($term)" + case BIGINT | INTERVAL_DAY_TIME => s"${className[JLong]}.hashCode($term)" case FLOAT => s"${className[JFloat]}.hashCode($term)" case DOUBLE => s"${className[JDouble]}.hashCode($term)" - case VARCHAR | CHAR => s"$term.hashCode()" - case VARBINARY | BINARY => s"${className[MurmurHashUtil]}.hashUnsafeBytes(" + - s"$term, $BYTE_ARRAY_BASE_OFFSET, $term.length)" - case DECIMAL => s"$term.hashCode()" - case DATE => s"${className[JInt]}.hashCode($term)" - case TIME_WITHOUT_TIME_ZONE => s"${className[JInt]}.hashCode($term)" case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => s"$term.hashCode()" - case INTERVAL_YEAR_MONTH => s"${className[JInt]}.hashCode($term)" + case TIMESTAMP_WITH_TIME_ZONE | ARRAY | MULTISET | MAP => + throw new UnsupportedOperationException("Unsupported type: " + t) case INTERVAL_DAY_TIME => s"${className[JLong]}.hashCode($term)" - case ARRAY => throw new IllegalArgumentException(s"Not support type to hash: $t") - case ROW => - val rowType = t.asInstanceOf[RowType] + case ROW | STRUCTURED_TYPE => + val fieldCount = getFieldCount(t) val subCtx = CodeGeneratorContext(ctx.tableConfig) val genHash = HashCodeGenerator.generateRowHash( - subCtx, rowType, "SubHashRow", (0 until rowType.getFieldCount).toArray) + subCtx, t, "SubHashRow", (0 until fieldCount).toArray) ctx.addReusableInnerClass(genHash.getClassName, genHash.getCode) val refs = ctx.addReusableObject(subCtx.references.toArray, "subRefs") val hashFunc = newName("hashFunc") ctx.addReusableMember(s"${classOf[HashFunction].getCanonicalName} $hashFunc;") ctx.addReusableInitStatement(s"$hashFunc = new ${genHash.getClassName}($refs);") s"$hashFunc.hashCode($term)" + case DISTINCT_TYPE => + hashCodeForType(ctx, t.asInstanceOf[DistinctType].getSourceType, term) case RAW => - val gt = t.asInstanceOf[TypeInformationRawType[_]] - val serTerm = ctx.addReusableObject( - gt.getTypeInformation.createSerializer(new ExecutionConfig), "serializer") + val serializer = t match { + case rt: RawType[_] => + rt.getTypeSerializer + case tirt: TypeInformationRawType[_] => + tirt.getTypeInformation.createSerializer(new ExecutionConfig) + } + val serTerm = ctx.addReusableObject(serializer, "serializer") s"$BINARY_RAW_VALUE.getJavaObjectFromRawValueData($term, $serTerm).hashCode()" + case NULL | SYMBOL | UNRESOLVED => + throw new IllegalArgumentException("Illegal type: " + t) } // ---------------------------------------------------------------------------------------------- @@ -406,6 +409,11 @@ object CodeGenUtils { throw new CodeGenException("Integer expression type expected.") } + def udfFieldName(udf: UserDefinedFunction): String = s"function_${udf.functionIdentifier}" + + def genLogInfo(logTerm: String, format: String, argTerm: String): String = + s"""$logTerm.info("$format", $argTerm);""" + // -------------------------------------------------------------------------------- // DataFormat Operations // -------------------------------------------------------------------------------- @@ -419,44 +427,50 @@ object CodeGenUtils { fieldType: LogicalType) : String = rowFieldReadAccess(ctx, index.toString, rowTerm, fieldType) + @tailrec def rowFieldReadAccess( ctx: CodeGeneratorContext, indexTerm: String, rowTerm: String, - t: LogicalType) : String = - t.getTypeRoot match { - // primitive types - case BOOLEAN => s"$rowTerm.getBoolean($indexTerm)" - case TINYINT => s"$rowTerm.getByte($indexTerm)" - case SMALLINT => s"$rowTerm.getShort($indexTerm)" - case INTEGER => s"$rowTerm.getInt($indexTerm)" - case BIGINT => s"$rowTerm.getLong($indexTerm)" - case FLOAT => s"$rowTerm.getFloat($indexTerm)" - case DOUBLE => s"$rowTerm.getDouble($indexTerm)" - case VARCHAR | CHAR => s"(($BINARY_STRING) $rowTerm.getString($indexTerm))" - case VARBINARY | BINARY => s"$rowTerm.getBinary($indexTerm)" + t: LogicalType) + : String = t.getTypeRoot match { + // ordered by type root definition + case CHAR | VARCHAR => + s"(($BINARY_STRING) $rowTerm.getString($indexTerm))" + case BOOLEAN => + s"$rowTerm.getBoolean($indexTerm)" + case BINARY | VARBINARY => + s"$rowTerm.getBinary($indexTerm)" case DECIMAL => - val dt = t.asInstanceOf[DecimalType] - s"$rowTerm.getDecimal($indexTerm, ${dt.getPrecision}, ${dt.getScale})" - - // temporal types - case DATE => s"$rowTerm.getInt($indexTerm)" - case TIME_WITHOUT_TIME_ZONE => s"$rowTerm.getInt($indexTerm)" - case TIMESTAMP_WITHOUT_TIME_ZONE => - val dt = t.asInstanceOf[TimestampType] - s"$rowTerm.getTimestamp($indexTerm, ${dt.getPrecision})" - case TIMESTAMP_WITH_LOCAL_TIME_ZONE => - val dt = t.asInstanceOf[LocalZonedTimestampType] - s"$rowTerm.getTimestamp($indexTerm, ${dt.getPrecision})" - case INTERVAL_YEAR_MONTH => s"$rowTerm.getInt($indexTerm)" - case INTERVAL_DAY_TIME => s"$rowTerm.getLong($indexTerm)" - - // complex types - case ARRAY => s"$rowTerm.getArray($indexTerm)" - case MULTISET | MAP => s"$rowTerm.getMap($indexTerm)" - case ROW => s"$rowTerm.getRow($indexTerm, ${t.asInstanceOf[RowType].getFieldCount})" - - case RAW => s"(($BINARY_RAW_VALUE) $rowTerm.getRawValue($indexTerm))" + s"$rowTerm.getDecimal($indexTerm, ${getPrecision(t)}, ${getScale(t)})" + case TINYINT => + s"$rowTerm.getByte($indexTerm)" + case SMALLINT => + s"$rowTerm.getShort($indexTerm)" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => + s"$rowTerm.getInt($indexTerm)" + case BIGINT | INTERVAL_DAY_TIME => + s"$rowTerm.getLong($indexTerm)" + case FLOAT => + s"$rowTerm.getFloat($indexTerm)" + case DOUBLE => + s"$rowTerm.getDouble($indexTerm)" + case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => + s"$rowTerm.getTimestamp($indexTerm, ${getPrecision(t)})" + case TIMESTAMP_WITH_TIME_ZONE => + throw new UnsupportedOperationException("Unsupported type: " + t) + case ARRAY => + s"$rowTerm.getArray($indexTerm)" + case MULTISET | MAP => + s"$rowTerm.getMap($indexTerm)" + case ROW | STRUCTURED_TYPE => + s"$rowTerm.getRow($indexTerm, ${getFieldCount(t)})" + case DISTINCT_TYPE => + rowFieldReadAccess(ctx, indexTerm, rowTerm, t.asInstanceOf[DistinctType].getSourceType) + case RAW => + s"(($BINARY_RAW_VALUE) $rowTerm.getRawValue($indexTerm))" + case NULL | SYMBOL | UNRESOLVED => + throw new IllegalArgumentException("Illegal type: " + t) } // -------------------------- RowData Set Field ------------------------------- @@ -549,14 +563,22 @@ object CodeGenUtils { def binaryRowSetNull(index: Int, rowTerm: String, t: LogicalType): String = binaryRowSetNull(index.toString, rowTerm, t) - def binaryRowSetNull(indexTerm: String, rowTerm: String, t: LogicalType): String = t match { - case d: DecimalType if !DecimalData.isCompact(d.getPrecision) => - s"$rowTerm.setDecimal($indexTerm, null, ${d.getPrecision})" - case d: TimestampType if !TimestampData.isCompact(d.getPrecision) => - s"$rowTerm.setTimestamp($indexTerm, null, ${d.getPrecision})" - case d: LocalZonedTimestampType if !TimestampData.isCompact(d.getPrecision) => - s"$rowTerm.setTimestamp($indexTerm, null, ${d.getPrecision})" - case _ => s"$rowTerm.setNullAt($indexTerm)" + @tailrec + def binaryRowSetNull( + indexTerm: String, + rowTerm: String, + t: LogicalType) + : String = t.getTypeRoot match { + // ordered by type root definition + case DECIMAL if !DecimalData.isCompact(getPrecision(t)) => + s"$rowTerm.setDecimal($indexTerm, null, ${getPrecision(t)})" + case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE + if !TimestampData.isCompact(getPrecision(t)) => + s"$rowTerm.setTimestamp($indexTerm, null, ${getPrecision(t)})" + case DISTINCT_TYPE => + binaryRowSetNull(indexTerm, rowTerm, t.asInstanceOf[DistinctType].getSourceType) + case _ => + s"$rowTerm.setNullAt($indexTerm)" } def binaryRowFieldSetAccess( @@ -566,75 +588,102 @@ object CodeGenUtils { fieldValTerm: String): String = binaryRowFieldSetAccess(index.toString, binaryRowTerm, fieldType, fieldValTerm) + @tailrec def binaryRowFieldSetAccess( index: String, binaryRowTerm: String, t: LogicalType, - fieldValTerm: String): String = - t.getTypeRoot match { - case INTEGER => s"$binaryRowTerm.setInt($index, $fieldValTerm)" - case BIGINT => s"$binaryRowTerm.setLong($index, $fieldValTerm)" - case SMALLINT => s"$binaryRowTerm.setShort($index, $fieldValTerm)" - case TINYINT => s"$binaryRowTerm.setByte($index, $fieldValTerm)" - case FLOAT => s"$binaryRowTerm.setFloat($index, $fieldValTerm)" - case DOUBLE => s"$binaryRowTerm.setDouble($index, $fieldValTerm)" - case BOOLEAN => s"$binaryRowTerm.setBoolean($index, $fieldValTerm)" - case DATE => s"$binaryRowTerm.setInt($index, $fieldValTerm)" - case TIME_WITHOUT_TIME_ZONE => s"$binaryRowTerm.setInt($index, $fieldValTerm)" - case TIMESTAMP_WITHOUT_TIME_ZONE => - val dt = t.asInstanceOf[TimestampType] - s"$binaryRowTerm.setTimestamp($index, $fieldValTerm, ${dt.getPrecision})" - case TIMESTAMP_WITH_LOCAL_TIME_ZONE => - val dt = t.asInstanceOf[LocalZonedTimestampType] - s"$binaryRowTerm.setTimestamp($index, $fieldValTerm, ${dt.getPrecision})" - case INTERVAL_YEAR_MONTH => s"$binaryRowTerm.setInt($index, $fieldValTerm)" - case INTERVAL_DAY_TIME => s"$binaryRowTerm.setLong($index, $fieldValTerm)" - case DECIMAL => - val dt = t.asInstanceOf[DecimalType] - s"$binaryRowTerm.setDecimal($index, $fieldValTerm, ${dt.getPrecision})" - case _ => - throw new CodeGenException("Fail to find binary row field setter method of LogicalType " - + t + ".") - } + fieldValTerm: String) + : String = t.getTypeRoot match { + // ordered by type root definition + case BOOLEAN => + s"$binaryRowTerm.setBoolean($index, $fieldValTerm)" + case DECIMAL => + s"$binaryRowTerm.setDecimal($index, $fieldValTerm, ${getPrecision(t)})" + case TINYINT => + s"$binaryRowTerm.setByte($index, $fieldValTerm)" + case SMALLINT => + s"$binaryRowTerm.setShort($index, $fieldValTerm)" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => + s"$binaryRowTerm.setInt($index, $fieldValTerm)" + case BIGINT | INTERVAL_DAY_TIME => + s"$binaryRowTerm.setLong($index, $fieldValTerm)" + case FLOAT => + s"$binaryRowTerm.setFloat($index, $fieldValTerm)" + case DOUBLE => + s"$binaryRowTerm.setDouble($index, $fieldValTerm)" + case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => + s"$binaryRowTerm.setTimestamp($index, $fieldValTerm, ${getPrecision(t)})" + case DISTINCT_TYPE => + binaryRowFieldSetAccess( + index, + binaryRowTerm, + t.asInstanceOf[DistinctType].getSourceType, + fieldValTerm) + case _ => + throw new CodeGenException( + "Fail to find binary row field setter method of LogicalType " + t + ".") + } // -------------------------- BoxedWrapperRowData Set Field ------------------------------- + @tailrec def boxedWrapperRowFieldSetAccess( rowTerm: String, indexTerm: String, fieldTerm: String, - t: LogicalType): String = - t.getTypeRoot match { - case INTEGER => s"$rowTerm.setInt($indexTerm, $fieldTerm)" - case BIGINT => s"$rowTerm.setLong($indexTerm, $fieldTerm)" - case SMALLINT => s"$rowTerm.setShort($indexTerm, $fieldTerm)" - case TINYINT => s"$rowTerm.setByte($indexTerm, $fieldTerm)" - case FLOAT => s"$rowTerm.setFloat($indexTerm, $fieldTerm)" - case DOUBLE => s"$rowTerm.setDouble($indexTerm, $fieldTerm)" - case BOOLEAN => s"$rowTerm.setBoolean($indexTerm, $fieldTerm)" - case DATE => s"$rowTerm.setInt($indexTerm, $fieldTerm)" - case TIME_WITHOUT_TIME_ZONE => s"$rowTerm.setInt($indexTerm, $fieldTerm)" - case INTERVAL_YEAR_MONTH => s"$rowTerm.setInt($indexTerm, $fieldTerm)" - case INTERVAL_DAY_TIME => s"$rowTerm.setLong($indexTerm, $fieldTerm)" - case _ => s"$rowTerm.setNonPrimitiveValue($indexTerm, $fieldTerm)" - } + t: LogicalType) + : String = t.getTypeRoot match { + // ordered by type root definition + case BOOLEAN => + s"$rowTerm.setBoolean($indexTerm, $fieldTerm)" + case TINYINT => + s"$rowTerm.setByte($indexTerm, $fieldTerm)" + case SMALLINT => + s"$rowTerm.setShort($indexTerm, $fieldTerm)" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => + s"$rowTerm.setInt($indexTerm, $fieldTerm)" + case BIGINT | INTERVAL_DAY_TIME => + s"$rowTerm.setLong($indexTerm, $fieldTerm)" + case FLOAT => + s"$rowTerm.setFloat($indexTerm, $fieldTerm)" + case DOUBLE => + s"$rowTerm.setDouble($indexTerm, $fieldTerm)" + case DISTINCT_TYPE => + boxedWrapperRowFieldSetAccess( + rowTerm, + indexTerm, + fieldTerm, + t.asInstanceOf[DistinctType].getSourceType) + case _ => + s"$rowTerm.setNonPrimitiveValue($indexTerm, $fieldTerm)" + } // -------------------------- BinaryArray Set Access ------------------------------- + @tailrec def binaryArraySetNull( index: Int, arrayTerm: String, - t: LogicalType): String = t.getTypeRoot match { - case BOOLEAN => s"$arrayTerm.setNullBoolean($index)" - case TINYINT => s"$arrayTerm.setNullByte($index)" - case SMALLINT => s"$arrayTerm.setNullShort($index)" - case INTEGER => s"$arrayTerm.setNullInt($index)" - case FLOAT => s"$arrayTerm.setNullFloat($index)" - case DOUBLE => s"$arrayTerm.setNullDouble($index)" - case TIME_WITHOUT_TIME_ZONE => s"$arrayTerm.setNullInt($index)" - case DATE => s"$arrayTerm.setNullInt($index)" - case INTERVAL_YEAR_MONTH => s"$arrayTerm.setNullInt($index)" - case _ => s"$arrayTerm.setNullLong($index)" + t: LogicalType) + : String = t.getTypeRoot match { + // ordered by type root definition + case BOOLEAN => + s"$arrayTerm.setNullBoolean($index)" + case TINYINT => + s"$arrayTerm.setNullByte($index)" + case SMALLINT => + s"$arrayTerm.setNullShort($index)" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => + s"$arrayTerm.setNullInt($index)" + case FLOAT => + s"$arrayTerm.setNullFloat($index)" + case DOUBLE => + s"$arrayTerm.setNullDouble($index)" + case DISTINCT_TYPE => + binaryArraySetNull(index, arrayTerm, t) + case _ => + s"$arrayTerm.setNullLong($index)" } // -------------------------- BinaryWriter Write ------------------------------- @@ -642,17 +691,22 @@ object CodeGenUtils { def binaryWriterWriteNull(index: Int, writerTerm: String, t: LogicalType): String = binaryWriterWriteNull(index.toString, writerTerm, t) + @tailrec def binaryWriterWriteNull( indexTerm: String, writerTerm: String, - t: LogicalType): String = t match { - case d: DecimalType if !DecimalData.isCompact(d.getPrecision) => - s"$writerTerm.writeDecimal($indexTerm, null, ${d.getPrecision})" - case d: TimestampType if !TimestampData.isCompact(d.getPrecision) => - s"$writerTerm.writeTimestamp($indexTerm, null, ${d.getPrecision})" - case d: LocalZonedTimestampType if !TimestampData.isCompact(d.getPrecision) => - s"$writerTerm.writeTimestamp($indexTerm, null, ${d.getPrecision})" - case _ => s"$writerTerm.setNullAt($indexTerm)" + t: LogicalType) + : String = t.getTypeRoot match { + // ordered by type root definition + case DECIMAL if !DecimalData.isCompact(getPrecision(t)) => + s"$writerTerm.writeDecimal($indexTerm, null, ${getPrecision(t)})" + case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE + if !TimestampData.isCompact(getPrecision(t)) => + s"$writerTerm.writeTimestamp($indexTerm, null, ${getPrecision(t)})" + case DISTINCT_TYPE => + binaryWriterWriteNull(indexTerm, writerTerm, t.asInstanceOf[DistinctType].getSourceType) + case _ => + s"$writerTerm.setNullAt($indexTerm)" } def binaryWriterWriteField( @@ -663,50 +717,74 @@ object CodeGenUtils { fieldType: LogicalType): String = binaryWriterWriteField(ctx, index.toString, fieldValTerm, writerTerm, fieldType) + @tailrec def binaryWriterWriteField( ctx: CodeGeneratorContext, indexTerm: String, fieldValTerm: String, writerTerm: String, - t: LogicalType): String = - t.getTypeRoot match { - case INTEGER => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" - case BIGINT => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" - case SMALLINT => s"$writerTerm.writeShort($indexTerm, $fieldValTerm)" - case TINYINT => s"$writerTerm.writeByte($indexTerm, $fieldValTerm)" - case FLOAT => s"$writerTerm.writeFloat($indexTerm, $fieldValTerm)" - case DOUBLE => s"$writerTerm.writeDouble($indexTerm, $fieldValTerm)" - case BOOLEAN => s"$writerTerm.writeBoolean($indexTerm, $fieldValTerm)" - case VARBINARY | BINARY => s"$writerTerm.writeBinary($indexTerm, $fieldValTerm)" - case VARCHAR | CHAR => s"$writerTerm.writeString($indexTerm, $fieldValTerm)" - case DECIMAL => - val dt = t.asInstanceOf[DecimalType] - s"$writerTerm.writeDecimal($indexTerm, $fieldValTerm, ${dt.getPrecision})" - case DATE => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" - case TIME_WITHOUT_TIME_ZONE => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" - case TIMESTAMP_WITHOUT_TIME_ZONE => - val dt = t.asInstanceOf[TimestampType] - s"$writerTerm.writeTimestamp($indexTerm, $fieldValTerm, ${dt.getPrecision})" - case TIMESTAMP_WITH_LOCAL_TIME_ZONE => - val dt = t.asInstanceOf[LocalZonedTimestampType] - s"$writerTerm.writeTimestamp($indexTerm, $fieldValTerm, ${dt.getPrecision})" - case INTERVAL_YEAR_MONTH => s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" - case INTERVAL_DAY_TIME => s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" - - // complex types - case ARRAY => - val ser = ctx.addReusableTypeSerializer(t) - s"$writerTerm.writeArray($indexTerm, $fieldValTerm, $ser)" - case MULTISET | MAP => - val ser = ctx.addReusableTypeSerializer(t) - s"$writerTerm.writeMap($indexTerm, $fieldValTerm, $ser)" - case ROW => - val ser = ctx.addReusableTypeSerializer(t) - s"$writerTerm.writeRow($indexTerm, $fieldValTerm, $ser)" - case RAW => - val ser = ctx.addReusableTypeSerializer(t) - s"$writerTerm.writeRawValue($indexTerm, $fieldValTerm, $ser)" - } + t: LogicalType) + : String = t.getTypeRoot match { + // ordered by type root definition + case CHAR | VARCHAR => + s"$writerTerm.writeString($indexTerm, $fieldValTerm)" + case BOOLEAN => + s"$writerTerm.writeBoolean($indexTerm, $fieldValTerm)" + case BINARY | VARBINARY => + s"$writerTerm.writeBinary($indexTerm, $fieldValTerm)" + case DECIMAL => + s"$writerTerm.writeDecimal($indexTerm, $fieldValTerm, ${getPrecision(t)})" + case TINYINT => + s"$writerTerm.writeByte($indexTerm, $fieldValTerm)" + case SMALLINT => + s"$writerTerm.writeShort($indexTerm, $fieldValTerm)" + case INTEGER | DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH => + s"$writerTerm.writeInt($indexTerm, $fieldValTerm)" + case BIGINT | INTERVAL_DAY_TIME => + s"$writerTerm.writeLong($indexTerm, $fieldValTerm)" + case FLOAT => + s"$writerTerm.writeFloat($indexTerm, $fieldValTerm)" + case DOUBLE => + s"$writerTerm.writeDouble($indexTerm, $fieldValTerm)" + case TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => + s"$writerTerm.writeTimestamp($indexTerm, $fieldValTerm, ${getPrecision(t)})" + case TIMESTAMP_WITH_TIME_ZONE => + throw new UnsupportedOperationException("Unsupported type: " + t) + case ARRAY => + val ser = ctx.addReusableTypeSerializer(t) + s"$writerTerm.writeArray($indexTerm, $fieldValTerm, $ser)" + case MULTISET | MAP => + val ser = ctx.addReusableTypeSerializer(t) + s"$writerTerm.writeMap($indexTerm, $fieldValTerm, $ser)" + case ROW | STRUCTURED_TYPE => + val ser = ctx.addReusableTypeSerializer(t) + s"$writerTerm.writeRow($indexTerm, $fieldValTerm, $ser)" + case DISTINCT_TYPE => + binaryWriterWriteField( + ctx, + indexTerm, + fieldValTerm, + writerTerm, + t.asInstanceOf[DistinctType].getSourceType) + case RAW => + val ser = ctx.addReusableTypeSerializer(t) + s"$writerTerm.writeRawValue($indexTerm, $fieldValTerm, $ser)" + case NULL | SYMBOL | UNRESOLVED => + throw new IllegalArgumentException("Illegal type: " + t); + } + + // -------------------------- Data Structure Conversion ------------------------------- + + /** + * If it's internally compatible, don't need to DataStructure converter. + * clazz != classOf[Row] => Row can only infer GenericType[Row]. + */ + def isInternalClass(t: DataType): Boolean = { + val clazz = t.getConversionClass + clazz != classOf[Object] && clazz != classOf[Row] && + (classOf[RowData].isAssignableFrom(clazz) || + clazz == toInternalConversionClass(fromDataTypeToLogicalType(t))) + } private def isConverterIdentity(t: DataType): Boolean = { DataFormatConverters.getConverterForDataType(t).isInstanceOf[IdentityConverter[_]] @@ -808,9 +886,4 @@ object CodeGenUtils { s"${internalExpr.nullTerm} ? null : ($externalResultTerm)" } } - - def udfFieldName(udf: UserDefinedFunction): String = s"function_${udf.functionIdentifier}" - - def genLogInfo(logTerm: String, format: String, argTerm: String): String = - s"""$logTerm.info("$format", $argTerm);""" } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala index 174158d..850d55f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala @@ -24,8 +24,10 @@ import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.generateE import org.apache.flink.table.runtime.generated.{GeneratedRecordEqualiser, RecordEqualiser} import org.apache.flink.table.runtime.types.PlannerTypeUtils import org.apache.flink.table.types.logical.LogicalTypeRoot._ -import org.apache.flink.table.types.logical.{LogicalType, RowType} +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldTypes, isCompositeType} +import org.apache.flink.table.types.logical.{DistinctType, LogicalType} +import scala.annotation.tailrec import scala.collection.JavaConverters._ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { @@ -57,9 +59,9 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { // TODO merge ScalarOperatorGens.generateEquals. val (equalsCode, equalsResult) = if (isInternalPrimitive(fieldType)) { ("", s"$leftFieldTerm == $rightFieldTerm") - } else if (isRowData(fieldType)) { + } else if (isCompositeType(fieldType)) { val equaliserGenerator = new EqualiserCodeGenerator( - fieldType.asInstanceOf[RowType].getChildren.asScala.toArray) + getFieldTypes(fieldType).asScala.toArray) val generatedEqualiser = equaliserGenerator .generateRecordEqualiser("field$" + i + "GeneratedEqualiser") val generatedEqualiserTerm = ctx.addReusableObject( @@ -128,15 +130,14 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType]) { new GeneratedRecordEqualiser(className, functionCode, ctx.references.toArray) } + @tailrec private def isInternalPrimitive(t: LogicalType): Boolean = t.getTypeRoot match { case _ if PlannerTypeUtils.isPrimitive(t) => true - case DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH |INTERVAL_DAY_TIME => true - case _ => false - } + case DATE | TIME_WITHOUT_TIME_ZONE | INTERVAL_YEAR_MONTH | INTERVAL_DAY_TIME => true + + case DISTINCT_TYPE => isInternalPrimitive(t.asInstanceOf[DistinctType].getSourceType) - private def isRowData(t: LogicalType): Boolean = t match { - case _: RowType => true case _ => false } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala index 950a35b..82a0122 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala @@ -30,11 +30,9 @@ import org.apache.flink.table.planner.codegen.FunctionCodeGenerator.generateFunc import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall import org.apache.flink.table.types.logical.RowType import org.apache.flink.table.util.TimestampStringUtils.fromLocalDateTime - import org.apache.calcite.avatica.util.ByteString import org.apache.calcite.rex.{RexBuilder, RexExecutor, RexNode} import org.apache.calcite.sql.`type`.SqlTypeName - import java.io.File import scala.collection.JavaConverters._ @@ -72,7 +70,9 @@ class ExpressionReducer( // we don't support object literals yet, we skip those constant expressions case (SqlTypeName.ANY, _) | + (SqlTypeName.OTHER, _) | (SqlTypeName.ROW, _) | + (SqlTypeName.STRUCTURED, _) | (SqlTypeName.ARRAY, _) | (SqlTypeName.MAP, _) | (SqlTypeName.MULTISET, _) => None @@ -133,7 +133,9 @@ class ExpressionReducer( unreduced.getType.getSqlTypeName match { // we insert the original expression for object literals case SqlTypeName.ANY | + SqlTypeName.OTHER | SqlTypeName.ROW | + SqlTypeName.STRUCTURED | SqlTypeName.ARRAY | SqlTypeName.MAP | SqlTypeName.MULTISET => diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala index 4fc52d2..9d0fe44 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala @@ -34,12 +34,13 @@ import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE} import org.apache.flink.table.planner.codegen.calls.CurrentTimePointCallGen import org.apache.flink.table.planner.plan.utils.SortUtil -import org.apache.flink.table.runtime.types.PlannerTypeUtils import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.{isCharacterString, isReference, isTemporal} import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical._ +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldCount, getFieldTypes} import org.apache.flink.table.util.TimestampStringUtils.toLocalDateTime +import scala.annotation.tailrec import scala.collection.mutable /** @@ -209,39 +210,47 @@ object GenerateUtils { /** * Generates a record declaration statement. The record can be any type of RowData or * other types. + * * @param t the record type * @param clazz the specified class of the type (only used when RowType) * @param recordTerm the record term to be declared * @param recordWriterTerm the record writer term (only used when BinaryRowData type) * @return the record declaration statement - */ + */ + @tailrec def generateRecordStatement( t: LogicalType, clazz: Class[_], recordTerm: String, - recordWriterTerm: Option[String] = None): String = { - t match { - case rt: RowType if clazz == classOf[BinaryRowData] => - val writerTerm = recordWriterTerm.getOrElse( - throw new CodeGenException("No writer is specified when writing BinaryRowData record.") - ) - val binaryRowWriter = className[BinaryRowWriter] - val typeTerm = clazz.getCanonicalName - s""" - |final $typeTerm $recordTerm = new $typeTerm(${rt.getFieldCount}); - |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm); - |""".stripMargin.trim - case rt: RowType if clazz == classOf[GenericRowData] || - clazz == classOf[BoxedWrapperRowData] => - val typeTerm = clazz.getCanonicalName - s"final $typeTerm $recordTerm = new $typeTerm(${rt.getFieldCount});" - case _: RowType if clazz == classOf[JoinedRowData] => - val typeTerm = clazz.getCanonicalName - s"final $typeTerm $recordTerm = new $typeTerm();" - case _ => - val typeTerm = boxedTypeTermForType(t) - s"final $typeTerm $recordTerm = new $typeTerm();" - } + recordWriterTerm: Option[String] = None) + : String = t.getTypeRoot match { + // ordered by type root definition + case ROW | STRUCTURED_TYPE if clazz == classOf[BinaryRowData] => + val writerTerm = recordWriterTerm.getOrElse( + throw new CodeGenException("No writer is specified when writing BinaryRowData record.") + ) + val binaryRowWriter = className[BinaryRowWriter] + val typeTerm = clazz.getCanonicalName + s""" + |final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)}); + |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm); + |""".stripMargin.trim + case ROW | STRUCTURED_TYPE if clazz == classOf[GenericRowData] || + clazz == classOf[BoxedWrapperRowData] => + val typeTerm = clazz.getCanonicalName + s"final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});" + case ROW | STRUCTURED_TYPE if clazz == classOf[JoinedRowData] => + val typeTerm = clazz.getCanonicalName + s"final $typeTerm $recordTerm = new $typeTerm();" + case DISTINCT_TYPE => + generateRecordStatement( + t.asInstanceOf[DistinctType].getSourceType, + clazz, + recordTerm, + recordWriterTerm) + case _ => + val typeTerm = boxedTypeTermForType(t) + s"final $typeTerm $recordTerm = new $typeTerm();" } def generateNullLiteral( @@ -273,6 +282,7 @@ object GenerateUtils { literalValue = Some(literalValue)) } + @tailrec def generateLiteral( ctx: CodeGeneratorContext, literalType: LogicalType, @@ -282,10 +292,41 @@ object GenerateUtils { } // non-null values literalType.getTypeRoot match { + // ordered by type root definition + case CHAR | VARCHAR => + val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString) + val field = ctx.addReusableStringConstants(escapedValue) + generateNonNullLiteral(literalType, field, StringData.fromString(escapedValue)) case BOOLEAN => generateNonNullLiteral(literalType, literalValue.toString, literalValue) + case BINARY | VARBINARY => + val bytesVal = literalValue.asInstanceOf[ByteString].getBytes + val fieldTerm = ctx.addReusableObject( + bytesVal, "binary", bytesVal.getClass.getCanonicalName) + generateNonNullLiteral(literalType, fieldTerm, bytesVal) + + case DECIMAL => + val dt = literalType.asInstanceOf[DecimalType] + val precision = dt.getPrecision + val scale = dt.getScale + val fieldTerm = newName("decimal") + val decimalClass = className[DecimalData] + val fieldDecimal = + s""" + |$decimalClass $fieldTerm = + | $DECIMAL_UTIL.castFrom("${literalValue.toString}", $precision, $scale); + |""".stripMargin + ctx.addReusableMember(fieldDecimal) + val value = DecimalData.fromBigDecimal( + literalValue.asInstanceOf[JBigDecimal], precision, scale) + if (value == null) { + generateNullLiteral(literalType, ctx.nullCheck) + } else { + generateNonNullLiteral(literalType, fieldTerm, value) + } + case TINYINT => val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal]) generateNonNullLiteral(literalType, decimal.byteValue().toString, decimal.byteValue()) @@ -335,36 +376,6 @@ object GenerateUtils { case _ => generateNonNullLiteral( literalType, doubleValue.toString + "d", doubleValue) } - case DECIMAL => - val dt = literalType.asInstanceOf[DecimalType] - val precision = dt.getPrecision - val scale = dt.getScale - val fieldTerm = newName("decimal") - val decimalClass = className[DecimalData] - val fieldDecimal = - s""" - |$decimalClass $fieldTerm = - | $DECIMAL_UTIL.castFrom("${literalValue.toString}", $precision, $scale); - |""".stripMargin - ctx.addReusableMember(fieldDecimal) - val value = DecimalData.fromBigDecimal( - literalValue.asInstanceOf[JBigDecimal], precision, scale) - if (value == null) { - generateNullLiteral(literalType, ctx.nullCheck) - } else { - generateNonNullLiteral(literalType, fieldTerm, value) - } - - case VARCHAR | CHAR => - val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString) - val field = ctx.addReusableStringConstants(escapedValue) - generateNonNullLiteral(literalType, field, StringData.fromString(escapedValue)) - - case VARBINARY | BINARY => - val bytesVal = literalValue.asInstanceOf[ByteString].getBytes - val fieldTerm = ctx.addReusableObject( - bytesVal, "binary", bytesVal.getClass.getCanonicalName) - generateNonNullLiteral(literalType, fieldTerm, bytesVal) case DATE => generateNonNullLiteral(literalType, literalValue.toString, literalValue) @@ -384,6 +395,9 @@ object GenerateUtils { ctx.addReusableMember(fieldTimestamp) generateNonNullLiteral(literalType, fieldTerm, ts) + case TIMESTAMP_WITH_TIME_ZONE => + throw new UnsupportedOperationException("Unsupported type: " + literalType) + case TIMESTAMP_WITH_LOCAL_TIME_ZONE => val fieldTerm = newName("timestampWithLocalZone") val ins = @@ -420,13 +434,19 @@ object GenerateUtils { s"Decimal '$decimal' can not be converted to interval of milliseconds.") } + case DISTINCT_TYPE => + generateLiteral(ctx, literalType.asInstanceOf[DistinctType].getSourceType, literalValue) + // Symbol type for special flags e.g. TRIM's BOTH, LEADING, TRAILING case RAW if literalType.asInstanceOf[TypeInformationRawType[_]] .getTypeInformation.getTypeClass.isAssignableFrom(classOf[Enum[_]]) => generateSymbol(literalValue.asInstanceOf[Enum[_]]) - case t@_ => - throw new CodeGenException(s"Type not supported: $t") + case SYMBOL => + throw new UnsupportedOperationException() // TODO support symbol? + + case ARRAY | MULTISET | MAP | ROW | STRUCTURED_TYPE | NULL | UNRESOLVED => + throw new CodeGenException(s"Type not supported: $literalType") } } @@ -546,10 +566,15 @@ object GenerateUtils { index: Int, deepCopy: Boolean = false): GeneratedExpression = { - val fieldType = inputType match { - case ct: RowType => ct.getTypeAt(index) - case _ => inputType + @tailrec + def getFieldType(t: LogicalType, pos: Int): LogicalType = t.getTypeRoot match { + // ordered by type root definition + case ROW | STRUCTURED_TYPE => t.getChildren.get(pos) + case DISTINCT_TYPE => getFieldType(t.asInstanceOf[DistinctType].getSourceType, pos) + case _ => t } + + val fieldType = getFieldType(inputType, index) val resultTypeTerm = primitiveTypeTermForType(fieldType) val defaultValue = primitiveDefaultValue(fieldType) val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( @@ -636,14 +661,16 @@ object GenerateUtils { } } + @tailrec def generateFieldAccess( ctx: CodeGeneratorContext, inputType: LogicalType, inputTerm: String, - index: Int): GeneratedExpression = - inputType match { - case ct: RowType => - val fieldType = ct.getTypeAt(index) + index: Int) + : GeneratedExpression = inputType.getTypeRoot match { + // ordered by type root definition + case ROW | STRUCTURED_TYPE => + val fieldType = getFieldTypes(inputType).get(index) val resultTypeTerm = primitiveTypeTermForType(fieldType) val defaultValue = primitiveDefaultValue(fieldType) val readCode = rowFieldReadAccess(ctx, index.toString, inputTerm, fieldType) @@ -667,6 +694,13 @@ object GenerateUtils { } GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType) + case DISTINCT_TYPE => + generateFieldAccess( + ctx, + inputType.asInstanceOf[DistinctType].getSourceType, + inputTerm, + index) + case _ => val fieldTypeTerm = boxedTypeTermForType(inputType) val inputCode = s"($fieldTypeTerm) $inputTerm" @@ -674,23 +708,30 @@ object GenerateUtils { } /** - * Generates code for comparing two field. + * Generates code for comparing two fields. */ + @tailrec def generateCompare( ctx: CodeGeneratorContext, t: LogicalType, nullsIsLast: Boolean, leftTerm: String, - rightTerm: String): String = t.getTypeRoot match { - case BOOLEAN => s"($leftTerm == $rightTerm ? 0 : ($leftTerm ? 1 : -1))" - case DATE | TIME_WITHOUT_TIME_ZONE => - s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)" - case _ if PlannerTypeUtils.isPrimitive(t) => - s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)" - case VARBINARY | BINARY => + rightTerm: String) + : String = t.getTypeRoot match { + // ordered by type root definition + case CHAR | VARCHAR | DECIMAL | TIMESTAMP_WITHOUT_TIME_ZONE | TIMESTAMP_WITH_LOCAL_TIME_ZONE => + s"$leftTerm.compareTo($rightTerm)" + case BOOLEAN => + s"($leftTerm == $rightTerm ? 0 : ($leftTerm ? 1 : -1))" + case BINARY | VARBINARY => val sortUtil = classOf[org.apache.flink.table.runtime.operators.sort.SortUtil] .getCanonicalName s"$sortUtil.compareBinary($leftTerm, $rightTerm)" + case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | DATE | TIME_WITHOUT_TIME_ZONE | + INTERVAL_YEAR_MONTH | INTERVAL_DAY_TIME => + s"($leftTerm > $rightTerm ? 1 : $leftTerm < $rightTerm ? -1 : 0)" + case TIMESTAMP_WITH_TIME_ZONE | MULTISET | MAP => + throw new UnsupportedOperationException() // TODO support MULTISET and MAP? case ARRAY => val at = t.asInstanceOf[ArrayType] val compareFunc = newName("compareArray") @@ -706,13 +747,13 @@ object GenerateUtils { """ ctx.addReusableMember(funcCode) s"$compareFunc($leftTerm, $rightTerm)" - case ROW => - val rowType = t.asInstanceOf[RowType] - val orders = (0 until rowType.getFieldCount).map(_ => true).toArray + case ROW | STRUCTURED_TYPE => + val fieldCount = getFieldCount(t) + val orders = (0 until fieldCount).map(_ => true).toArray val comparisons = generateRowCompare( ctx, - (0 until rowType.getFieldCount).toArray, - rowType.getChildren.toArray(Array[LogicalType]()), + (0 until fieldCount).toArray, + getFieldTypes(t).toArray(Array[LogicalType]()), orders, SortUtil.getNullDefaultOrders(orders), "a", @@ -727,18 +768,38 @@ object GenerateUtils { """ ctx.addReusableMember(funcCode) s"$compareFunc($leftTerm, $rightTerm)" + case DISTINCT_TYPE => + generateCompare( + ctx, + t.asInstanceOf[DistinctType].getSourceType, + nullsIsLast, + leftTerm, + rightTerm) case RAW => - val rawType = t.asInstanceOf[TypeInformationRawType[_]] - val ser = ctx.addReusableObject( - rawType.getTypeInformation.createSerializer(new ExecutionConfig), "serializer") - val comp = ctx.addReusableObject( - rawType.getTypeInformation.asInstanceOf[AtomicTypeInfo[_]] - .createComparator(true, new ExecutionConfig), - "comparator") - s""" - |$comp.compare($leftTerm.toObject($ser), $rightTerm.toObject($ser)) - """.stripMargin - case other => s"$leftTerm.compareTo($rightTerm)" + t match { + case rawType: RawType[_] => + val clazz = rawType.getOriginatingClass + if (!classOf[Comparable[_]].isAssignableFrom(clazz)) { + throw new CodeGenException( + s"Raw type class '$clazz' must implement ${className[Comparable[_]]} to be used " + + s"in a comparision of two '${rawType.asSummaryString()}' types.") + } + val serializer = rawType.getTypeSerializer + val serializerTerm = ctx.addReusableObject(serializer, "serializer") + s"((${className[Comparable[_]]}) $leftTerm.toObject($serializerTerm))" + + s".compareTo($rightTerm.toObject($serializerTerm))" + + case rawType: TypeInformationRawType[_] => + val serializer = rawType.getTypeInformation.createSerializer(new ExecutionConfig) + val ser = ctx.addReusableObject(serializer, "serializer") + val comp = ctx.addReusableObject( + rawType.getTypeInformation.asInstanceOf[AtomicTypeInfo[_]] + .createComparator(true, new ExecutionConfig), + "comparator") + s"$comp.compare($leftTerm.toObject($ser), $rightTerm.toObject($ser))" + } + case NULL | SYMBOL | UNRESOLVED => + throw new IllegalArgumentException("Illegal type: " + t) } /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala index 7493aa2..1fbacc0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala @@ -18,7 +18,6 @@ package org.apache.flink.table.planner.codegen.agg.batch -import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.runtime.util.SingleElementIterator import org.apache.flink.streaming.api.operators.OneInputStreamOperator import org.apache.flink.table.data.{GenericRowData, RowData} @@ -39,12 +38,13 @@ import org.apache.flink.table.runtime.types.InternalSerializers import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.{fromDataTypeToLogicalType, fromLogicalTypeToDataType} import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.LogicalTypeRoot._ -import org.apache.flink.table.types.logical.{LogicalType, RowType} - +import org.apache.flink.table.types.logical.{DistinctType, LogicalType, RowType} import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder +import scala.annotation.tailrec + /** * Batch aggregate code generate helper. */ @@ -360,16 +360,7 @@ object AggCodeGenHelper { aggBufferExprs.zip(initAggBufferExprs).map { case (aggBufVar, initExpr) => - val resultCode = aggBufVar.resultType.getTypeRoot match { - case VARCHAR | CHAR | ROW | ARRAY | MULTISET | MAP => - val serializer = InternalSerializers.create( - aggBufVar.resultType, new ExecutionConfig) - val term = ctx.addReusableObject( - serializer, "serializer", serializer.getClass.getCanonicalName) - val typeTerm = boxedTypeTermForType(aggBufVar.resultType) - s"($typeTerm) $term.copy(${initExpr.resultTerm})" - case _ => initExpr.resultTerm - } + val resultCode = genElementCopyTerm(ctx, aggBufVar.resultType, initExpr.resultTerm) s""" |${initExpr.code} |${aggBufVar.nullTerm} = ${initExpr.nullTerm}; @@ -378,6 +369,23 @@ object AggCodeGenHelper { } mkString "\n" } + @tailrec + private def genElementCopyTerm( + ctx: CodeGeneratorContext, + t: LogicalType, + inputTerm: String) + : String = t.getTypeRoot match { + case CHAR | VARCHAR | ARRAY | MULTISET | MAP | ROW | STRUCTURED_TYPE => + val serializer = InternalSerializers.create(t) + val term = ctx.addReusableObject( + serializer, "serializer", serializer.getClass.getCanonicalName) + val typeTerm = boxedTypeTermForType(t) + s"($typeTerm) $term.copy($inputTerm)" + case DISTINCT_TYPE => + genElementCopyTerm(ctx, t.asInstanceOf[DistinctType].getSourceType, inputTerm) + case _ => inputTerm + } + private[flink] def genAggregateByFlatAggregateBuffer( isMerge: Boolean, ctx: CodeGeneratorContext, diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java index 319021d..9a58882 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.typeutils; +import org.apache.flink.table.types.logical.DistinctType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.TimestampKind; @@ -60,9 +61,10 @@ public class TypeCheckUtils { } public static boolean isTimeInterval(LogicalType type) { + // ordered by type root definition switch (type.getTypeRoot()) { - case INTERVAL_DAY_TIME: case INTERVAL_YEAR_MONTH: + case INTERVAL_DAY_TIME: return true; default: return false; @@ -122,22 +124,28 @@ public class TypeCheckUtils { } public static boolean isMutable(LogicalType type) { - // the internal representation of String is StringData which is mutable + // ordered by type root definition switch (type.getTypeRoot()) { - case VARCHAR: case CHAR: + case VARCHAR: // the internal representation of String is StringData which is mutable case ARRAY: case MULTISET: case MAP: case ROW: + case STRUCTURED_TYPE: case RAW: return true; + case TIMESTAMP_WITH_TIME_ZONE: + throw new UnsupportedOperationException("Unsupported type: " + type); + case DISTINCT_TYPE: + return isMutable(((DistinctType) type).getSourceType()); default: return false; } } public static boolean isReference(LogicalType type) { + // ordered by type root definition switch (type.getTypeRoot()) { case BOOLEAN: case TINYINT: @@ -153,6 +161,10 @@ public class TypeCheckUtils { case INTERVAL_YEAR_MONTH: case INTERVAL_DAY_TIME: return false; + case TIMESTAMP_WITH_TIME_ZONE: + throw new UnsupportedOperationException("Unsupported type: " + type); + case DISTINCT_TYPE: + return isReference(((DistinctType) type).getSourceType()); default: return true; }