Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/20750#discussion_r174873582 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala --- @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{UserDefinedType, _} +import org.apache.spark.unsafe.Platform + +/** + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer + * should copy the row if it is being buffered. This class is not thread safe. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { + import InterpretedUnsafeProjection._ + + /** Number of (top level) fields in the resulting row. */ + private[this] val numFields = expressions.length + + /** Array that expression results. */ + private[this] val values = new Array[Any](numFields) + + /** The row representing the expression results. */ + private[this] val intermediate = new GenericInternalRow(values) + + /** The row returned by the projection. */ + private[this] val result = new UnsafeRow(numFields) + + /** The buffer which holds the resulting row's backing data. */ + private[this] val holder = new BufferHolder(result, numFields * 32) + + /** The writer that writes the intermediate result to the result row. */ + private[this] val writer: InternalRow => Unit = { + val rowWriter = new UnsafeRowWriter(holder, numFields) + val baseWriter = generateStructWriter( + holder, + rowWriter, + expressions.map(e => StructField("", e.dataType, e.nullable))) + if (!expressions.exists(_.nullable)) { + // No nullable fields. The top-level null bit mask will always be zeroed out. + baseWriter + } else { + // Zero out the null bit mask before we write the row. + row => { + rowWriter.zeroOutNullBytes() + baseWriter(row) + } + } + } + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + override def apply(row: InternalRow): UnsafeRow = { + // Put the expression results in the intermediate row. + var i = 0 + while (i < numFields) { + values(i) = expressions(i).eval(row) + i += 1 + } + + // Write the intermediate row to an unsafe row. + holder.reset() + writer(intermediate) + result.setTotalSize(holder.totalSize()) + result + } +} + +/** + * Helper functions for creating an [[InterpretedUnsafeProjection]]. + */ +object InterpretedUnsafeProjection extends UnsafeProjectionCreator { + + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedUnsafeProjection(cleanedExpressions.toArray) + } + + /** + * Generate a struct writer function. The generated function writes an [[InternalRow]] to the + * given buffer using the given [[UnsafeRowWriter]]. + */ + private def generateStructWriter( + bufferHolder: BufferHolder, + rowWriter: UnsafeRowWriter, --- End diff -- We can add `UnsafeWriter#getBufferHolder`, so that we don't need to pass 2 parameters.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org