[FLINK-3226] Implement a CodeGenerator for an efficient translation to DataSet programs
This closes #1595 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a4ad9dd5 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a4ad9dd5 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a4ad9dd5 Branch: refs/heads/tableOnCalcite Commit: a4ad9dd566353c578ad660fda7039757a605f27d Parents: 99f60c8 Author: twalthr <twal...@apache.org> Authored: Fri Feb 5 17:40:54 2016 +0100 Committer: Fabian Hueske <fhue...@apache.org> Committed: Fri Feb 12 11:34:09 2016 +0100 ---------------------------------------------------------------------- .../api/java/table/JavaBatchTranslator.scala | 11 +- .../flink/api/java/table/TableEnvironment.scala | 15 +- .../api/scala/table/ScalaBatchTranslator.scala | 6 +- .../api/scala/table/TableConversions.scala | 9 +- .../api/scala/table/TableEnvironment.scala | 75 ++ .../apache/flink/api/table/TableConfig.scala | 30 +- .../apache/flink/api/table/TableException.scala | 23 + .../api/table/codegen/CodeGenException.scala | 24 + .../flink/api/table/codegen/CodeGenUtils.scala | 176 ++++ .../flink/api/table/codegen/CodeGenerator.scala | 752 ++++++++++++++++++ .../table/codegen/ExpressionCodeGenerator.scala | 794 ------------------- .../api/table/codegen/GenerateFilter.scala | 99 --- .../flink/api/table/codegen/GenerateJoin.scala | 171 ---- .../table/codegen/GenerateResultAssembler.scala | 119 --- .../api/table/codegen/GenerateSelect.scala | 84 -- .../api/table/codegen/GeneratedExpression.scala | 27 + .../api/table/codegen/GeneratedFunction.scala | 23 + .../flink/api/table/codegen/Indenter.scala | 10 +- .../api/table/codegen/OperatorCodeGen.scala | 367 +++++++++ .../flink/api/table/expressions/literals.scala | 6 +- .../flink/api/table/plan/TypeConverter.scala | 83 +- .../plan/nodes/dataset/DataSetExchange.scala | 11 +- .../plan/nodes/dataset/DataSetFlatMap.scala | 24 +- .../plan/nodes/dataset/DataSetGroupReduce.scala | 7 +- .../table/plan/nodes/dataset/DataSetJoin.scala | 7 +- .../table/plan/nodes/dataset/DataSetMap.scala | 26 +- .../plan/nodes/dataset/DataSetReduce.scala | 7 +- .../table/plan/nodes/dataset/DataSetRel.scala | 8 +- .../table/plan/nodes/dataset/DataSetSort.scala | 11 +- .../plan/nodes/dataset/DataSetSource.scala | 190 ++--- .../table/plan/nodes/dataset/DataSetUnion.scala | 7 +- .../plan/rules/dataset/DataSetFilterRule.scala | 50 +- .../plan/rules/dataset/DataSetProjectRule.scala | 38 +- .../runtime/ExpressionAggregateFunction.scala | 100 --- .../runtime/ExpressionFilterFunction.scala | 50 -- .../table/runtime/ExpressionJoinFunction.scala | 57 -- .../runtime/ExpressionSelectFunction.scala | 56 -- .../flink/api/table/runtime/FlatMapRunner.scala | 51 ++ .../api/table/runtime/FunctionCompiler.scala | 35 + .../flink/api/table/runtime/MapRunner.scala | 50 ++ .../flink/api/table/runtime/package.scala | 23 - .../api/table/typeinfo/RenameOperator.scala | 36 - .../table/typeinfo/RenamingProxyTypeInfo.scala | 143 ---- .../flink/api/table/typeinfo/RowTypeInfo.scala | 23 +- .../api/java/table/test/AggregationsITCase.java | 1 - .../flink/api/java/table/test/AsITCase.java | 4 +- .../api/java/table/test/CastingITCase.java | 13 +- .../api/java/table/test/ExpressionsITCase.java | 20 +- .../flink/api/java/table/test/FilterITCase.java | 26 +- .../table/test/GroupedAggregationsITCase.java | 1 - .../flink/api/java/table/test/JoinITCase.java | 1 - .../api/java/table/test/PojoGroupingITCase.java | 4 +- .../flink/api/java/table/test/SelectITCase.java | 22 +- .../table/test/StringExpressionsITCase.java | 7 +- .../flink/api/java/table/test/UnionITCase.java | 1 - .../api/scala/table/test/CastingITCase.scala | 11 +- .../scala/table/test/ExpressionsITCase.scala | 33 +- .../api/scala/table/test/FilterITCase.scala | 59 +- .../api/scala/table/test/SelectITCase.scala | 27 +- .../table/test/StringExpressionsITCase.scala | 6 +- .../api/table/test/TableProgramsTestBase.scala | 97 +++ .../typeinfo/RenamingProxyTypeInfoTest.scala | 75 -- .../api/table/typeinfo/RowComparatorTest.scala | 3 +- .../api/table/typeinfo/RowSerializerTest.scala | 20 +- 64 files changed, 2199 insertions(+), 2146 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/JavaBatchTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/JavaBatchTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/JavaBatchTranslator.scala index f70f477..7e8ee77 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/JavaBatchTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/JavaBatchTranslator.scala @@ -25,7 +25,7 @@ import org.apache.calcite.tools.Programs import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.{DataSet => JavaDataSet} import org.apache.flink.api.table.plan._ -import org.apache.flink.api.table.Table +import org.apache.flink.api.table.{TableConfig, Table} import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, DataSetRel} import org.apache.flink.api.table.plan.rules.FlinkRuleSets import org.apache.flink.api.table.plan.schema.DataSetTable @@ -34,7 +34,7 @@ import org.apache.flink.api.table.plan.schema.DataSetTable * [[PlanTranslator]] for creating [[Table]]s from Java [[org.apache.flink.api.java.DataSet]]s and * translating them back to Java [[org.apache.flink.api.java.DataSet]]s. */ -class JavaBatchTranslator extends PlanTranslator { +class JavaBatchTranslator(config: TableConfig) extends PlanTranslator { type Representation[A] = JavaDataSet[A] @@ -68,7 +68,7 @@ class JavaBatchTranslator extends PlanTranslator { println("Input Plan:") println("-----------") println(RelOptUtil.toString(lPlan)) - + // decorrelate val decorPlan = RelDecorrelator.decorrelateQuery(lPlan) @@ -96,7 +96,10 @@ class JavaBatchTranslator extends PlanTranslator { dataSetPlan match { case node: DataSetRel => - node.translateToPlan.asInstanceOf[JavaDataSet[A]] + node.translateToPlan( + config, + Some(tpe.asInstanceOf[TypeInformation[Any]]) + ).asInstanceOf[JavaDataSet[A]] case _ => ??? } http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/TableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/TableEnvironment.scala index 01e38db..2027037 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/TableEnvironment.scala @@ -20,7 +20,7 @@ package org.apache.flink.api.java.table import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.typeutils.TypeExtractor -import org.apache.flink.api.table.Table +import org.apache.flink.api.table.{TableConfig, Table} /** * Environment for working with the Table API. @@ -30,6 +30,13 @@ import org.apache.flink.api.table.Table */ class TableEnvironment { + private val config = new TableConfig() + + /** + * Returns the table config to define the runtime behavior of the Table API. + */ + def getConfig = config + /** * Transforms the given DataSet to a [[org.apache.flink.api.table.Table]]. * The fields of the DataSet type are renamed to the given set of fields: @@ -44,7 +51,7 @@ class TableEnvironment { * are named a and b. */ def fromDataSet[T](set: DataSet[T], fields: String): Table = { - new JavaBatchTranslator().createTable(set, fields) + new JavaBatchTranslator(config).createTable(set, fields) } /** @@ -53,7 +60,7 @@ class TableEnvironment { * [[org.apache.flink.api.table.Table]] fields. */ def fromDataSet[T](set: DataSet[T]): Table = { - new JavaBatchTranslator().createTable(set) + new JavaBatchTranslator(config).createTable(set) } /** @@ -64,7 +71,7 @@ class TableEnvironment { */ @SuppressWarnings(Array("unchecked")) def toDataSet[T](table: Table, clazz: Class[T]): DataSet[T] = { - new JavaBatchTranslator().translate[T](table.relNode)( + new JavaBatchTranslator(config).translate[T](table.relNode)( TypeExtractor.createTypeInfo(clazz).asInstanceOf[TypeInformation[T]]) } http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/ScalaBatchTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/ScalaBatchTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/ScalaBatchTranslator.scala index cc92c37..642654a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/ScalaBatchTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/ScalaBatchTranslator.scala @@ -22,7 +22,7 @@ import org.apache.calcite.rel.RelNode import org.apache.flink.api.java.table.JavaBatchTranslator import org.apache.flink.api.scala.wrap import org.apache.flink.api.table.plan._ -import org.apache.flink.api.table.Table +import org.apache.flink.api.table.{TableConfig, Table} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala.DataSet @@ -32,9 +32,9 @@ import scala.reflect.ClassTag * [[PlanTranslator]] for creating [[Table]]s from Scala [[DataSet]]s and * translating them back to Scala [[DataSet]]s. */ -class ScalaBatchTranslator extends PlanTranslator { +class ScalaBatchTranslator(config: TableConfig = TableConfig.DEFAULT) extends PlanTranslator { - private val javaTranslator = new JavaBatchTranslator + private val javaTranslator = new JavaBatchTranslator(config) type Representation[A] = DataSet[A] http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableConversions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableConversions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableConversions.scala index fdcd804..74c8ee8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableConversions.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableConversions.scala @@ -29,11 +29,18 @@ import org.apache.flink.api.table._ class TableConversions(table: Table) { /** - * Converts the [[Table]] to a [[DataSet]]. + * Converts the [[Table]] to a [[DataSet]] using the default configuration. */ def toDataSet[T: TypeInformation]: DataSet[T] = { new ScalaBatchTranslator().translate[T](table.relNode) } + /** + * Converts the [[Table]] to a [[DataSet]] using a custom configuration. + */ + def toDataSet[T: TypeInformation](config: TableConfig): DataSet[T] = { + new ScalaBatchTranslator(config).translate[T](table.relNode) + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableEnvironment.scala new file mode 100644 index 0000000..a05bb48 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableEnvironment.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.table + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala.DataSet +import org.apache.flink.api.table.expressions.Expression +import org.apache.flink.api.table.{TableConfig, Table} + +/** + * Environment for working with the Table API. + * + * This can be used to convert a [[DataSet]] to a [[Table]] and back again. You + * can also use the provided methods to create a [[Table]] directly from a data source. + */ +class TableEnvironment { + + private val config = new TableConfig() + + /** + * Returns the table config to define the runtime behavior of the Table API. + */ + def getConfig = config + + /** + * Converts the [[DataSet]] to a [[Table]]. The field names can be specified like this: + * + * {{{ + * val in: DataSet[(String, Int)] = ... + * val table = in.as('a, 'b) + * }}} + * + * This results in a [[Table]] that has field `a` of type `String` and field `b` + * of type `Int`. + */ + def fromDataSet[T](set: DataSet[T], fields: Expression*): Table = { + new ScalaBatchTranslator(config).createTable(set, fields.toArray) + } + + /** + * Transforms the given DataSet to a [[org.apache.flink.api.table.Table]]. + * The fields of the DataSet type are used to name the + * [[org.apache.flink.api.table.Table]] fields. + */ + def fromDataSet[T](set: DataSet[T]): Table = { + new ScalaBatchTranslator(config).createTable(set) + } + + /** + * Converts the given [[org.apache.flink.api.table.Table]] to + * a DataSet. The given type must have exactly the same fields as the + * [[org.apache.flink.api.table.Table]]. That is, the names of the + * fields and the types must match. + */ + def toDataSet[T: TypeInformation](table: Table): DataSet[T] = { + new ScalaBatchTranslator(config).translate[T](table.relNode) + } + +} + http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableConfig.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableConfig.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableConfig.scala index ffa2bec..e93d37d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableConfig.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableConfig.scala @@ -35,9 +35,15 @@ class TableConfig extends Serializable { private var nullCheck: Boolean = false /** + * Defines if efficient types (such as Tuple types or Atomic types) + * should be used within operators where possible. + */ + private var efficientTypeUsage = false + + /** * Sets the timezone for date/time/timestamp conversions. */ - def setTimeZone(timeZone: TimeZone) = { + def setTimeZone(timeZone: TimeZone): Unit = { require(timeZone != null, "timeZone must not be null.") this.timeZone = timeZone } @@ -55,12 +61,30 @@ class TableConfig extends Serializable { /** * Sets the NULL check. If enabled, all fields need to be checked for NULL first. */ - def setNullCheck(nullCheck: Boolean) = { + def setNullCheck(nullCheck: Boolean): Unit = { this.nullCheck = nullCheck } + /** + * Returns the usage of efficient types. If enabled, efficient types (such as Tuple types + * or Atomic types) are used within operators where possible. + * + * NOTE: Currently, this is an experimental feature. + */ + def getEfficientTypeUsage = efficientTypeUsage + + /** + * Sets the usage of efficient types. If enabled, efficient types (such as Tuple types + * or Atomic types) are used within operators where possible. + * + * NOTE: Currently, this is an experimental feature. + */ + def setEfficientTypeUsage(efficientTypeUsage: Boolean): Unit = { + this.efficientTypeUsage = efficientTypeUsage + } + } object TableConfig { - val DEFAULT = new TableConfig() + def DEFAULT = new TableConfig() } http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableException.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableException.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableException.scala new file mode 100644 index 0000000..3e298a4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableException.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table + +/** + * General Exception for all errors during table handling. + */ +class TableException(msg: String) extends RuntimeException(msg) http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenException.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenException.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenException.scala new file mode 100644 index 0000000..8b7559f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.codegen + +/** + * Exception for all errors occurring during code generation. + */ +class CodeGenException(msg: String) extends RuntimeException(msg) http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala new file mode 100644 index 0000000..5bd1467 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.codegen + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._ +import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo} +import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo +import org.apache.flink.api.table.typeinfo.RowTypeInfo + +object CodeGenUtils { + + private val nameCounter = new AtomicInteger + + def newName(name: String): String = { + s"$name$$${nameCounter.getAndIncrement}" + } + + // when casting we first need to unbox Primitives, for example, + // float a = 1.0f; + // byte b = (byte) a; + // works, but for boxed types we need this: + // Float a = 1.0f; + // Byte b = (byte)(float) a; + def primitiveTypeTermForTypeInfo(tpe: TypeInformation[_]): String = tpe match { + case INT_TYPE_INFO => "int" + case LONG_TYPE_INFO => "long" + case SHORT_TYPE_INFO => "short" + case BYTE_TYPE_INFO => "byte" + case FLOAT_TYPE_INFO => "float" + case DOUBLE_TYPE_INFO => "double" + case BOOLEAN_TYPE_INFO => "boolean" + case CHAR_TYPE_INFO => "char" + + // From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections + // does not seem to like this, so we manually give the correct type here. + case INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]" + case LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]" + case SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]" + case BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]" + case FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]" + case DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]" + case BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]" + case CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]" + + case _ => + tpe.getTypeClass.getCanonicalName + } + + def boxedTypeTermForTypeInfo(tpe: TypeInformation[_]): String = tpe match { + // From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections + // does not seem to like this, so we manually give the correct type here. + case INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]" + case LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]" + case SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]" + case BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]" + case FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]" + case DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]" + case BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]" + case CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]" + + case _ => + tpe.getTypeClass.getCanonicalName + } + + def primitiveDefaultValue(tpe: TypeInformation[_]): String = tpe match { + case INT_TYPE_INFO => "-1" + case LONG_TYPE_INFO => "-1" + case SHORT_TYPE_INFO => "-1" + case BYTE_TYPE_INFO => "-1" + case FLOAT_TYPE_INFO => "-1.0f" + case DOUBLE_TYPE_INFO => "-1.0d" + case BOOLEAN_TYPE_INFO => "false" + case STRING_TYPE_INFO => "\"<empty>\"" + case CHAR_TYPE_INFO => "'\\0'" + case _ => "null" + } + + def requireNumeric(genExpr: GeneratedExpression) = genExpr.resultType match { + case nti: NumericTypeInfo[_] => // ok + case _ => throw new CodeGenException("Numeric expression type expected.") + } + + def requireString(genExpr: GeneratedExpression) = genExpr.resultType match { + case STRING_TYPE_INFO => // ok + case _ => throw new CodeGenException("String expression type expected.") + } + + def requireBoolean(genExpr: GeneratedExpression) = genExpr.resultType match { + case BOOLEAN_TYPE_INFO => // ok + case _ => throw new CodeGenException("Boolean expression type expected.") + } + + def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) + + def isReference(typeInfo: TypeInformation[_]): Boolean = typeInfo match { + case INT_TYPE_INFO + | LONG_TYPE_INFO + | SHORT_TYPE_INFO + | BYTE_TYPE_INFO + | FLOAT_TYPE_INFO + | DOUBLE_TYPE_INFO + | BOOLEAN_TYPE_INFO + | CHAR_TYPE_INFO => false + case _ => true + } + + def isNumeric(genExpr: GeneratedExpression): Boolean = isNumeric(genExpr.resultType) + + def isNumeric(typeInfo: TypeInformation[_]): Boolean = typeInfo match { + case nti: NumericTypeInfo[_] => true + case _ => false + } + + def isString(genExpr: GeneratedExpression): Boolean = isString(genExpr.resultType) + + def isString(typeInfo: TypeInformation[_]): Boolean = typeInfo match { + case STRING_TYPE_INFO => true + case _ => false + } + + def isBoolean(genExpr: GeneratedExpression): Boolean = isBoolean(genExpr.resultType) + + def isBoolean(typeInfo: TypeInformation[_]): Boolean = typeInfo match { + case BOOLEAN_TYPE_INFO => true + case _ => false + } + + // ---------------------------------------------------------------------------------------------- + + sealed abstract class FieldAccessor + + case class ObjectFieldAccessor(fieldName: String) extends FieldAccessor + + case class ObjectMethodAccessor(methodName: String) extends FieldAccessor + + case class ProductAccessor(i: Int) extends FieldAccessor + + def fieldAccessorFor(compType: CompositeType[_], index: Int): FieldAccessor = { + compType match { + case ri: RowTypeInfo => + ProductAccessor(index) + + case cc: CaseClassTypeInfo[_] => + ObjectMethodAccessor(cc.getFieldNames()(index)) + + case javaTup: TupleTypeInfo[_] => + ObjectFieldAccessor("f" + index) + + case pj: PojoTypeInfo[_] => + ObjectFieldAccessor(pj.getFieldNames()(index)) + + case _ => throw new CodeGenException("Unsupported composite type.") + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala new file mode 100644 index 0000000..a4ae4b1 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -0,0 +1,752 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.codegen + +import org.apache.calcite.rex._ +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.calcite.sql.fun.SqlStdOperatorTable._ +import org.apache.flink.api.common.functions.{FlatMapFunction, Function, MapFunction} +import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo} +import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo +import org.apache.flink.api.table.TableConfig +import org.apache.flink.api.table.codegen.CodeGenUtils._ +import org.apache.flink.api.table.codegen.Indenter.toISC +import org.apache.flink.api.table.codegen.OperatorCodeGen._ +import org.apache.flink.api.table.plan.TypeConverter.sqlTypeToTypeInfo +import org.apache.flink.api.table.typeinfo.RowTypeInfo + +import scala.collection.JavaConversions._ +import scala.collection.mutable + +/** + * A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s. + * + * @param config configuration that determines runtime behavior + * @param input1 type information about the first input of the Function + * @param input2 type information about the second input if the Function is binary + */ +class CodeGenerator( + config: TableConfig, + input1: TypeInformation[Any], + input2: Option[TypeInformation[Any]] = None) + extends RexVisitor[GeneratedExpression] { + + // set of member statements that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableMemberStatements = mutable.LinkedHashSet[String]() + + // set of constructor statements that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableInitStatements = mutable.LinkedHashSet[String]() + + // map of initial input unboxing expressions that will be added only once + // (inputTerm, index) -> expr + private val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]() + + /** + * @return code block of statements that need to be placed in the member area of the Function + * (e.g. member variables and their initialization) + */ + def reuseMemberCode(): String = { + reusableMemberStatements.mkString("", "\n", "\n") + } + + /** + * @return code block of statements that need to be placed in the constructor of the Function + */ + def reuseInitCode(): String = { + reusableInitStatements.mkString("", "\n", "\n") + } + + /** + * @return code block of statements that unbox input variables to a primitive variable + * and a corresponding null flag variable + */ + def reuseInputUnboxingCode(): String = { + reusableInputUnboxingExprs.values.map(_.code).mkString("", "\n", "\n") + } + + /** + * @return term of the (casted and possibly boxed) first input + */ + def input1Term = "in1" + + /** + * @return term of the (casted and possibly boxed) second input + */ + def input2Term = "in2" + + /** + * @return term of the (casted) output collector + */ + def collectorTerm = "c" + + /** + * @return term of the output record (possibly defined in the member area e.g. Row, Tuple) + */ + def outRecordTerm = "out" + + /** + * @return returns if null checking is enabled + */ + def nullCheck: Boolean = config.getNullCheck + + /** + * Generates an expression from a RexNode. If objects or variables can be reused, they will be + * added to reusable code sections internally. + * + * @param rex Calcite row expression + * @return instance of GeneratedExpression + */ + def generateExpression(rex: RexNode): GeneratedExpression = { + rex.accept(this) + } + + /** + * Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java + * compiler. + * + * @param name Class name of the Function. Must not be unique but has to be a valid Java class + * identifier. + * @param clazz Flink Function to be generated. + * @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or + * output record can be accessed via the given term methods. + * @param returnType expected return type + * @tparam T Flink Function to be generated. + * @return instance of GeneratedFunction + */ + def generateFunction[T <: Function]( + name: String, + clazz: Class[T], + bodyCode: String, + returnType: TypeInformation[Any]) + : GeneratedFunction[T] = { + val funcName = newName(name) + + // Janino does not support generics, that's why we need + // manual casting here + val samHeader = + // FlatMapFunction + if (clazz == classOf[FlatMapFunction[_,_]]) { + val inputTypeTerm = boxedTypeTermForTypeInfo(input1) + (s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)", + s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;") + } + + // MapFunction + else if (clazz == classOf[MapFunction[_,_]]) { + val inputTypeTerm = boxedTypeTermForTypeInfo(input1) + ("Object map(Object _in1)", + s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;") + } + + else { + // TODO more functions + throw new CodeGenException("Unsupported Function.") + } + + val funcCode = j""" + public class $funcName + implements ${clazz.getCanonicalName} { + + ${reuseMemberCode()} + + public $funcName() { + ${reuseInitCode()} + } + + @Override + public ${samHeader._1} { + ${samHeader._2} + ${reuseInputUnboxingCode()} + $bodyCode + } + } + """.stripMargin + + GeneratedFunction(funcName, returnType, funcCode) + } + + /** + * Generates an expression that converts the first input (and second input) into the given type. + * If two inputs are converted, the second input is appended. If objects or variables can + * be reused, they will be added to reusable code sections internally. The evaluation result + * may be stored in the global result variable (see [[outRecordTerm]]). + * + * @param returnType conversion target type. Inputs and output must have the same arity. + * @return instance of GeneratedExpression + */ + def generateConverterResultExpression( + returnType: TypeInformation[_ <: Any]) + : GeneratedExpression = { + val input1AccessExprs = for (i <- 0 until input1.getArity) + yield generateInputAccess(input1, input1Term, i) + + val input2AccessExprs = input2 match { + case Some(ti) => for (i <- 0 until ti.getArity) + yield generateInputAccess(ti, input2Term, i) + case None => Seq() // add nothing + } + + generateResultExpression(input1AccessExprs ++ input2AccessExprs, returnType) + } + + /** + * Generates an expression from a sequence of RexNode. If objects or variables can be reused, + * they will be added to reusable code sections internally. The evaluation result + * may be stored in the global result variable (see [[outRecordTerm]]). + * + * @param returnType conversion target type. Type must have the same arity than rexNodes. + * @param rexNodes sequence of RexNode + * @return instance of GeneratedExpression + */ + def generateResultExpression( + returnType: TypeInformation[_ <: Any], + rexNodes: Seq[RexNode]) + : GeneratedExpression = { + val fieldExprs = rexNodes.map(generateExpression) + generateResultExpression(fieldExprs, returnType) + } + + /** + * Generates an expression from a sequence of other expressions. If objects or variables can + * be reused, they will be added to reusable code sections internally. The evaluation result + * may be stored in the global result variable (see [[outRecordTerm]]). + * + * @param fieldExprs + * @param returnType conversion target type. Type must have the same arity than fieldExprs. + * @return instance of GeneratedExpression + */ + def generateResultExpression( + fieldExprs: Seq[GeneratedExpression], + returnType: TypeInformation[_ <: Any]) + : GeneratedExpression = { + // TODO disable arity check for Rows and derive row arity from fieldExprs + // initial type check + if (returnType.getArity != fieldExprs.length) { + throw new CodeGenException("Arity of result type does not match number of expressions.") + } + // type check + returnType match { + case ct: CompositeType[_] => + fieldExprs.zipWithIndex foreach { + case (fieldExpr, i) if fieldExpr.resultType != ct.getTypeAt(i) => + throw new CodeGenException("Incompatible types of expression and result type.") + case _ => // ok + } + case at: AtomicType[_] if at != fieldExprs.head.resultType => + throw new CodeGenException("Incompatible types of expression and result type.") + case _ => // ok + } + + val returnTypeTerm = boxedTypeTermForTypeInfo(returnType) + + // generate result expression + returnType match { + case ri: RowTypeInfo => + addReusableOutRecord(ri) + val resultSetters: String = fieldExprs.zipWithIndex map { + case (fieldExpr, i) => + if (nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | $outRecordTerm.setField($i, null); + |} + |else { + | $outRecordTerm.setField($i, ${fieldExpr.resultTerm}); + |} + |""".stripMargin + } + else { + s""" + |${fieldExpr.code} + |$outRecordTerm.setField($i, ${fieldExpr.resultTerm}); + |""".stripMargin + } + } mkString "\n" + + GeneratedExpression(outRecordTerm, "false", resultSetters, returnType) + + case pj: PojoTypeInfo[_] => + addReusableOutRecord(pj) + val resultSetters: String = fieldExprs.zip(pj.getFieldNames) map { + case (fieldExpr, fieldName) => + if (nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | $outRecordTerm.$fieldName = null; + |} + |else { + | $outRecordTerm.$fieldName = ${fieldExpr.resultTerm}; + |} + |""".stripMargin + } + else { + s""" + |${fieldExpr.code} + |$outRecordTerm.$fieldName = ${fieldExpr.resultTerm}; + |""".stripMargin + } + } mkString "\n" + + GeneratedExpression(outRecordTerm, "false", resultSetters, returnType) + + case tup: TupleTypeInfo[_] => + addReusableOutRecord(tup) + val resultSetters: String = fieldExprs.zipWithIndex map { + case (fieldExpr, i) => + val fieldName = "f" + i + if (nullCheck) { + s""" + |${fieldExpr.code} + |if (${fieldExpr.nullTerm}) { + | throw new NullPointerException("Null result cannot be stored in a Tuple."); + |} + |else { + | $outRecordTerm.$fieldName = ${fieldExpr.resultTerm}; + |} + |""".stripMargin + } + else { + s""" + |${fieldExpr.code} + |$outRecordTerm.$fieldName = ${fieldExpr.resultTerm}; + |""".stripMargin + } + } mkString "\n" + + GeneratedExpression(outRecordTerm, "false", resultSetters, returnType) + + case cc: CaseClassTypeInfo[_] => + val fieldCodes: String = fieldExprs.map(_.code).mkString("\n") + val constructorParams: String = fieldExprs.map(_.resultTerm).mkString(", ") + val resultTerm = newName(outRecordTerm) + + val nullCheckCode = if (nullCheck) { + fieldExprs map { (fieldExpr) => + s""" + |if (${fieldExpr.nullTerm}) { + | throw new NullPointerException("Null result cannot be stored in a Case Class."); + |} + |""".stripMargin + } mkString "\n" + } else { + "" + } + + val resultCode = + s""" + |$fieldCodes + |$nullCheckCode + |$returnTypeTerm $resultTerm = new $returnTypeTerm($constructorParams); + |""".stripMargin + + GeneratedExpression(resultTerm, "false", resultCode, returnType) + + case a: AtomicType[_] => + val fieldExpr = fieldExprs.head + val nullCheckCode = if (nullCheck) { + s""" + |if (${fieldExpr.nullTerm}) { + | throw new NullPointerException("Null result cannot be used for atomic types."); + |} + |""".stripMargin + } else { + "" + } + val resultCode = + s""" + |${fieldExpr.code} + |$nullCheckCode + |""".stripMargin + + GeneratedExpression(fieldExpr.resultTerm, "false", resultCode, returnType) + + case _ => + throw new CodeGenException(s"Unsupported result type: $returnType") + } + } + + // ---------------------------------------------------------------------------------------------- + // RexVisitor methods + // ---------------------------------------------------------------------------------------------- + + override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = { + // if inputRef index is within size of input1 we work with input1, input2 otherwise + val input = if (inputRef.getIndex < input1.getArity) { + (input1, input1Term) + } else { + (input2.getOrElse(throw new CodeGenException("Invalid input access.")), input2Term) + } + + val index = if (input._1 == input1) { + inputRef.getIndex + } else { + inputRef.getIndex - input1.getArity + } + + generateInputAccess(input._1, input._2, index) + } + + override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = ??? + + override def visitLiteral(literal: RexLiteral): GeneratedExpression = { + val resultType = sqlTypeToTypeInfo(literal.getType.getSqlTypeName) + val value = literal.getValue3 + literal.getType.getSqlTypeName match { + case BOOLEAN => + generateNonNullLiteral(resultType, literal.getValue3.toString) + case TINYINT => + val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + if (decimal.isValidByte) { + generateNonNullLiteral(resultType, decimal.byteValue().toString) + } + else { + throw new CodeGenException("Decimal can not be converted to byte.") + } + case SMALLINT => + val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + if (decimal.isValidShort) { + generateNonNullLiteral(resultType, decimal.shortValue().toString) + } + else { + throw new CodeGenException("Decimal can not be converted to short.") + } + case INTEGER => + val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + if (decimal.isValidShort) { + generateNonNullLiteral(resultType, decimal.intValue().toString) + } + else { + throw new CodeGenException("Decimal can not be converted to integer.") + } + case BIGINT => + val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + if (decimal.isValidLong) { + generateNonNullLiteral(resultType, decimal.longValue().toString) + } + else { + throw new CodeGenException("Decimal can not be converted to long.") + } + case FLOAT => + val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + if (decimal.isValidFloat) { + generateNonNullLiteral(resultType, decimal.floatValue().toString + "f") + } + else { + throw new CodeGenException("Decimal can not be converted to float.") + } + case DOUBLE => + val decimal = BigDecimal(value.asInstanceOf[java.math.BigDecimal]) + if (decimal.isValidDouble) { + generateNonNullLiteral(resultType, decimal.doubleValue().toString) + } + else { + throw new CodeGenException("Decimal can not be converted to double.") + } + case VARCHAR | CHAR => + generateNonNullLiteral(resultType, "\"" + value.toString + "\"") + case NULL => + generateNullLiteral(resultType) + case _ => ??? // TODO more types + } + } + + override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = ??? + + override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = ??? + + override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression = ??? + + override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression = ??? + + override def visitCall(call: RexCall): GeneratedExpression = { + val operands = call.getOperands.map(_.accept(this)) + val resultType = sqlTypeToTypeInfo(call.getType.getSqlTypeName) + + call.getOperator match { + // arithmetic + case PLUS if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateArithmeticOperator("+", nullCheck, resultType, left, right) + + case PLUS if isString(resultType) => + val left = operands.head + val right = operands(1) + requireString(left) + generateArithmeticOperator("+", nullCheck, resultType, left, right) + + case MINUS if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateArithmeticOperator("-", nullCheck, resultType, left, right) + + case MULTIPLY if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateArithmeticOperator("*", nullCheck, resultType, left, right) + + case DIVIDE if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateArithmeticOperator("/", nullCheck, resultType, left, right) + + case MOD if isNumeric(resultType) => + val left = operands.head + val right = operands(1) + requireNumeric(left) + requireNumeric(right) + generateArithmeticOperator("%", nullCheck, resultType, left, right) + + case UNARY_MINUS if isNumeric(resultType) => + val operand = operands.head + requireNumeric(operand) + generateUnaryArithmeticOperator("-", nullCheck, resultType, operand) + + case UNARY_PLUS if isNumeric(resultType) => + val operand = operands.head + requireNumeric(operand) + generateUnaryArithmeticOperator("+", nullCheck, resultType, operand) + + // comparison + case EQUALS => + val left = operands.head + val right = operands(1) + generateEquals(nullCheck, left, right) + + case NOT_EQUALS => + val left = operands.head + val right = operands(1) + generateNotEquals(nullCheck, left, right) + + case GREATER_THAN => + val left = operands.head + val right = operands(1) + generateComparison(">", nullCheck, left, right) + + case GREATER_THAN_OR_EQUAL => + val left = operands.head + val right = operands(1) + generateComparison(">=", nullCheck, left, right) + + case LESS_THAN => + val left = operands.head + val right = operands(1) + generateComparison("<", nullCheck, left, right) + + case LESS_THAN_OR_EQUAL => + val left = operands.head + val right = operands(1) + generateComparison("<=", nullCheck, left, right) + + case IS_NULL => + val operand = operands.head + generateIsNull(nullCheck, operand) + + case IS_NOT_NULL => + val operand = operands.head + generateIsNotNull(nullCheck, operand) + + // logic + case AND => + operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) => + requireBoolean(left) + requireBoolean(right) + generateAnd(nullCheck, left, right) + } + + case OR => + operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) => + requireBoolean(left) + requireBoolean(right) + generateOr(nullCheck, left, right) + } + + case NOT => + val operand = operands.head + requireBoolean(operand) + generateNot(nullCheck, operand) + + case call@_ => + throw new CodeGenException(s"Unsupported call: $call") + } + } + + override def visitOver(over: RexOver): GeneratedExpression = ??? + + // ---------------------------------------------------------------------------------------------- + // generator helping methods + // ---------------------------------------------------------------------------------------------- + + private def generateInputAccess( + inputType: TypeInformation[Any], + inputTerm: String, + index: Int) + : GeneratedExpression = { + // if input has been used before, we can reuse the code that + // has already been generated + val inputExpr = reusableInputUnboxingExprs.get((inputTerm, index)) match { + // input access and boxing has already been generated + case Some(expr) => + expr + + // generate input access and boxing + case None => + val newExpr = inputType match { + case ct: CompositeType[_] => + val accessor = fieldAccessorFor(ct, index) + val fieldType: TypeInformation[Any] = ct.getTypeAt(index) + val fieldTypeTerm = boxedTypeTermForTypeInfo(fieldType) + + val inputCode = accessor match { + case ObjectFieldAccessor(fieldName) => + s"($fieldTypeTerm) $inputTerm.$fieldName" + + case ObjectMethodAccessor(methodName) => + s"($fieldTypeTerm) $inputTerm.$methodName()" + + case ProductAccessor(i) => + s"($fieldTypeTerm) $inputTerm.productElement($i)" + } + generateInputUnboxing(fieldType, inputCode) + case at: AtomicType[_] => + val fieldTypeTerm = boxedTypeTermForTypeInfo(at) + val inputCode = s"($fieldTypeTerm) $inputTerm" + generateInputUnboxing(at, inputCode) + case _ => + throw new CodeGenException("Unsupported type for input access.") + } + reusableInputUnboxingExprs((inputTerm, index)) = newExpr + newExpr + } + // hide the generated code as it will be executed only once + GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType) + } + + private def generateInputUnboxing( + inputType: TypeInformation[Any], + inputCode: String) + : GeneratedExpression = { + val tmpTerm = newName("tmp") + val resultTerm = newName("result") + val nullTerm = newName("isNull") + val tmpTypeTerm = boxedTypeTermForTypeInfo(inputType) + val resultTypeTerm = primitiveTypeTermForTypeInfo(inputType) + val defaultValue = primitiveDefaultValue(inputType) + + val wrappedCode = if (nullCheck && !isReference(inputType)) { + s""" + |$tmpTypeTerm $tmpTerm = $inputCode; + |boolean $nullTerm = $tmpTerm == null; + |$resultTypeTerm $resultTerm; + |if ($nullTerm) { + | $resultTerm = $defaultValue; + |} + |else { + | $resultTerm = $tmpTerm; + |} + |""".stripMargin + } else if (nullCheck) { + s""" + |$resultTypeTerm $resultTerm = $inputCode; + |boolean $nullTerm = $inputCode == null; + |""".stripMargin + } else { + s""" + |$resultTypeTerm $resultTerm = $inputCode; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, wrappedCode, inputType) + } + + private def generateNonNullLiteral( + literalType: TypeInformation[_], + literalCode: String) + : GeneratedExpression = { + val resultTerm = newName("result") + val nullTerm = newName("isNull") + val resultTypeTerm = primitiveTypeTermForTypeInfo(literalType) + + val resultCode = if (nullCheck) { + s""" + |$resultTypeTerm $resultTerm = $literalCode; + |boolean $nullTerm = false; + |""".stripMargin + } else { + s""" + |$resultTypeTerm $resultTerm = $literalCode; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, resultCode, literalType) + } + + private def generateNullLiteral(resultType: TypeInformation[_]): GeneratedExpression = { + val resultTerm = newName("result") + val nullTerm = newName("isNull") + val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) + val defaultValue = primitiveDefaultValue(resultType) + + val wrappedCode = if (nullCheck) { + s""" + |$resultTypeTerm $resultTerm = null; + |boolean $nullTerm = true; + |""".stripMargin + } else { + s""" + |$resultTypeTerm $resultTerm = $defaultValue; + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, wrappedCode, resultType) + } + + // ---------------------------------------------------------------------------------------------- + + def addReusableOutRecord(ti: TypeInformation[_]) = { + val statement = ti match { + case rt: RowTypeInfo => + s""" + |${ti.getTypeClass.getCanonicalName} $outRecordTerm = + | new ${ti.getTypeClass.getCanonicalName}(${rt.getArity}); + |""".stripMargin + case _ => + s""" + |${ti.getTypeClass.getCanonicalName} $outRecordTerm = + | new ${ti.getTypeClass.getCanonicalName}(); + |""".stripMargin + } + reusableMemberStatements.add(statement) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala deleted file mode 100644 index 9592f2e..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala +++ /dev/null @@ -1,794 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.table.codegen - -import java.util.Date -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation} -import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo} -import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo -import org.apache.flink.api.table.expressions._ -import org.apache.flink.api.table.typeinfo.{RenamingProxyTypeInfo, RowTypeInfo} -import org.apache.flink.api.table.{ExpressionException, TableConfig, expressions} -import org.codehaus.janino.SimpleCompiler -import org.slf4j.LoggerFactory - -import scala.collection.mutable - -/** Base class for all code generation classes. This provides the functionality for generating - * code from an [[Expression]] tree. Derived classes must embed this in a lambda function - * to form an executable code block. - * - * @param inputs List of input variable names with corresponding [[TypeInformation]]. - * @param cl The ClassLoader that is used to create the Scala reflection ToolBox - * @param config General configuration specifying runtime behaviour. - * @tparam R The type of the generated code block. In most cases a lambda function such - * as "(IN1, IN2) => OUT". - */ -abstract class ExpressionCodeGenerator[R]( - inputs: Seq[(String, CompositeType[_])], - cl: ClassLoader, - config: TableConfig) { - protected val log = LoggerFactory.getLogger(classOf[ExpressionCodeGenerator[_]]) - - import scala.reflect.runtime.universe._ - import scala.reflect.runtime.{universe => ru} - - if (cl == null) { - throw new IllegalArgumentException("ClassLoader must not be null.") - } - - val compiler = new SimpleCompiler() - compiler.setParentClassLoader(cl) - - protected val reusableMemberStatements = mutable.Set[String]() - - protected val reusableInitStatements = mutable.Set[String]() - - protected def reuseMemberCode(): String = { - reusableMemberStatements.mkString("", "\n", "\n") - } - - protected def reuseInitCode(): String = { - reusableInitStatements.mkString("", "\n", "\n") - } - - protected def nullCheck: Boolean = config.getNullCheck - - // This is to be implemented by subclasses, we have it like this - // so that we only call it from here with the Scala Reflection Lock. - protected def generateInternal(): R - - final def generate(): R = { - generateInternal() - } - - protected def generateExpression(expr: Expression): GeneratedExpression = { - generateExpressionInternal(expr) - } - - protected def generateExpressionInternal(expr: Expression): GeneratedExpression = { - // protected def generateExpression(expr: Expression): GeneratedExpression = { - val nullTerm = freshName("isNull") - val resultTerm = freshName("result") - - // For binary predicates that must only be evaluated when both operands are non-null. - // This will write to nullTerm and resultTerm, so don't use those term names - // after using this function - def generateIfNonNull(left: Expression, right: Expression, resultType: TypeInformation[_]) - (expr: (String, String) => String): String = { - val leftCode = generateExpression(left) - val rightCode = generateExpression(right) - - val leftTpe = typeTermForTypeInfo(left.typeInfo) - val rightTpe = typeTermForTypeInfo(right.typeInfo) - val resultTpe = typeTermForTypeInfo(resultType) - - if (nullCheck) { - leftCode.code + "\n" + - rightCode.code + "\n" + - s""" - |boolean $nullTerm = ${leftCode.nullTerm} || ${rightCode.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = ${defaultPrimitive(resultType)}; - |} else { - | $resultTerm = ${expr(leftCode.resultTerm, rightCode.resultTerm)}; - |} - """.stripMargin - } else { - leftCode.code + "\n" + - rightCode.code + "\n" + - s""" - |$resultTpe $resultTerm = ${expr(leftCode.resultTerm, rightCode.resultTerm)}; - """.stripMargin - } - } - - def cleanedExpr(e: Expression): Expression = { - e match { - case expressions.Naming(namedExpr, _) => cleanedExpr(namedExpr) - case _ => e - } - } - - val cleanedExpression = cleanedExpr(expr) - val resultTpe = typeTermForTypeInfo(cleanedExpression.typeInfo) - - val code: String = cleanedExpression match { - - case expressions.Literal(null, typeInfo) => - if (nullCheck) { - s""" - |boolean $nullTerm = true; - |$resultTpe resultTerm = null; - """.stripMargin - } else { - s""" - |$resultTpe resultTerm = null; - """.stripMargin - } - - case expressions.Literal(intValue: Int, INT_TYPE_INFO) => - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = $intValue; - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = $intValue; - """.stripMargin - } - - case expressions.Literal(longValue: Long, LONG_TYPE_INFO) => - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = ${longValue}L; - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = ${longValue}L; - """.stripMargin - } - - - case expressions.Literal(doubleValue: Double, DOUBLE_TYPE_INFO) => - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = $doubleValue; - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = $doubleValue; - """.stripMargin - } - - case expressions.Literal(floatValue: Float, FLOAT_TYPE_INFO) => - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = ${floatValue}f; - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = ${floatValue}f; - """.stripMargin - } - - case expressions.Literal(strValue: String, STRING_TYPE_INFO) => - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = "$strValue"; - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = "$strValue"; - """.stripMargin - } - - case expressions.Literal(boolValue: Boolean, BOOLEAN_TYPE_INFO) => - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = $boolValue; - """.stripMargin - } else { - s""" - $resultTpe $resultTerm = $boolValue; - """.stripMargin - } - - case expressions.Literal(dateValue: Date, DATE_TYPE_INFO) => - val dateName = s"""date_${dateValue.getTime}""" - val dateStmt = s"""static java.util.Date $dateName - |= new java.util.Date(${dateValue.getTime});""".stripMargin - reusableMemberStatements.add(dateStmt) - - if (nullCheck) { - s""" - |boolean $nullTerm = false; - |$resultTpe $resultTerm = $dateName; - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = $dateName; - """.stripMargin - } - - case Substring(str, beginIndex, endIndex) => - val strCode = generateExpression(str) - val beginIndexCode = generateExpression(beginIndex) - val endIndexCode = generateExpression(endIndex) - if (nullCheck) { - strCode.code + - beginIndexCode.code + - endIndexCode.code + - s""" - boolean $nullTerm = - ${strCode.nullTerm} || ${beginIndexCode.nullTerm} || ${endIndexCode.nullTerm}; - $resultTpe $resultTerm; - if ($nullTerm) { - $resultTerm = ${defaultPrimitive(str.typeInfo)}; - } else { - if (${endIndexCode.resultTerm} == Int.MaxValue) { - $resultTerm = (${strCode.resultTerm}).substring(${beginIndexCode.resultTerm}); - } else { - $resultTerm = (${strCode.resultTerm}).substring( - ${beginIndexCode.resultTerm}, - ${endIndexCode.resultTerm}); - } - } - """.stripMargin - } else { - strCode.code + - beginIndexCode.code + - endIndexCode.code + - s""" - $resultTpe $resultTerm; - - if (${endIndexCode.resultTerm} == Integer.MAX_VALUE) { - $resultTerm = (${strCode.resultTerm}).substring(${beginIndexCode.resultTerm}); - } else { - $resultTerm = (${strCode.resultTerm}).substring( - ${beginIndexCode.resultTerm}, - ${endIndexCode.resultTerm}); - } - """ - } - - case expressions.Cast(child: Expression, STRING_TYPE_INFO) - if child.typeInfo == BasicTypeInfo.DATE_TYPE_INFO => - val childGen = generateExpression(child) - - addTimestampFormatter() - - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = null; - |} else { - | $resultTerm = timestampFormatter.format(${childGen.resultTerm}); - |} - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = timestampFormatter.format(${childGen.resultTerm}); - """.stripMargin - } - childGen.code + castCode - - case expressions.Cast(child: Expression, STRING_TYPE_INFO) => - val childGen = generateExpression(child) - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = null; - |} else { - | $resultTerm = "" + ${childGen.resultTerm}; - |} - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = "" + ${childGen.resultTerm}; - """.stripMargin - } - childGen.code + castCode - - case expressions.Cast(child: Expression, DATE_TYPE_INFO) - if child.typeInfo == BasicTypeInfo.LONG_TYPE_INFO => - val childGen = generateExpression(child) - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = null; - |} else { - | $resultTerm = new java.util.Date(${childGen.resultTerm}); - |} - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = new java.util.Date(${childGen.resultTerm}); - """.stripMargin - } - childGen.code + castCode - - case expressions.Cast(child: Expression, DATE_TYPE_INFO) - if child.typeInfo == BasicTypeInfo.STRING_TYPE_INFO => - val childGen = generateExpression(child) - - addDateFormatter() - addTimeFormatter() - addTimestampFormatter() - - // tries to parse - // "2011-05-03 15:51:36.234" - // then "2011-05-03" - // then "15:51:36" - // then "1446473775" - val parsedName = freshName("parsed") - val parsingCode = - s""" - |java.util.Date $parsedName = null; - |try { - | $parsedName = timestampFormatter.parse(${childGen.resultTerm}); - |} catch (java.text.ParseException e1) { - | try { - | $parsedName = dateFormatter.parse(${childGen.resultTerm}); - | } catch (java.text.ParseException e2) { - | try { - | $parsedName = timeFormatter.parse(${childGen.resultTerm}); - | } catch (java.text.ParseException e3) { - | $parsedName = new java.util.Date(Long.valueOf(${childGen.resultTerm})); - | } - | } - |} - """.stripMargin - - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = null; - |} else { - | $parsingCode - | $resultTerm = $parsedName; - |} - """.stripMargin - } else { - s""" - |$parsingCode - |$resultTpe $resultTerm = $parsedName; - """.stripMargin - } - childGen.code + castCode - - case expressions.Cast(child: Expression, DATE_TYPE_INFO) => - throw new ExpressionException("Only Long and String can be casted to Date.") - - case expressions.Cast(child: Expression, LONG_TYPE_INFO) - if child.typeInfo == BasicTypeInfo.DATE_TYPE_INFO => - val childGen = generateExpression(child) - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = null; - |} else { - | $resultTerm = ${childGen.resultTerm}.getTime(); - |} - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = ${childGen.resultTerm}.getTime(); - """.stripMargin - } - childGen.code + castCode - - case expressions.Cast(child: Expression, tpe: BasicTypeInfo[_]) - if child.typeInfo == BasicTypeInfo.DATE_TYPE_INFO => - throw new ExpressionException("Date can only be casted to Long or String.") - - case expressions.Cast(child: Expression, tpe: BasicTypeInfo[_]) - if child.typeInfo == BasicTypeInfo.STRING_TYPE_INFO => - val childGen = generateExpression(child) - val fromTpe = typeTermForTypeInfoForCast(child.typeInfo) - val toTpe = typeTermForTypeInfoForCast(tpe) - - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm = - | ${tpe.getTypeClass.getCanonicalName}.valueOf(${childGen.resultTerm}); - """.stripMargin - } else { - s""" - |$resultTpe $resultTerm = - | ${tpe.getTypeClass.getCanonicalName}.valueOf(${childGen.resultTerm}); - """.stripMargin - } - - childGen.code + castCode - - case expressions.Cast(child: Expression, tpe: BasicTypeInfo[_]) - if child.typeInfo.isBasicType => - val childGen = generateExpression(child) - val fromTpe = typeTermForTypeInfoForCast(child.typeInfo) - val toTpe = typeTermForTypeInfoForCast(tpe) - val castCode = if (nullCheck) { - s""" - |boolean $nullTerm = ${childGen.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = null; - |} else { - | $resultTerm = ($toTpe)($fromTpe) ${childGen.resultTerm}; - |} - """.stripMargin - } else { - s"$resultTpe $resultTerm = ($toTpe)($fromTpe) ${childGen.resultTerm};\n" - } - childGen.code + castCode - - case ResolvedFieldReference(fieldName, fieldTpe: TypeInformation[_]) => - inputs find { i => i._2.hasField(fieldName)} match { - case Some((inputName, inputTpe)) => - val fieldCode = getField(newTermName(inputName), inputTpe, fieldName, fieldTpe) - if (nullCheck) { - s""" - |$resultTpe $resultTerm = $fieldCode; - |boolean $nullTerm = $resultTerm == null; - """.stripMargin - } else { - s"""$resultTpe $resultTerm = $fieldCode;""" - } - - case None => throw new ExpressionException("Could not get accessor for " + fieldName - + " in inputs " + inputs.mkString(", ") + ".") - } - - case GreaterThan(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm > $rightTerm" - } - - case GreaterThanOrEqual(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm >= $rightTerm" - } - - case LessThan(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm < $rightTerm" - } - - case LessThanOrEqual(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm <= $rightTerm" - } - - case EqualTo(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm.equals($rightTerm)" - } - - case NotEqualTo(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"!($leftTerm.equals($rightTerm))" - } - - case And(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm && $rightTerm" - } - - case Or(left, right) => - generateIfNonNull(left, right, BOOLEAN_TYPE_INFO) { - (leftTerm, rightTerm) => s"$leftTerm || $rightTerm" - } - - case Plus(left, right) => - generateIfNonNull(left, right, expr.typeInfo) { - (leftTerm, rightTerm) => s"$leftTerm + $rightTerm" - } - - case Minus(left, right) => - generateIfNonNull(left, right, expr.typeInfo) { - (leftTerm, rightTerm) => s"$leftTerm - $rightTerm" - } - - case Div(left, right) => - generateIfNonNull(left, right, expr.typeInfo) { - (leftTerm, rightTerm) => s"$leftTerm / $rightTerm" - } - - case Mul(left, right) => - generateIfNonNull(left, right, expr.typeInfo) { - (leftTerm, rightTerm) => s"$leftTerm * $rightTerm" - } - - case Mod(left, right) => - generateIfNonNull(left, right, expr.typeInfo) { - (leftTerm, rightTerm) => s"$leftTerm % $rightTerm" - } - - case UnaryMinus(child) => - val childCode = generateExpression(child) - if (nullCheck) { - childCode.code + - s""" - |boolean $nullTerm = ${childCode.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = ${defaultPrimitive(child.typeInfo)}; - |} else { - | $resultTerm = -(${childCode.resultTerm}); - |} - """.stripMargin - } else { - childCode.code + - s""" - |$resultTpe $resultTerm = -(${childCode.resultTerm}); - """.stripMargin - } - - case Not(child) => - val childCode = generateExpression(child) - if (nullCheck) { - childCode.code + - s""" - |boolean $nullTerm = ${childCode.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = ${defaultPrimitive(child.typeInfo)}; - |} else { - | $resultTerm = !(${childCode.resultTerm}); - |} - """.stripMargin - } else { - childCode.code + - s""" - |$resultTpe $resultTerm = !(${childCode.resultTerm}); - """.stripMargin - } - - case IsNull(child) => - val childCode = generateExpression(child) - if (nullCheck) { - childCode.code + - s""" - |$resultTpe $resultTerm = ${childCode.nullTerm}; - """.stripMargin - } else { - childCode.code + - s""" - |$resultTpe $resultTerm = (${childCode.resultTerm}) == null; - """.stripMargin - } - - case IsNotNull(child) => - val childCode = generateExpression(child) - if (nullCheck) { - childCode.code + - s""" - |$resultTpe $resultTerm = !${childCode.nullTerm}; - """.stripMargin - } else { - childCode.code + - s""" - |$resultTpe $resultTerm = (${childCode.resultTerm}) != null; - """.stripMargin - } - - case Abs(child) => - val childCode = generateExpression(child) - if (nullCheck) { - childCode.code + - s""" - |boolean $nullTerm = ${childCode.nullTerm}; - |$resultTpe $resultTerm; - |if ($nullTerm) { - | $resultTerm = ${defaultPrimitive(child.typeInfo)}; - |} else { - | $resultTerm = Math.abs(${childCode.resultTerm}); - |} - """.stripMargin - } else { - childCode.code + - s""" - |$resultTpe $resultTerm = Math.abs(${childCode.resultTerm}); - """.stripMargin - } - - case _ => throw new ExpressionException("Could not generate code for expression " + expr) - } - - GeneratedExpression(code, resultTerm, nullTerm) - } - - case class GeneratedExpression(code: String, resultTerm: String, nullTerm: String) - - def freshName(name: String): String = { - s"$name$$${freshNameCounter.getAndIncrement}" - } - - val freshNameCounter = new AtomicInteger - - protected def getField( - inputTerm: TermName, - inputType: CompositeType[_], - fieldName: String, - fieldType: TypeInformation[_]): String = { - val accessor = fieldAccessorFor(inputType, fieldName) - val fieldTpe = typeTermForTypeInfo(fieldType) - - accessor match { - case ObjectFieldAccessor(fieldName) => - val fieldTerm = newTermName(fieldName) - s"($fieldTpe) $inputTerm.$fieldTerm" - - case ObjectMethodAccessor(methodName) => - val methodTerm = newTermName(methodName) - s"($fieldTpe) $inputTerm.$methodTerm()" - - case ProductAccessor(i) => - s"($fieldTpe) $inputTerm.productElement($i)" - - } - } - - sealed abstract class FieldAccessor - - case class ObjectFieldAccessor(fieldName: String) extends FieldAccessor - - case class ObjectMethodAccessor(methodName: String) extends FieldAccessor - - case class ProductAccessor(i: Int) extends FieldAccessor - - def fieldAccessorFor(elementType: CompositeType[_], fieldName: String): FieldAccessor = { - elementType match { - case ri: RowTypeInfo => - ProductAccessor(elementType.getFieldIndex(fieldName)) - - case cc: CaseClassTypeInfo[_] => - ObjectMethodAccessor(fieldName) - - case javaTup: TupleTypeInfo[_] => - ObjectFieldAccessor(fieldName) - - case pj: PojoTypeInfo[_] => - ObjectFieldAccessor(fieldName) - - case proxy: RenamingProxyTypeInfo[_] => - val underlying = proxy.getUnderlyingType - val fieldIndex = proxy.getFieldIndex(fieldName) - fieldAccessorFor(underlying, underlying.getFieldNames()(fieldIndex)) - } - } - - protected def defaultPrimitive(tpe: TypeInformation[_]): String = tpe match { - case BasicTypeInfo.INT_TYPE_INFO => "-1" - case BasicTypeInfo.LONG_TYPE_INFO => "-1" - case BasicTypeInfo.SHORT_TYPE_INFO => "-1" - case BasicTypeInfo.BYTE_TYPE_INFO => "-1" - case BasicTypeInfo.FLOAT_TYPE_INFO => "-1.0f" - case BasicTypeInfo.DOUBLE_TYPE_INFO => "-1.0d" - case BasicTypeInfo.BOOLEAN_TYPE_INFO => "false" - case BasicTypeInfo.STRING_TYPE_INFO => "\"<empty>\"" - case BasicTypeInfo.CHAR_TYPE_INFO => "'\\0'" - case _ => "null" - } - - protected def typeTermForTypeInfo(tpe: TypeInformation[_]): String = tpe match { - - // From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections - // does not seem to like this, so we manually give the correct type here. - case PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]" - case PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]" - case PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]" - case PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]" - case PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]" - case PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]" - case PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]" - case PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]" - - case _ => - tpe.getTypeClass.getCanonicalName - - } - - // when casting we first need to unbox Primitives, for example, - // float a = 1.0f; - // byte b = (byte) a; - // works, but for boxed types we need this: - // Float a = 1.0f; - // Byte b = (byte)(float) a; - protected def typeTermForTypeInfoForCast(tpe: TypeInformation[_]): String = tpe match { - - case BasicTypeInfo.INT_TYPE_INFO => "int" - case BasicTypeInfo.LONG_TYPE_INFO => "long" - case BasicTypeInfo.SHORT_TYPE_INFO => "short" - case BasicTypeInfo.BYTE_TYPE_INFO => "byte" - case BasicTypeInfo.FLOAT_TYPE_INFO => "float" - case BasicTypeInfo.DOUBLE_TYPE_INFO => "double" - case BasicTypeInfo.BOOLEAN_TYPE_INFO => "boolean" - case BasicTypeInfo.CHAR_TYPE_INFO => "char" - - // From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections - // does not seem to like this, so we manually give the correct type here. - case PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]" - case PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]" - case PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]" - case PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]" - case PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]" - case PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]" - case PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]" - case PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]" - - case _ => - tpe.getTypeClass.getCanonicalName - - } - - def addDateFormatter(): Unit = { - reusableMemberStatements.add(s""" - |java.text.SimpleDateFormat dateFormatter = - | new java.text.SimpleDateFormat("yyyy-MM-dd"); - |""".stripMargin) - - reusableInitStatements.add(s""" - |dateFormatter.setTimeZone(config.getTimeZone()); - |""".stripMargin) - } - - def addTimeFormatter(): Unit = { - reusableMemberStatements.add(s""" - |java.text.SimpleDateFormat timeFormatter = - | new java.text.SimpleDateFormat("HH:mm:ss"); - |""".stripMargin) - - reusableInitStatements.add(s""" - |timeFormatter.setTimeZone(config.getTimeZone()); - |""".stripMargin) - } - - def addTimestampFormatter(): Unit = { - reusableMemberStatements.add(s""" - |java.text.SimpleDateFormat timestampFormatter = - | new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); - |""".stripMargin) - - reusableInitStatements.add(s""" - |timestampFormatter.setTimeZone(config.getTimeZone()); - |""".stripMargin) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/a4ad9dd5/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/GenerateFilter.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/GenerateFilter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/GenerateFilter.scala deleted file mode 100644 index 50b8c69..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/GenerateFilter.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.table.codegen - -import java.io.StringReader - -import org.apache.flink.api.common.functions.FilterFunction -import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.table.TableConfig -import org.apache.flink.api.table.codegen.Indenter._ -import org.apache.flink.api.table.expressions.Expression -import org.slf4j.LoggerFactory - -/** - * Code generator for a unary predicate, i.e. a Filter. - */ -class GenerateFilter[T]( - inputType: CompositeType[T], - predicate: Expression, - cl: ClassLoader, - config: TableConfig) extends ExpressionCodeGenerator[FilterFunction[T]]( - Seq(("in0", inputType)), - cl = cl, - config) { - - val LOG = LoggerFactory.getLogger(this.getClass) - - override protected def generateInternal(): FilterFunction[T] = { - val pred = generateExpression(predicate) - - val tpe = typeTermForTypeInfo(inputType) - - val generatedName = freshName("GeneratedFilter") - - // Janino does not support generics, so we need to cast by hand - val code = if (nullCheck) { - j""" - public class $generatedName - implements org.apache.flink.api.common.functions.FilterFunction<$tpe> { - - org.apache.flink.api.table.TableConfig config = null; - - public $generatedName(org.apache.flink.api.table.TableConfig config) { - this.config = config; - } - - public boolean filter(Object _in0) { - $tpe in0 = ($tpe) _in0; - ${pred.code} - if (${pred.nullTerm}) { - return false; - } else { - return ${pred.resultTerm}; - } - } - } - """ - } else { - j""" - public class $generatedName - implements org.apache.flink.api.common.functions.FilterFunction<$tpe> { - - org.apache.flink.api.table.TableConfig config = null; - - public $generatedName(org.apache.flink.api.table.TableConfig config) { - this.config = config; - } - - public boolean filter(Object _in0) { - $tpe in0 = ($tpe) _in0; - ${pred.code} - return ${pred.resultTerm}; - } - } - """ - } - - LOG.debug(s"""Generated unary predicate "$predicate":\n$code""") - compiler.cook(new StringReader(code)) - val clazz = compiler.getClassLoader().loadClass(generatedName) - val constructor = clazz.getConstructor(classOf[TableConfig]) - constructor.newInstance(config).asInstanceOf[FilterFunction[T]] - } -}