[ https://issues.apache.org/jira/browse/FLINK-7755?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16212508#comment-16212508 ]
ASF GitHub Bot commented on FLINK-7755: --------------------------------------- Github user xccui commented on a diff in the pull request: https://github.com/apache/flink/pull/4858#discussion_r145938669 --- Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala --- @@ -156,65 +163,394 @@ class DataSetJoin( val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv) val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv) - val (joinOperator, nullCheck) = joinType match { - case JoinRelType.INNER => (leftDataSet.join(rightDataSet), false) - case JoinRelType.LEFT => (leftDataSet.leftOuterJoin(rightDataSet), true) - case JoinRelType.RIGHT => (leftDataSet.rightOuterJoin(rightDataSet), true) - case JoinRelType.FULL => (leftDataSet.fullOuterJoin(rightDataSet), true) + joinType match { + case JoinRelType.INNER => + addInnerJoin( + leftDataSet, + rightDataSet, + leftKeys.toArray, + rightKeys.toArray, + returnType, + config) + case JoinRelType.LEFT => + addLeftOuterJoin( + leftDataSet, + rightDataSet, + leftKeys.toArray, + rightKeys.toArray, + returnType, + config) + case JoinRelType.RIGHT => + addRightOuterJoin( + leftDataSet, + rightDataSet, + leftKeys.toArray, + rightKeys.toArray, + returnType, + config) + case JoinRelType.FULL => + addFullOuterJoin( + leftDataSet, + rightDataSet, + leftKeys.toArray, + rightKeys.toArray, + returnType, + config) } + } - if (nullCheck && !config.getNullCheck) { - throw TableException("Null check in TableConfig must be enabled for outer joins.") - } + private def addInnerJoin( + left: DataSet[Row], + right: DataSet[Row], + leftKeys: Array[Int], + rightKeys: Array[Int], + resultType: TypeInformation[Row], + config: TableConfig): DataSet[Row] = { val generator = new FunctionCodeGenerator( config, - nullCheck, - leftDataSet.getType, - Some(rightDataSet.getType)) + false, + left.getType, + Some(right.getType)) val conversion = generator.generateConverterResultExpression( - returnType, + resultType, joinRowType.getFieldNames) - var body = "" + val condition = generator.generateExpression(joinCondition) + val body = + s""" + |${condition.code} + |if (${condition.resultTerm}) { + | ${conversion.code} + | ${generator.collectorTerm}.collect(${conversion.resultTerm}); + |} + |""".stripMargin - if (joinInfo.isEqui) { - // only equality condition - body = s""" - |${conversion.code} - |${generator.collectorTerm}.collect(${conversion.resultTerm}); - |""".stripMargin - } - else { - val nonEquiPredicates = joinInfo.getRemaining(this.cluster.getRexBuilder) - val condition = generator.generateExpression(nonEquiPredicates) - body = s""" - |${condition.code} - |if (${condition.resultTerm}) { - | ${conversion.code} - | ${generator.collectorTerm}.collect(${conversion.resultTerm}); - |} - |""".stripMargin - } val genFunction = generator.generateFunction( ruleDescription, classOf[FlatJoinFunction[Row, Row, Row]], body, - returnType) + resultType) val joinFun = new FlatJoinRunner[Row, Row, Row]( genFunction.name, genFunction.code, genFunction.returnType) - val joinOpName = - s"where: (${joinConditionToString(joinRowType, joinCondition, getExpressionString)}), " + - s"join: (${joinSelectionToString(joinRowType)})" + left.join(right) + .where(leftKeys: _*) + .equalTo(rightKeys: _*) + .`with`(joinFun) + .name(getJoinOpName) + } + + private def addLeftOuterJoin( + left: DataSet[Row], + right: DataSet[Row], + leftKeys: Array[Int], + rightKeys: Array[Int], + resultType: TypeInformation[Row], + config: TableConfig): DataSet[Row] = { + + if (!config.getNullCheck) { + throw TableException("Null check in TableConfig must be enabled for outer joins.") + } + + val joinOpName = getJoinOpName + + // replace field names by indexed names for easier key handling + val leftType = new RowTypeInfo(left.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*) + val rightType = right.getType.asInstanceOf[RowTypeInfo] + + // partition and sort left input + // this step ensures we can reuse the sorting for all following operations + // (groupBy->join->groupBy) + val partitionedSortedLeft: DataSet[Row] = partitionAndSort(left, leftKeys) + + // deduplicate the rows of the left input + val deduplicatedRowsLeft: DataSet[Row] = deduplicateRows(partitionedSortedLeft, leftType) + + // create JoinFunction to evaluate join predicate + val predFun = generatePredicateFunction(leftType, rightType, config) + val joinOutType = new RowTypeInfo(leftType, rightType, Types.INT) + val joinFun = new LeftOuterJoinRunner(predFun.name, predFun.code, joinOutType) + + // join left and right inputs, evaluate join predicate, and emit join pairs + val nestedLeftKeys = leftKeys.map(i => s"f0.f$i") + val joinPairs = deduplicatedRowsLeft.leftOuterJoin(right, JoinHint.REPARTITION_SORT_MERGE) + .where(nestedLeftKeys: _*) + .equalTo(rightKeys: _*) + .`with`(joinFun) + .withForwardedFieldsFirst("f0->f0") + .name(joinOpName) + + // create GroupReduceFunction to generate the join result + val convFun = generateConversionFunction(leftType, rightType, resultType, config) + val reduceFun = new LeftOuterJoinGroupReduceRunner( + convFun.name, + convFun.code, + convFun.returnType) + + // convert join pairs to result. + // This step ensures we preserve the rows of the left input. + joinPairs + .groupBy("f0") + .reduceGroup(reduceFun) + .name(joinOpName) + .returns(resultType) + } + + private def addRightOuterJoin( + left: DataSet[Row], + right: DataSet[Row], + leftKeys: Array[Int], + rightKeys: Array[Int], + resultType: TypeInformation[Row], + config: TableConfig): DataSet[Row] = { + + if (!config.getNullCheck) { + throw TableException("Null check in TableConfig must be enabled for outer joins.") + } + + val joinOpName = getJoinOpName - joinOperator - .where(leftKeys.toArray: _*) - .equalTo(rightKeys.toArray: _*) + // replace field names by indexed names for easier key handling + val leftType = left.getType.asInstanceOf[RowTypeInfo] + val rightType = new RowTypeInfo(right.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*) + + // partition and sort right input + // this step ensures we can reuse the sorting for all following operations + // (groupBy->join->groupBy) + val partitionedSortedRight: DataSet[Row] = partitionAndSort(right, rightKeys) + + // deduplicate the rows of the right input + val deduplicatedRowsRight: DataSet[Row] = deduplicateRows(partitionedSortedRight, rightType) + + // create JoinFunction to evaluate join predicate + val predFun = generatePredicateFunction(leftType, rightType, config) + val joinOutType = new RowTypeInfo(leftType, rightType, Types.INT) + val joinFun = new RightOuterJoinRunner(predFun.name, predFun.code, joinOutType) + + // join left and right inputs, evaluate join predicate, and emit join pairs + val nestedRightKeys = rightKeys.map(i => s"f0.f$i") + val joinPairs = left.rightOuterJoin(deduplicatedRowsRight, JoinHint.REPARTITION_SORT_MERGE) + .where(leftKeys: _*) + .equalTo(nestedRightKeys: _*) .`with`(joinFun) + .withForwardedFieldsSecond("f0->f1") + .name(joinOpName) + + // create GroupReduceFunction to generate the join result + val convFun = generateConversionFunction(leftType, rightType, resultType, config) + val reduceFun = new RightOuterJoinGroupReduceRunner( + convFun.name, + convFun.code, + convFun.returnType) + + // convert join pairs to result + // This step ensures we preserve the rows of the right input. + joinPairs + .groupBy("f1") + .reduceGroup(reduceFun) .name(joinOpName) + .returns(resultType) } + + private def addFullOuterJoin( + left: DataSet[Row], + right: DataSet[Row], + leftKeys: Array[Int], + rightKeys: Array[Int], + resultType: TypeInformation[Row], + config: TableConfig): DataSet[Row] = { + + if (!config.getNullCheck) { + throw TableException("Null check in TableConfig must be enabled for outer joins.") + } + + val joinOpName = getJoinOpName + + // replace field names by indexed names for easier key handling + val leftType = new RowTypeInfo(left.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*) + val rightType = new RowTypeInfo(right.getType.asInstanceOf[RowTypeInfo].getFieldTypes: _*) + + // partition and sort left and right input + // this step ensures we can reuse the sorting for all following operations + // (groupBy->join->groupBy), except the second grouping to preserve right rows. + val partitionedSortedLeft: DataSet[Row] = partitionAndSort(left, leftKeys) + val partitionedSortedRight: DataSet[Row] = partitionAndSort(right, rightKeys) + + // deduplicate the rows of the left and right input + val deduplicatedRowsLeft: DataSet[Row] = deduplicateRows(partitionedSortedLeft, leftType) + val deduplicatedRowsRight: DataSet[Row] = deduplicateRows(partitionedSortedRight, rightType) + + // create JoinFunction to evaluate join predicate + val predFun = generatePredicateFunction(leftType, rightType, config) + val joinOutType = new RowTypeInfo(leftType, rightType, Types.INT, Types.INT) + val joinFun = new FullOuterJoinRunner(predFun.name, predFun.code, joinOutType) + + // join left and right inputs, evaluate join predicate, and emit join pairs + val nestedLeftKeys = leftKeys.map(i => s"f0.f$i") + val nestedRightKeys = rightKeys.map(i => s"f0.f$i") + val joinPairs = deduplicatedRowsLeft + .fullOuterJoin(deduplicatedRowsRight, JoinHint.REPARTITION_SORT_MERGE) + .where(nestedLeftKeys: _*) + .equalTo(nestedRightKeys: _*) + .`with`(joinFun) + .withForwardedFieldsFirst("f0->f0") + .withForwardedFieldsSecond("f0->f1") + .name(joinOpName) + + // create GroupReduceFunctions to generate the join result + val convFun = generateConversionFunction(leftType, rightType, resultType, config) + val leftReduceFun = new LeftFullOuterJoinGroupReduceRunner( + convFun.name, + convFun.code, + convFun.returnType) + val rightReduceFun = new RightFullOuterJoinGroupReduceRunner( + convFun.name, + convFun.code, + convFun.returnType) + + // compute joined (left + right) and left preserved (left + null) + val joinedAndLeftPreserved = joinPairs + // filter for pairs with left row + .filter(new FilterFunction[Row](){ + override def filter(row: Row): Boolean = row.getField(0) != null}) + .groupBy("f0") + .reduceGroup(leftReduceFun) + .name(joinOpName) + .returns(resultType) + + // compute right preserved (null + right) + val rightPreserved = joinPairs + // filter for pairs with right row + .filter(new FilterFunction[Row](){ + override def filter(row: Row): Boolean = row.getField(1) != null}) + .groupBy("f1") + .reduceGroup(rightReduceFun) + .name(joinOpName) + .returns(resultType) + + // union joined (left + right), left preserved (left + null), and right preserved (null + right) + joinedAndLeftPreserved.union(rightPreserved) + } + + private def getJoinOpName: String = { + s"where: (${joinConditionToString(joinRowType, joinCondition, getExpressionString)}), " + + s"join: (${joinSelectionToString(joinRowType)})" + } + + /** Returns an array of indicies with some indicies being a prefix. */ + private def getFullIndiciesWithPrefix(keys: Array[Int], numFields: Int): Array[Int] = { + // get indicies of all fields which are not keys + val nonKeys = (0 until numFields).filter(i => !keys.contains(i)) + // return all field indicies prefixed by keys + keys ++ nonKeys + } + + /** + * Partitions the data set on the join keys and sort it on all field with the join keys being a + * prefix. + */ + private def partitionAndSort( + dataSet: DataSet[Row], + partitionKeys: Array[Int]): DataSet[Row] = { + + // construct full sort keys with partitionKeys being a prefix + val sortKeys = getFullIndiciesWithPrefix(partitionKeys, dataSet.getType.getArity) + // partition + val partitioned: DataSet[Row] = dataSet.partitionByHash(partitionKeys: _*) + // sort on all fields + sortKeys.foldLeft(partitioned: DataSet[Row]) { (d, i) => + d.sortPartition(i, Order.ASCENDING).asInstanceOf[DataSet[Row]] + } + } + + /** + * Deduplicates the rows of a data set and emits a row for each unique row with with the first + * field being the unique row and the second field being the number of duplicates of the row. + */ + private def deduplicateRows( --- End diff -- The function name is a little bit misleading. How about `foldIdenticalRows`? > Null values are not correctly handled by batch inner and outer joins > -------------------------------------------------------------------- > > Key: FLINK-7755 > URL: https://issues.apache.org/jira/browse/FLINK-7755 > Project: Flink > Issue Type: Bug > Components: Table API & SQL > Affects Versions: 1.4.0, 1.3.2 > Reporter: Fabian Hueske > Assignee: Fabian Hueske > Priority: Blocker > Fix For: 1.4.0, 1.3.3 > > > Join predicates of batch joins are not correctly evaluated according to > three-value logic. > This affects inner as well as outer joins. > The problem is that some equality predicates are only evaluated by the > internal join algorithms of Flink which are based on {{TypeComparator}}. The > field {{TypeComparator}} for {{Row}} are implemented such that {{null == > null}} results in {{TRUE}} to ensure correct ordering and grouping. However, > three-value logic requires that {{null == null}} results to {{UNKNOWN}} (or > null). The code generator implements this logic correctly, but for equality > predicates, no code is generated. > For outer joins, the problem is a bit tricker because these do not support > code-generated predicates yet (see FLINK-5520). FLINK-5498 proposes a > solution for this issue. > We also need to extend several of the existing tests and add null values to > ensure that the join logic is correctly implemented. -- This message was sent by Atlassian JIRA (v6.4.14#64029)