Repository: flink
Updated Branches:
  refs/heads/master b181662be -> dbe707324


[FLINK-5159] [table] Add single-row broadcast nested-loop join.

- Support for arbitrary joins (cross, theta, equi) if one input has exactly one 
row.

This closes #2811.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/dbe70732
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/dbe70732
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/dbe70732

Branch: refs/heads/master
Commit: dbe70732434ff1396e1829080cd14f26b691489a
Parents: b181662
Author: Alexander Shoshin <alexander_shos...@epam.com>
Authored: Wed Nov 9 11:27:57 2016 +0300
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Thu Dec 1 20:42:10 2016 +0100

----------------------------------------------------------------------
 .../nodes/dataset/DataSetSingleRowJoin.scala    | 193 +++++++++++++++++++
 .../api/table/plan/rules/FlinkRuleSets.scala    |   1 +
 .../dataSet/DataSetSingleRowJoinRule.scala      |  82 ++++++++
 .../api/table/runtime/MapJoinLeftRunner.scala   |  33 ++++
 .../api/table/runtime/MapJoinRightRunner.scala  |  33 ++++
 .../api/table/runtime/MapSideJoinRunner.scala   |  51 +++++
 .../flink/api/scala/batch/sql/JoinITCase.scala  |  48 +++++
 .../api/scala/batch/sql/SingleRowJoinTest.scala | 146 ++++++++++++++
 8 files changed, 587 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetSingleRowJoin.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetSingleRowJoin.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetSingleRowJoin.scala
