Github user hvanhovell commented on a diff in the pull request:
https://github.com/apache/spark/pull/20750#discussion_r175065574
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
---
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder,
UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.types.{UserDefinedType, _}
+import org.apache.spark.unsafe.Platform
+
+/**
+ * An interpreted unsafe projection. This class reuses the [[UnsafeRow]]
it produces, a consumer
+ * should copy the row if it is being buffered. This class is not thread
safe.
+ *
+ * @param expressions that produces the resulting fields. These
expressions must be bound
+ * to a schema.
+ */
+class InterpretedUnsafeProjection(expressions: Array[Expression]) extends
UnsafeProjection {
+ import InterpretedUnsafeProjection._
+
+ /** Number of (top level) fields in the resulting row. */
+ private[this] val numFields = expressions.length
+
+ /** Array that expression results. */
+ private[this] val values = new Array[Any](numFields)
+
+ /** The row representing the expression results. */
+ private[this] val intermediate = new GenericInternalRow(values)
+
+ /** The row returned by the projection. */
+ private[this] val result = new UnsafeRow(numFields)
+
+ /** The buffer which holds the resulting row's backing data. */
+ private[this] val holder = new BufferHolder(result, numFields * 32)
+
+ /** The writer that writes the intermediate result to the result row. */
+ private[this] val writer: InternalRow => Unit = {
+ val rowWriter = new UnsafeRowWriter(holder, numFields)
+ val baseWriter = generateStructWriter(
+ holder,
+ rowWriter,
+ expressions.map(e => StructField("", e.dataType, e.nullable)))
+ if (!expressions.exists(_.nullable)) {
+ // No nullable fields. The top-level null bit mask will always be
zeroed out.
+ baseWriter
+ } else {
+ // Zero out the null bit mask before we write the row.
+ row => {
+ rowWriter.zeroOutNullBytes()
+ baseWriter(row)
+ }
+ }
+ }
+
+ override def initialize(partitionIndex: Int): Unit = {
+ expressions.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ })
+ }
+
+ override def apply(row: InternalRow): UnsafeRow = {
+ // Put the expression results in the intermediate row.
+ var i = 0
+ while (i < numFields) {
+ values(i) = expressions(i).eval(row)
+ i += 1
+ }
+
+ // Write the intermediate row to an unsafe row.
+ holder.reset()
+ writer(intermediate)
+ result.setTotalSize(holder.totalSize())
+ result
+ }
+}
+
+/**
+ * Helper functions for creating an [[InterpretedUnsafeProjection]].
+ */
+object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
+
+ /**
+ * Returns an [[UnsafeProjection]] for given sequence of bound
Expressions.
+ */
+ override protected def createProjection(exprs: Seq[Expression]):
UnsafeProjection = {
+ // We need to make sure that we do not reuse stateful expressions.
+ val cleanedExpressions = exprs.map(_.transform {
+ case s: Stateful => s.freshCopy()
+ })
+ new InterpretedUnsafeProjection(cleanedExpressions.toArray)
+ }
+
+ /**
+ * Generate a struct writer function. The generated function writes an
[[InternalRow]] to the
+ * given buffer using the given [[UnsafeRowWriter]].
+ */
+ private def generateStructWriter(
+ bufferHolder: BufferHolder,
+ rowWriter: UnsafeRowWriter,
--- End diff --
Yeah we could do that. I think we need to refactor the writers a little bit
anyway, but I would like to do that in a follow-up.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]