http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 45a3213..971770a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -42,7 +42,7 @@ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { +private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index ac3c52e..9bebd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.DataSourceScan.{INPUT_PATHS, PUSHED_FILTERS} -import org.apache.spark.sql.execution.command.ExecutedCommand +import org.apache.spark.sql.execution.DataSourceScanExec.{INPUT_PATHS, PUSHED_FILTERS} +import org.apache.spark.sql.execution.command.ExecutedCommandExec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -105,12 +105,12 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil case l @ LogicalRelation(baseRelation: TableScan, _, _) => - execution.DataSourceScan.create( + execution.DataSourceScanExec.create( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), part, query, overwrite, false) if part.isEmpty => - ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil + ExecutedCommandExec(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil } @@ -214,22 +214,22 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Don't request columns that are only referenced by pushed filters. .filterNot(handledSet.contains) - val scan = execution.DataSourceScan.create( + val scan = execution.DataSourceScanExec.create( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, metadata) - filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) + filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. val requestedColumns = (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq - val scan = execution.DataSourceScan.create( + val scan = execution.DataSourceScanExec.create( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, metadata) - execution.Project( - projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + execution.ProjectExec( + projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index c1a97de..751daa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan} +import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -192,7 +192,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { } val scan = - DataSourceScan.create( + DataSourceScanExec.create( readDataColumns ++ partitionColumns, new FileScanRDD( files.sqlContext, @@ -205,11 +205,11 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { "ReadSchema" -> prunedDataSchema.simpleString)) val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) - val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan) + val withFilter = afterScanFilter.map(execution.FilterExec(_, scan)).getOrElse(scan) val withProjections = if (projects == withFilter.output) { withFilter } else { - execution.Project(projects, withFilter) + execution.ProjectExec(projects, withFilter) } withProjections :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e6079ec..5b96ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -49,9 +49,9 @@ package object debug { } def codegenString(plan: SparkPlan): String = { - val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() plan transform { - case s: WholeStageCodegen => + case s: WholeStageCodegenExec => codegenSubtrees += s s case s => s @@ -86,11 +86,11 @@ package object debug { val debugPlan = plan transform { case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => visited += new TreeNodeRef(s) - DebugNode(s) + DebugExec(s) } debugPrint(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { - case d: DebugNode => d.dumpStats() + case d: DebugExec => d.dumpStats() case _ => } } @@ -104,7 +104,7 @@ package object debug { } } - private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { + private[sql] case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala deleted file mode 100644 index 87a113e..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.exchange - -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration._ - -import org.apache.spark.broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} -import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.ThreadUtils - -/** - * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed - * SparkPlan. - */ -case class BroadcastExchange( - mode: BroadcastMode, - child: SparkPlan) extends Exchange { - - override private[sql] lazy val metrics = Map( - "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"), - "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"), - "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"), - "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)")) - - override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - - override def sameResult(plan: SparkPlan): Boolean = plan match { - case p: BroadcastExchange => - mode.compatibleWith(p.mode) && child.sameResult(p.child) - case _ => false - } - - @transient - private val timeout: Duration = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - @transient - private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - val beforeCollect = System.nanoTime() - // Note that we use .executeCollect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = child.executeCollect() - val beforeBuild = System.nanoTime() - longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum - - // Construct and broadcast the relation. - val relation = mode.transform(input) - val beforeBroadcast = System.nanoTime() - longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 - - val broadcasted = sparkContext.broadcast(relation) - longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - broadcasted - } - }(BroadcastExchange.executionContext) - } - - override protected def doPrepare(): Unit = { - // Materialize the future. - relationFuture - } - - override protected def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException( - "BroadcastExchange does not support the execute() code path.") - } - - override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] - } -} - -object BroadcastExchange { - private[execution] val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128)) -} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala new file mode 100644 index 0000000..573ca19 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.ThreadUtils + +/** + * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of + * a transformed SparkPlan. + */ +case class BroadcastExchangeExec( + mode: BroadcastMode, + child: SparkPlan) extends Exchange { + + override private[sql] lazy val metrics = Map( + "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"), + "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"), + "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"), + "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)")) + + override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + + override def sameResult(plan: SparkPlan): Boolean = plan match { + case p: BroadcastExchangeExec => + mode.compatibleWith(p.mode) && child.sameResult(p.child) + case _ => false + } + + @transient + private val timeout: Duration = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + @transient + private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + val beforeCollect = System.nanoTime() + // Note that we use .executeCollect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = child.executeCollect() + val beforeBuild = System.nanoTime() + longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 + longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + + // Construct and broadcast the relation. + val relation = mode.transform(input) + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + + val broadcasted = sparkContext.broadcast(relation) + longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 + broadcasted + } + }(BroadcastExchangeExec.executionContext) + } + + override protected def doPrepare(): Unit = { + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the execute() code path.") + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } +} + +object BroadcastExchangeExec { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128)) +} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4864db7..446571a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -160,7 +160,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => - BroadcastExchange(mode, child) + BroadcastExchangeExec(mode, child) case (child, distribution) => ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } @@ -237,7 +237,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - Sort(requiredOrdering, global = false, child = child) + SortExec(requiredOrdering, global = false, child = child) } else { child } http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index df7ad48..9da9df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType * differs significantly, the concept is similar to the exchange operator described in * "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe. */ -abstract class Exchange extends UnaryNode { +abstract class Exchange extends UnaryExecNode { override def output: Seq[Attribute] = child.output } @@ -45,7 +45,8 @@ abstract class Exchange extends UnaryNode { * logically identical output will have distinct sets of output attribute ids, so we need to * preserve the original ids because they're what downstream operators are expecting. */ -case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode { +case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange) + extends LeafExecNode { override def sameResult(plan: SparkPlan): Boolean = { // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. @@ -86,7 +87,7 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { if (samePlan.isDefined) { // Keep the output of this exchange, the following plans require that to resolve // attributes. - ReusedExchange(exchange.output, samePlan.get) + ReusedExchangeExec(exchange.output, samePlan.get) } else { sameSchema += exchange exchange http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala deleted file mode 100644 index 89487c6..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.TaskContext -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType - -/** - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - buildSide: BuildSide, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) - extends BinaryNode with HashJoin with CodegenSupport { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode(buildKeys) - buildSide match { - case BuildLeft => - BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil - case BuildRight => - UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil - } - } - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() - streamedPlan.execute().mapPartitions { streamedIter => - val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) - join(streamedIter, hashed, numOutputRows) - } - } - - override def inputRDDs(): Seq[RDD[InternalRow]] = { - streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() - } - - override def doProduce(ctx: CodegenContext): String = { - streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) - } - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - joinType match { - case Inner => codegenInner(ctx, input) - case LeftOuter | RightOuter => codegenOuter(ctx, input) - case LeftSemi => codegenSemi(ctx, input) - case LeftAnti => codegenAnti(ctx, input) - case x => - throw new IllegalArgumentException( - s"BroadcastHashJoin should not take $x as the JoinType") - } - } - - /** - * Returns a tuple of Broadcast of HashedRelation and the variable name for it. - */ - private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { - // create a name for HashedRelation - val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() - val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - val relationTerm = ctx.freshName("relation") - val clsName = broadcastRelation.value.getClass.getName - ctx.addMutableState(clsName, relationTerm, - s""" - | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.estimatedSize()); - """.stripMargin) - (broadcastRelation, relationTerm) - } - - /** - * Returns the code for generating join key for stream side, and expression of whether the key - * has any null in it or not. - */ - private def genStreamSideJoinKey( - ctx: CodegenContext, - input: Seq[ExprCode]): (ExprCode, String) = { - ctx.currentVars = input - if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { - // generate the join key as Long - val ev = streamedKeys.head.genCode(ctx) - (ev, ev.isNull) - } else { - // generate the join key as UnsafeRow - val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) - (ev, s"${ev.value}.anyNull()") - } - } - - /** - * Generates the code for variable of build side. - */ - private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { - ctx.currentVars = null - ctx.INPUT_ROW = matched - buildPlan.output.zipWithIndex.map { case (a, i) => - val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) - if (joinType == Inner) { - ev - } else { - // the variables are needed even there is no matched rows - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val code = s""" - |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; - |if ($matched != null) { - | ${ev.code} - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - |} - """.stripMargin - ExprCode(code, isNull, value) - } - } - } - - /** - * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi - * and Left Anti joins. - */ - private def getJoinCondition( - ctx: CodegenContext, - input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - // filter the output via condition - ctx.currentVars = input ++ buildVars - val ev = - BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) - s""" - |$eval - |${ev.code} - |if (${ev.isNull} || !${ev.value}) continue; - """.stripMargin - } else { - "" - } - (matched, checkCondition, buildVars) - } - - /** - * Generates the code for Inner join. - */ - private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched == null) continue; - |$checkCondition - |$numOutput.add(1); - |${consume(ctx, resultVars)} - """.stripMargin - - } else { - ctx.copyResult = true - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |if ($matches == null) continue; - |while ($matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - |} - """.stripMargin - } - } - - /** - * Generates the code for left or right outer join. - */ - private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) - val numOutput = metricTerm(ctx, "numOutputRows") - - // filter the output via condition - val conditionPassed = ctx.freshName("conditionPassed") - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - ctx.currentVars = input ++ buildVars - val ev = - BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) - s""" - |boolean $conditionPassed = true; - |${eval.trim} - |${ev.code} - |if ($matched != null) { - | $conditionPassed = !${ev.isNull} && ${ev.value}; - |} - """.stripMargin - } else { - s"final boolean $conditionPassed = true;" - } - - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |${checkCondition.trim} - |if (!$conditionPassed) { - | $matched = null; - | // reset the variables those are already evaluated. - | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} - |} - |$numOutput.add(1); - |${consume(ctx, resultVars)} - """.stripMargin - - } else { - ctx.copyResult = true - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val found = ctx.freshName("found") - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |boolean $found = false; - |// the last iteration of this loop is to emit an empty row if there is no matched rows. - |while ($matches != null && $matches.hasNext() || !$found) { - | UnsafeRow $matched = $matches != null && $matches.hasNext() ? - | (UnsafeRow) $matches.next() : null; - | ${checkCondition.trim} - | if (!$conditionPassed) continue; - | $found = true; - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - |} - """.stripMargin - } - } - - /** - * Generates the code for left semi join. - */ - private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched == null) continue; - |$checkCondition - |$numOutput.add(1); - |${consume(ctx, input)} - """.stripMargin - } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val found = ctx.freshName("found") - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |if ($matches == null) continue; - |boolean $found = false; - |while (!$found && $matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | $found = true; - |} - |if (!$found) continue; - |$numOutput.add(1); - |${consume(ctx, input)} - """.stripMargin - } - } - - /** - * Generates the code for anti join. - */ - private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// Check if the key has nulls. - |if (!($anyNull)) { - | // Check if the HashedRelation exists. - | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - | if ($matched != null) { - | // Evaluate the condition. - | $checkCondition - | } - |} - |$numOutput.add(1); - |${consume(ctx, input)} - """.stripMargin - } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val found = ctx.freshName("found") - s""" - |// generate join key for stream side - |${keyEv.code} - |// Check if the key has nulls. - |if (!($anyNull)) { - | // Check if the HashedRelation exists. - | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); - | if ($matches != null) { - | // Evaluate the condition. - | boolean $found = false; - | while (!$found && $matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | $found = true; - | } - | if ($found) continue; - | } - |} - |$numOutput.add(1); - |${consume(ctx, input)} - """.stripMargin - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala new file mode 100644 index 0000000..51399e1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType + +/** + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +case class BroadcastHashJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends BinaryExecNode with HashJoin with CodegenSupport { + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + val mode = HashedRelationBroadcastMode(buildKeys) + buildSide match { + case BuildLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() + streamedPlan.execute().mapPartitions { streamedIter => + val hashed = broadcastRelation.value.asReadOnlyCopy() + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) + join(streamedIter, hashed, numOutputRows) + } + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() + } + + override def doProduce(ctx: CodegenContext): String = { + streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + joinType match { + case Inner => codegenInner(ctx, input) + case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftSemi => codegenSemi(ctx, input) + case LeftAnti => codegenAnti(ctx, input) + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } + } + + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ + private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { + // create a name for HashedRelation + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + val relationTerm = ctx.freshName("relation") + val clsName = broadcastRelation.value.getClass.getName + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); + | incPeakExecutionMemory($relationTerm.estimatedSize()); + """.stripMargin) + (broadcastRelation, relationTerm) + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { + ctx.currentVars = input + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { + ctx.currentVars = null + ctx.INPUT_ROW = matched + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) + if (joinType == Inner) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) + } + } + } + + /** + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. + */ + private def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |$eval + |${ev.code} + |if (${ev.isNull} || !${ev.value}) continue; + """.stripMargin + } else { + "" + } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + + } else { + ctx.copyResult = true + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches == null) continue; + |while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} + """.stripMargin + } + } + + /** + * Generates the code for left or right outer join. + */ + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val numOutput = metricTerm(ctx, "numOutputRows") + + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |boolean $conditionPassed = true; + |${eval.trim} + |${ev.code} + |if ($matched != null) { + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |${checkCondition.trim} + |if (!$conditionPassed) { + | $matched = null; + | // reset the variables those are already evaluated. + | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + + } else { + ctx.copyResult = true + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. + |while ($matches != null && $matches.hasNext() || !$found) { + | UnsafeRow $matched = $matches != null && $matches.hasNext() ? + | (UnsafeRow) $matches.next() : null; + | ${checkCondition.trim} + | if (!$conditionPassed) continue; + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} + """.stripMargin + } + } + + /** + * Generates the code for left semi join. + */ + private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches == null) continue; + |boolean $found = false; + |while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + |} + |if (!$found) continue; + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } + + /** + * Generates the code for anti join. + */ + private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | // Evaluate the condition. + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); + | if ($matches != null) { + | // Evaluate the condition. + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + | } + | if ($found) continue; + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala deleted file mode 100644 index 4ba710c..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ /dev/null @@ -1,331 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.{BitSet, CompactBuffer} - -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - /** BuildRight means the right relation <=> the broadcast relation. */ - private val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def requiredChildDistribution: Seq[Distribution] = buildSide match { - case BuildLeft => - BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil - case BuildRight => - UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil - } - - private[this] def genResultProjection: InternalRow => InternalRow = { - if (joinType == LeftSemi) { - UnsafeProjection.create(output, output) - } else { - // Always put the stream side on left to simplify implementation - // both of left and right side could be null - UnsafeProjection.create( - output, (streamed.output ++ broadcast.output).map(_.withNullability(true))) - } - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output: Seq[Attribute] = { - joinType match { - case Inner => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case LeftExistence(_) => - left.output - case x => - throw new IllegalArgumentException( - s"BroadcastNestedLoopJoin should not take $x as the JoinType") - } - } - - @transient private lazy val boundCondition = { - if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output) - } else { - (r: InternalRow) => true - } - } - - /** - * The implementation for InnerJoin. - */ - private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { - streamed.execute().mapPartitionsInternal { streamedIter => - val buildRows = relation.value - val joinedRow = new JoinedRow - - streamedIter.flatMap { streamedRow => - val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) - if (condition.isDefined) { - joinedRows.filter(boundCondition) - } else { - joinedRows - } - } - } - } - - /** - * The implementation for these joins: - * - * LeftOuter with BuildRight - * RightOuter with BuildLeft - */ - private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { - streamed.execute().mapPartitionsInternal { streamedIter => - val buildRows = relation.value - val joinedRow = new JoinedRow - val nulls = new GenericMutableRow(broadcast.output.size) - - // Returns an iterator to avoid copy the rows. - new Iterator[InternalRow] { - // current row from stream side - private var streamRow: InternalRow = null - // have found a match for current row or not - private var foundMatch: Boolean = false - // the matched result row - private var resultRow: InternalRow = null - // the next index of buildRows to try - private var nextIndex: Int = 0 - - private def findNextMatch(): Boolean = { - if (streamRow == null) { - if (!streamedIter.hasNext) { - return false - } - streamRow = streamedIter.next() - nextIndex = 0 - foundMatch = false - } - while (nextIndex < buildRows.length) { - resultRow = joinedRow(streamRow, buildRows(nextIndex)) - nextIndex += 1 - if (boundCondition(resultRow)) { - foundMatch = true - return true - } - } - if (!foundMatch) { - resultRow = joinedRow(streamRow, nulls) - streamRow = null - true - } else { - resultRow = null - streamRow = null - findNextMatch() - } - } - - override def hasNext(): Boolean = { - resultRow != null || findNextMatch() - } - override def next(): InternalRow = { - val r = resultRow - resultRow = null - r - } - } - } - } - - /** - * The implementation for these joins: - * - * LeftSemi with BuildRight - * Anti with BuildRight - */ - private def leftExistenceJoin( - relation: Broadcast[Array[InternalRow]], - exists: Boolean): RDD[InternalRow] = { - assert(buildSide == BuildRight) - streamed.execute().mapPartitionsInternal { streamedIter => - val buildRows = relation.value - val joinedRow = new JoinedRow - - if (condition.isDefined) { - streamedIter.filter(l => - buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists - ) - } else if (buildRows.nonEmpty == exists) { - streamedIter - } else { - Iterator.empty - } - } - } - - /** - * The implementation for these joins: - * - * LeftOuter with BuildLeft - * RightOuter with BuildRight - * FullOuter - * LeftSemi with BuildLeft - * Anti with BuildLeft - */ - private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val streamRdd = streamed.execute() - - val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter => - val buildRows = relation.value - val matched = new BitSet(buildRows.length) - val joinedRow = new JoinedRow - - streamedIter.foreach { streamedRow => - var i = 0 - while (i < buildRows.length) { - if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { - matched.set(i) - } - i += 1 - } - } - Seq(matched).toIterator - } - - val matchedBroadcastRows = matchedBuildRows.fold( - new BitSet(relation.value.length) - )(_ | _) - - if (joinType == LeftSemi) { - assert(buildSide == BuildLeft) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val rel = relation.value - while (i < rel.length) { - if (matchedBroadcastRows.get(i)) { - buf += rel(i).copy() - } - i += 1 - } - return sparkContext.makeRDD(buf) - } - - val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value - val joinedRow = new JoinedRow - joinedRow.withLeft(nulls) - while (i < buildRows.length) { - if (!matchedBroadcastRows.get(i)) { - buf += joinedRow.withRight(buildRows(i)).copy() - } - i += 1 - } - buf - } - - if (joinType == LeftAnti) { - return sparkContext.makeRDD(notMatchedBroadcastRows) - } - - val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => - val buildRows = relation.value - val joinedRow = new JoinedRow - val nulls = new GenericMutableRow(broadcast.output.size) - - streamedIter.flatMap { streamedRow => - var i = 0 - var foundMatch = false - val matchedRows = new CompactBuffer[InternalRow] - - while (i < buildRows.length) { - if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { - matchedRows += joinedRow.copy() - foundMatch = true - } - i += 1 - } - - if (!foundMatch && joinType == FullOuter) { - matchedRows += joinedRow(streamedRow, nulls).copy() - } - matchedRows.iterator - } - } - - sparkContext.union( - matchedStreamRows, - sparkContext.makeRDD(notMatchedBroadcastRows) - ) - } - - protected override def doExecute(): RDD[InternalRow] = { - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() - - val resultRdd = (joinType, buildSide) match { - case (Inner, _) => - innerJoin(broadcastedRelation) - case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => - outerJoin(broadcastedRelation) - case (LeftSemi, BuildRight) => - leftExistenceJoin(broadcastedRelation, exists = true) - case (LeftAnti, BuildRight) => - leftExistenceJoin(broadcastedRelation, exists = false) - case _ => - /** - * LeftOuter with BuildLeft - * RightOuter with BuildRight - * FullOuter - * LeftSemi with BuildLeft - * Anti with BuildLeft - */ - defaultJoin(broadcastedRelation) - } - - val numOutputRows = longMetric("numOutputRows") - resultRdd.mapPartitionsInternal { iter => - val resultProj = genResultProjection - iter.map { r => - numOutputRows += 1 - resultProj(r) - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala new file mode 100644 index 0000000..51afa00 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.collection.{BitSet, CompactBuffer} + +case class BroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryExecNode { + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def requiredChildDistribution: Seq[Distribution] = buildSide match { + case BuildLeft => + BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + + private[this] def genResultProjection: InternalRow => InternalRow = { + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (streamed.output ++ broadcast.output).map(_.withNullability(true))) + } + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") + } + } + + @transient private lazy val boundCondition = { + if (condition.isDefined) { + newPredicate(condition.get, streamed.output ++ broadcast.output) + } else { + (r: InternalRow) => true + } + } + + /** + * The implementation for InnerJoin. + */ + private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + + streamedIter.flatMap { streamedRow => + val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) + if (condition.isDefined) { + joinedRows.filter(boundCondition) + } else { + joinedRows + } + } + } + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildRight + * RightOuter with BuildLeft + */ + private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) + + // Returns an iterator to avoid copy the rows. + new Iterator[InternalRow] { + // current row from stream side + private var streamRow: InternalRow = null + // have found a match for current row or not + private var foundMatch: Boolean = false + // the matched result row + private var resultRow: InternalRow = null + // the next index of buildRows to try + private var nextIndex: Int = 0 + + private def findNextMatch(): Boolean = { + if (streamRow == null) { + if (!streamedIter.hasNext) { + return false + } + streamRow = streamedIter.next() + nextIndex = 0 + foundMatch = false + } + while (nextIndex < buildRows.length) { + resultRow = joinedRow(streamRow, buildRows(nextIndex)) + nextIndex += 1 + if (boundCondition(resultRow)) { + foundMatch = true + return true + } + } + if (!foundMatch) { + resultRow = joinedRow(streamRow, nulls) + streamRow = null + true + } else { + resultRow = null + streamRow = null + findNextMatch() + } + } + + override def hasNext(): Boolean = { + resultRow != null || findNextMatch() + } + override def next(): InternalRow = { + val r = resultRow + resultRow = null + r + } + } + } + } + + /** + * The implementation for these joins: + * + * LeftSemi with BuildRight + * Anti with BuildRight + */ + private def leftExistenceJoin( + relation: Broadcast[Array[InternalRow]], + exists: Boolean): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + + if (condition.isDefined) { + streamedIter.filter(l => + buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists + ) + } else if (buildRows.nonEmpty == exists) { + streamedIter + } else { + Iterator.empty + } + } + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + * Anti with BuildLeft + */ + private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val streamRdd = streamed.execute() + + val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val matched = new BitSet(buildRows.length) + val joinedRow = new JoinedRow + + streamedIter.foreach { streamedRow => + var i = 0 + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matched.set(i) + } + i += 1 + } + } + Seq(matched).toIterator + } + + val matchedBroadcastRows = matchedBuildRows.fold( + new BitSet(relation.value.length) + )(_ | _) + + if (joinType == LeftSemi) { + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() + } + i += 1 + } + return sparkContext.makeRDD(buf) + } + + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 + } + buf + } + + if (joinType == LeftAnti) { + return sparkContext.makeRDD(notMatchedBroadcastRows) + } + + val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) + + streamedIter.flatMap { streamedRow => + var i = 0 + var foundMatch = false + val matchedRows = new CompactBuffer[InternalRow] + + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matchedRows += joinedRow.copy() + foundMatch = true + } + i += 1 + } + + if (!foundMatch && joinType == FullOuter) { + matchedRows += joinedRow(streamedRow, nulls).copy() + } + matchedRows.iterator + } + } + + sparkContext.union( + matchedStreamRows, + sparkContext.makeRDD(notMatchedBroadcastRows) + ) + } + + protected override def doExecute(): RDD[InternalRow] = { + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + + val resultRdd = (joinType, buildSide) match { + case (Inner, _) => + innerJoin(broadcastedRelation) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => + outerJoin(broadcastedRelation) + case (LeftSemi, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = true) + case (LeftAnti, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = false) + case _ => + /** + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + * Anti with BuildLeft + */ + defaultJoin(broadcastedRelation) + } + + val numOutputRows = longMetric("numOutputRows") + resultRdd.mapPartitionsInternal { iter => + val resultProj = genResultProjection + iter.map { r => + numOutputRows += 1 + resultProj(r) + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala deleted file mode 100644 index b1de52b..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark._ -import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter - -/** - * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, - * will be much faster than building the right partition for every row in left RDD, it also - * materialize the right RDD (in case of the right RDD is nondeterministic). - */ -private[spark] -class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) - extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { - - override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - // We will not sort the rows, so prefixComparator and recordComparator are null. - val sorter = UnsafeExternalSorter.create( - context.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - context, - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - false) - - val partition = split.asInstanceOf[CartesianPartition] - for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) - } - - // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] - def createIter(): Iterator[UnsafeRow] = { - val iter = sorter.getIterator - val unsafeRow = new UnsafeRow(numFieldsOfRight) - new Iterator[UnsafeRow] { - override def hasNext: Boolean = { - iter.hasNext - } - override def next(): UnsafeRow = { - iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - unsafeRow - } - } - } - - val resultIter = - for (x <- rdd1.iterator(partition.s1, context); - y <- createIter()) yield (x, y) - CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources) - } -} - - -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output ++ right.output - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] - val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) - pair.mapPartitionsInternal { iter => - val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) - iter.map { r => - numOutputRows += 1 - joiner.join(r._1, r._2) - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala new file mode 100644 index 0000000..3ce7c0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +/** + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ +private[spark] +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + false) + + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow(numFieldsOfRight) + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(partition.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} + + +case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode { + override def output: Seq[Attribute] = left.output ++ right.output + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] + val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] + + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) + iter.map { r => + numOutputRows += 1 + joiner.join(r._1, r._2) + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