new file mode 100644
index 0000000..0306c00
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetSingleRowJoin.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.plan.nodes.dataset
+
+import org.apache.calcite.plan._
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.metadata.RelMetadataQuery
+import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.api.common.functions.{FlatJoinFunction, 
FlatMapFunction}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.table.codegen.CodeGenerator
+import org.apache.flink.api.table.runtime.{MapJoinLeftRunner, 
MapJoinRightRunner}
+import org.apache.flink.api.table.typeutils.TypeConverter.determineReturnType
+import org.apache.flink.api.table.{BatchTableEnvironment, TableConfig}
+
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+
+/**
+  * Flink RelNode that executes a Join where one of inputs is a single row.
+  */
+class DataSetSingleRowJoin(
+    cluster: RelOptCluster,
+    traitSet: RelTraitSet,
+    leftNode: RelNode,
+    rightNode: RelNode,
+    leftIsSingle: Boolean,
+    rowRelDataType: RelDataType,
+    joinCondition: RexNode,
+    joinRowType: RelDataType,
+    ruleDescription: String)
+  extends BiRel(cluster, traitSet, leftNode, rightNode)
+  with DataSetRel {
+
+  override def deriveRowType() = rowRelDataType
+
+  override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): 
RelNode = {
+    new DataSetSingleRowJoin(
+      cluster,
+      traitSet,
+      inputs.get(0),
+      inputs.get(1),
+      leftIsSingle,
+      getRowType,
+      joinCondition,
+      joinRowType,
+      ruleDescription)
+  }
+
+  override def toString: String = {
+    s"$joinTypeToString(where: ($joinConditionToString), join: 
($joinSelectionToString))"
+  }
+
+  override def explainTerms(pw: RelWriter): RelWriter = {
+    super.explainTerms(pw)
+      .item("where", joinConditionToString)
+      .item("join", joinSelectionToString)
+      .item("joinType", joinTypeToString)
+  }
+
+  override def computeSelfCost (planner: RelOptPlanner, metadata: 
RelMetadataQuery): RelOptCost = {
+    val child = if (leftIsSingle) {
+      this.getRight
+    } else {
+      this.getLeft
+    }
+    val rowCnt = metadata.getRowCount(child)
+    val rowSize = this.estimateRowSize(child.getRowType)
+    planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize)
+  }
+
+  override def translateToPlan(
+      tableEnv: BatchTableEnvironment,
+      expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
+
+    val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
+    val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
+    val broadcastSetName = "joinSet"
+    val mapSideJoin = generateMapFunction(
+      tableEnv.getConfig,
+      leftDataSet.getType,
+      rightDataSet.getType,
+      leftIsSingle,
+      joinCondition,
+      broadcastSetName,
+      expectedType)
+
+    val (multiRowDataSet, singleRowDataSet) =
+      if (leftIsSingle) {
+        (rightDataSet, leftDataSet)
+      } else {
+        (leftDataSet, rightDataSet)
+      }
+
+    multiRowDataSet
+      .flatMap(mapSideJoin)
+      .withBroadcastSet(singleRowDataSet, broadcastSetName)
+      .name(getMapOperatorName)
+      .asInstanceOf[DataSet[Any]]
+  }
+
+  private def generateMapFunction(
+      config: TableConfig,
+      inputType1: TypeInformation[Any],
+      inputType2: TypeInformation[Any],
+      firstIsSingle: Boolean,
+      joinCondition: RexNode,
+      broadcastInputSetName: String,
+      expectedType: Option[TypeInformation[Any]]): FlatMapFunction[Any, Any] = 
{
+
+    val codeGenerator = new CodeGenerator(
+      config,
+      false,
+      inputType1,
+      Some(inputType2))
+
+    val returnType = determineReturnType(
+      getRowType,
+      expectedType,
+      config.getNullCheck,
+      config.getEfficientTypeUsage)
+
+    val conversion = codeGenerator.generateConverterResultExpression(
+      returnType,
+      joinRowType.getFieldNames)
+
+    val condition = codeGenerator.generateExpression(joinCondition)
+
+    val joinMethodBody = s"""
+                  |${condition.code}
+                  |if (${condition.resultTerm}) {
+                  |  ${conversion.code}
+                  |  
${codeGenerator.collectorTerm}.collect(${conversion.resultTerm});
+                  |}
+                  |""".stripMargin
+
+    val genFunction = codeGenerator.generateFunction(
+      ruleDescription,
+      classOf[FlatJoinFunction[Any, Any, Any]],
+      joinMethodBody,
+      returnType)
+
+    if (firstIsSingle) {
+      new MapJoinRightRunner[Any, Any, Any](
+        genFunction.name,
+        genFunction.code,
+        genFunction.returnType,
+        broadcastInputSetName)
+    } else {
+      new MapJoinLeftRunner[Any, Any, Any](
+        genFunction.name,
+        genFunction.code,
+        genFunction.returnType,
+        broadcastInputSetName)
+    }
+  }
+
+  private def getMapOperatorName: String = {
+    s"where: ($joinConditionToString), join: ($joinSelectionToString)"
+  }
+
+  private def joinSelectionToString: String = {
+    getRowType.getFieldNames.asScala.toList.mkString(", ")
+  }
+
+  private def joinConditionToString: String = {
+    val inFields = joinRowType.getFieldNames.asScala.toList
+    getExpressionString(joinCondition, inFields, None)
+  }
+
+  private def joinTypeToString: String = {
+    "NestedLoopJoin"
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 26c025e..9e20df4 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -101,6 +101,7 @@ object FlinkRuleSets {
     DataSetAggregateWithNullValuesRule.INSTANCE,
     DataSetCalcRule.INSTANCE,
     DataSetJoinRule.INSTANCE,
+    DataSetSingleRowJoinRule.INSTANCE,
     DataSetScanRule.INSTANCE,
     DataSetUnionRule.INSTANCE,
     DataSetIntersectRule.INSTANCE,

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
new file mode 100644
index 0000000..8109fcf
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.core.JoinRelType
+import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalJoin}
+import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, 
DataSetSingleRowJoin}
+
+class DataSetSingleRowJoinRule
+  extends ConverterRule(
+      classOf[LogicalJoin],
+      Convention.NONE,
+      DataSetConvention.INSTANCE,
+      "DataSetSingleRowCrossRule") {
+
+  override def matches(call: RelOptRuleCall): Boolean = {
+    val join = call.rel(0).asInstanceOf[LogicalJoin]
+
+    if (isInnerJoin(join)) {
+      isGlobalAggregation(join.getRight.asInstanceOf[RelSubset].getOriginal) ||
+        isGlobalAggregation(join.getLeft.asInstanceOf[RelSubset].getOriginal)
+    } else {
+      false
+    }
+  }
+
+  private def isInnerJoin(join: LogicalJoin) = {
+    join.getJoinType == JoinRelType.INNER
+  }
+
+  private def isGlobalAggregation(node: RelNode) = {
+    node.isInstanceOf[LogicalAggregate] &&
+      isSingleRow(node.asInstanceOf[LogicalAggregate])
+  }
+
+  private def isSingleRow(agg: LogicalAggregate) = {
+    agg.getGroupSet.isEmpty
+  }
+
+  override def convert(rel: RelNode): RelNode = {
+    val join = rel.asInstanceOf[LogicalJoin]
+    val traitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+    val dataSetLeftNode = RelOptRule.convert(join.getLeft, 
DataSetConvention.INSTANCE)
+    val dataSetRightNode = RelOptRule.convert(join.getRight, 
DataSetConvention.INSTANCE)
+    val leftIsSingle = 
isGlobalAggregation(join.getLeft.asInstanceOf[RelSubset].getOriginal)
+
+    new DataSetSingleRowJoin(
+      rel.getCluster,
+      traitSet,
+      dataSetLeftNode,
+      dataSetRightNode,
+      leftIsSingle,
+      rel.getRowType,
+      join.getCondition,
+      join.getRowType,
+      description)
+  }
+}
+
+object DataSetSingleRowJoinRule {
+  val INSTANCE: RelOptRule = new DataSetSingleRowJoinRule
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinLeftRunner.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinLeftRunner.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinLeftRunner.scala
new file mode 100644
index 0000000..76650c2
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinLeftRunner.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.util.Collector
+
+class MapJoinLeftRunner[IN1, IN2, OUT](
+    name: String,
+    code: String,
+    @transient returnType: TypeInformation[OUT],
+    broadcastSetName: String)
+  extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, 
broadcastSetName) {
+
+  override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit =
+    function.join(multiInput, singleInput, out)
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinRightRunner.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinRightRunner.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinRightRunner.scala
new file mode 100644
index 0000000..52b01cf
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapJoinRightRunner.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.util.Collector
+
+class MapJoinRightRunner[IN1, IN2, OUT](
+    name: String,
+    code: String,
+    @transient returnType: TypeInformation[OUT],
+    broadcastSetName: String)
+  extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, 
broadcastSetName) {
+
+  override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit =
+    function.join(singleInput, multiInput, out)
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapSideJoinRunner.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapSideJoinRunner.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapSideJoinRunner.scala
new file mode 100644
index 0000000..b355d49
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/MapSideJoinRunner.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime
+
+import org.apache.flink.api.common.functions.{FlatJoinFunction, 
RichFlatMapFunction}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable
+import org.apache.flink.api.table.codegen.Compiler
+import org.apache.flink.configuration.Configuration
+import org.slf4j.LoggerFactory
+
+abstract class MapSideJoinRunner[IN1, IN2, SINGLE_IN, MULTI_IN, OUT](
+    name: String,
+    code: String,
+    @transient returnType: TypeInformation[OUT],
+    broadcastSetName: String)
+  extends RichFlatMapFunction[MULTI_IN, OUT]
+    with ResultTypeQueryable[OUT]
+    with Compiler[FlatJoinFunction[IN1, IN2, OUT]] {
+
+  val LOG = LoggerFactory.getLogger(this.getClass)
+
+  protected var function: FlatJoinFunction[IN1, IN2, OUT] = _
+  protected var singleInput: SINGLE_IN = _
+
+  override def open(parameters: Configuration): Unit = {
+    LOG.debug(s"Compiling FlatJoinFunction: $name \n\n Code:\n$code")
+    val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
+    LOG.debug("Instantiating FlatJoinFunction.")
+    function = clazz.newInstance()
+    singleInput = 
getRuntimeContext.getBroadcastVariable(broadcastSetName).get(0)
+  }
+
+  override def getProducedType: TypeInformation[OUT] = returnType
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
index 6a02fd4..68f63c3 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
@@ -314,4 +314,52 @@ class JoinITCase(
     val results = tEnv.sql(sqlQuery).toDataSet[Row].collect()
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
+
+  @Test(expected = classOf[TableException])
+  def testCrossJoin(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val table1 = 
CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3)
+    val table2 = 
CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('b1, 'b2, 'b3)
+    tEnv.registerTable("A", table1)
+    tEnv.registerTable("B", table2)
+
+    val sqlQuery = "SELECT a1, b1 FROM A CROSS JOIN B"
+    tEnv.sql(sqlQuery).count
+  }
+
+  @Test
+  def testCrossJoinWithLeftSingleRowInput(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val table = 
CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3)
+    tEnv.registerTable("A", table)
+
+    val sqlQuery2 = "SELECT * FROM (SELECT count(*) FROM A) CROSS JOIN A"
+    val expected =
+      "3,1,1,Hi\n" +
+      "3,2,2,Hello\n" +
+      "3,3,2,Hello world"
+    val result = tEnv.sql(sqlQuery2).collect()
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
+  @Test
+  def testCrossJoinWithRightSingleRowInput(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val table = 
CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3)
+    tEnv.registerTable("A", table)
+
+    val sqlQuery1 = "SELECT * FROM A CROSS JOIN (SELECT count(*) FROM A)"
+    val expected =
+      "1,1,Hi,3\n" +
+      "2,2,Hello,3\n" +
+      "3,2,Hello world,3"
+    val result = tEnv.sql(sqlQuery1).collect()
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/dbe70732/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SingleRowJoinTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SingleRowJoinTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SingleRowJoinTest.scala
new file mode 100644
index 0000000..f56b9ae
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/SingleRowJoinTest.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.batch.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+class SingleRowJoinTest extends TableTestBase {
+
+  @Test
+  def testSingleRowEquiJoin(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, String)]("A", 'a1, 'a2)
+
+    val query =
+      "SELECT a1, a2 " +
+      "FROM A, (SELECT count(a1) AS cnt FROM A) " +
+      "WHERE a1 = cnt"
+
+    val expected =
+      unaryNode(
+        "DataSetCalc",
+        binaryNode(
+          "DataSetSingleRowJoin",
+          batchTableNode(0),
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetUnion",
+              unaryNode(
+                "DataSetValues",
+                batchTableNode(0),
+                tuples(List(null, null)),
+                term("values", "a1", "a2")
+              ),
+              term("union","a1","a2")
+            ),
+            term("select", "COUNT(a1) AS cnt")
+          ),
+          term("where", "true"),
+          term("join", "a1", "a2", "cnt"),
+          term("joinType", "NestedLoopJoin")
+        ),
+        term("select", "a1", "a2"),
+        term("where", "=(CAST(a1), cnt)")
+      )
+
+    util.verifySql(query, expected)
+  }
+
+  @Test
+  def testSingleRowNotEquiJoin(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, String)]("A", 'a1, 'a2)
+
+    val query =
+      "SELECT a1, a2 " +
+      "FROM A, (SELECT count(a1) AS cnt FROM A) " +
+      "WHERE a1 < cnt"
+
+    val expected =
+      unaryNode(
+        "DataSetCalc",
+        binaryNode(
+          "DataSetSingleRowJoin",
+          batchTableNode(0),
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetUnion",
+              unaryNode(
+                "DataSetValues",
+                batchTableNode(0),
+                tuples(List(null, null)),
+                term("values", "a1", "a2")
+              ),
+              term("union","a1","a2")
+            ),
+            term("select", "COUNT(a1) AS cnt")
+          ),
+          term("where", "true"),
+          term("join", "a1", "a2", "cnt"),
+          term("joinType", "NestedLoopJoin")
+        ),
+        term("select", "a1", "a2"),
+        term("where", "<(a1, cnt)")
+      )
+
+    util.verifySql(query, expected)
+  }
+
+  @Test
+  def testSingleRowJoinWithComplexPredicate(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long)]("A", 'a1, 'a2)
+    util.addTable[(Int, Long)]("B", 'b1, 'b2)
+
+    val query =
+      "SELECT a1, a2, b1, b2 " +
+        "FROM A, (SELECT min(b1) AS b1, max(b2) AS b2 FROM B) " +
+        "WHERE a1 < b1 AND a2 = b2"
+
+    val expected = binaryNode(
+      "DataSetSingleRowJoin",
+      batchTableNode(0),
+      unaryNode(
+        "DataSetAggregate",
+        unaryNode(
+          "DataSetUnion",
+          unaryNode(
+            "DataSetValues",
+            batchTableNode(1),
+            tuples(List(null, null)),
+            term("values", "b1", "b2")
+          ),
+          term("union","b1","b2")
+        ),
+        term("select", "MIN(b1) AS b1", "MAX(b2) AS b2")
+      ),
+      term("where", "AND(<(a1, b1)", "=(a2, b2))"),
+      term("join", "a1", "a2", "b1", "b2"),
+      term("joinType", "NestedLoopJoin")
+    )
+
+    util.verifySql(query, expected)
+  }
+}

Reply via email to