Repository: spark Updated Branches: refs/heads/master 1e6f76059 -> dce1610ae
[SPARK-22514][SQL] move ColumnVector.Array and ColumnarBatch.Row to individual files ## What changes were proposed in this pull request? Logically the `Array` doesn't belong to `ColumnVector`, and `Row` doesn't belong to `ColumnarBatch`. e.g. `ColumnVector` needs to return `Array` for `getArray`, and `Row` for `getStruct`. `Array` and `Row` can return each other with the `getArray`/`getStruct` methods. This is also a step to make `ColumnVector` public, it's cleaner to have `Array` and `Row` as top-level classes. This PR is just code moving around, with 2 renaming: `Array` -> `VectorBasedArray`, `Row` -> `VectorBasedRow`. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes #19740 from cloud-fan/vector. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dce1610a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dce1610a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dce1610a Branch: refs/heads/master Commit: dce1610ae376af00712ba7f4c99bfb4c006dbaec Parents: 1e6f760 Author: Wenchen Fan <[email protected]> Authored: Wed Nov 15 14:42:37 2017 +0100 Committer: Wenchen Fan <[email protected]> Committed: Wed Nov 15 14:42:37 2017 +0100 ---------------------------------------------------------------------- .../execution/vectorized/AggregateHashMap.java | 2 +- .../execution/vectorized/ArrowColumnVector.java | 6 +- .../sql/execution/vectorized/ColumnVector.java | 202 +----------- .../execution/vectorized/ColumnVectorUtils.java | 2 +- .../sql/execution/vectorized/ColumnarArray.java | 208 ++++++++++++ .../sql/execution/vectorized/ColumnarBatch.java | 326 +----------------- .../sql/execution/vectorized/ColumnarRow.java | 327 +++++++++++++++++++ .../vectorized/OffHeapColumnVector.java | 2 +- .../vectorized/OnHeapColumnVector.java | 2 +- .../vectorized/WritableColumnVector.java | 14 +- .../execution/aggregate/HashAggregateExec.scala | 10 +- .../aggregate/VectorizedHashMapGenerator.scala | 12 +- .../vectorized/ColumnVectorSuite.scala | 40 +-- 13 files changed, 597 insertions(+), 556 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index cb3ad4e..9467435 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -72,7 +72,7 @@ public class AggregateHashMap { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public ColumnarBatch.Row findOrInsert(long key) { + public ColumnarRow findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { columnVectors[0].putLong(numRows, key); http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 51ea719..949035b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -251,7 +251,7 @@ public final class ArrowColumnVector extends ColumnVector { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { throw new UnsupportedOperationException(); } @@ -330,7 +330,7 @@ public final class ArrowColumnVector extends ColumnVector { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - resultArray = new ColumnVector.Array(childColumns[0]); + resultArray = new ColumnarArray(childColumns[0]); } else if (vector instanceof MapVector) { MapVector mapVector = (MapVector) vector; accessor = new StructAccessor(mapVector); @@ -339,7 +339,7 @@ public final class ArrowColumnVector extends ColumnVector { for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } - resultStruct = new ColumnarBatch.Row(childColumns); + resultStruct = new ColumnarRow(childColumns); } else { throw new UnsupportedOperationException(); } http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index c4b519f..666fd63 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,11 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; /** @@ -42,190 +40,6 @@ import org.apache.spark.unsafe.types.UTF8String; * ColumnVectors are intended to be reused. */ public abstract class ColumnVector implements AutoCloseable { - - /** - * Holder object to return an array. This object is intended to be reused. Callers should - * copy the data out if it needs to be stored. - */ - public static final class Array extends ArrayData { - // The data for this array. This array contains elements from - // data[offset] to data[offset + length). - public final ColumnVector data; - public int length; - public int offset; - - // Populate if binary data is required for the Array. This is stored here as an optimization - // for string data. - public byte[] byteArray; - public int byteArrayOffset; - - // Reused staging buffer, used for loading from offheap. - protected byte[] tmpByteArray = new byte[1]; - - protected Array(ColumnVector data) { - this.data = data; - } - - @Override - public int numElements() { return length; } - - @Override - public ArrayData copy() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } - - @Override - public byte[] toByteArray() { return data.getBytes(offset, length); } - - @Override - public short[] toShortArray() { return data.getShorts(offset, length); } - - @Override - public int[] toIntArray() { return data.getInts(offset, length); } - - @Override - public long[] toLongArray() { return data.getLongs(offset, length); } - - @Override - public float[] toFloatArray() { return data.getFloats(offset, length); } - - @Override - public double[] toDoubleArray() { return data.getDoubles(offset, length); } - - // TODO: this is extremely expensive. - @Override - public Object[] array() { - DataType dt = data.dataType(); - Object[] list = new Object[length]; - try { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = get(i, dt); - } - } - return list; - } catch(Exception e) { - throw new RuntimeException("Could not get the array", e); - } - } - - @Override - public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } - - @Override - public boolean getBoolean(int ordinal) { - return data.getBoolean(offset + ordinal); - } - - @Override - public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } - - @Override - public short getShort(int ordinal) { - return data.getShort(offset + ordinal); - } - - @Override - public int getInt(int ordinal) { return data.getInt(offset + ordinal); } - - @Override - public long getLong(int ordinal) { return data.getLong(offset + ordinal); } - - @Override - public float getFloat(int ordinal) { - return data.getFloat(offset + ordinal); - } - - @Override - public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - return data.getDecimal(offset + ordinal, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - return data.getUTF8String(offset + ordinal); - } - - @Override - public byte[] getBinary(int ordinal) { - return data.getBinary(offset + ordinal); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - return data.getStruct(offset + ordinal); - } - - @Override - public ArrayData getArray(int ordinal) { - return data.getArray(offset + ordinal); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else if (dataType instanceof CalendarIntervalType) { - return getInterval(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } - - @Override - public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } - } - /** * Returns the data type of this column. */ @@ -350,7 +164,7 @@ public abstract class ColumnVector implements AutoCloseable { /** * Returns a utility object to get structs. */ - public ColumnarBatch.Row getStruct(int rowId) { + public ColumnarRow getStruct(int rowId) { resultStruct.rowId = rowId; return resultStruct; } @@ -359,7 +173,7 @@ public abstract class ColumnVector implements AutoCloseable { * Returns a utility object to get structs. * provided to keep API compatibility with InternalRow for code generation */ - public ColumnarBatch.Row getStruct(int rowId, int size) { + public ColumnarRow getStruct(int rowId, int size) { resultStruct.rowId = rowId; return resultStruct; } @@ -367,7 +181,7 @@ public abstract class ColumnVector implements AutoCloseable { /** * Returns the array at rowid. */ - public final ColumnVector.Array getArray(int rowId) { + public final ColumnarArray getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -376,7 +190,7 @@ public abstract class ColumnVector implements AutoCloseable { /** * Loads the data into array.byteArray. */ - public abstract void loadBytes(ColumnVector.Array array); + public abstract void loadBytes(ColumnarArray array); /** * Returns the value for rowId. @@ -423,12 +237,12 @@ public abstract class ColumnVector implements AutoCloseable { /** * Reusable Array holder for getArray(). */ - protected ColumnVector.Array resultArray; + protected ColumnarArray resultArray; /** * Reusable Struct holder for getStruct(). */ - protected ColumnarBatch.Row resultStruct; + protected ColumnarRow resultStruct; /** * The Dictionary for this column. http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index adb859e..b4b5f0a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -98,7 +98,7 @@ public class ColumnVectorUtils { * For example, an array of IntegerType will return an int[]. * Throws exceptions for unhandled schemas. */ - public static Object toPrimitiveJavaArray(ColumnVector.Array array) { + public static Object toPrimitiveJavaArray(ColumnarArray array) { DataType dt = array.data.dataType(); if (dt instanceof IntegerType) { int[] result = new int[array.length]; http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java new file mode 100644 index 0000000..5e88ce0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Array abstraction in {@link ColumnVector}. The instance of this class is intended + * to be reused, callers should copy the data out if it needs to be stored. + */ +public final class ColumnarArray extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + public final ColumnVector data; + public int length; + public int offset; + + // Populate if binary data is required for the Array. This is stored here as an optimization + // for string data. + public byte[] byteArray; + public int byteArrayOffset; + + // Reused staging buffer, used for loading from offheap. + protected byte[] tmpByteArray = new byte[1]; + + protected ColumnarArray(ColumnVector data) { + this.data = data; + } + + @Override + public int numElements() { + return length; + } + + @Override + public ArrayData copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } + + @Override + public byte[] toByteArray() { return data.getBytes(offset, length); } + + @Override + public short[] toShortArray() { return data.getShorts(offset, length); } + + @Override + public int[] toIntArray() { return data.getInts(offset, length); } + + @Override + public long[] toLongArray() { return data.getLongs(offset, length); } + + @Override + public float[] toFloatArray() { return data.getFloats(offset, length); } + + @Override + public double[] toDoubleArray() { return data.getDoubles(offset, length); } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch(Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index bc546c7..8849a20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,17 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; import java.util.*; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.sql.types.StructType; /** * This class is the in memory representation of rows as they are streamed through operators. It @@ -48,7 +40,7 @@ public final class ColumnarBatch { private final StructType schema; private final int capacity; private int numRows; - private final ColumnVector[] columns; + final ColumnVector[] columns; // True if the row is filtered. private final boolean[] filteredRows; @@ -60,7 +52,7 @@ public final class ColumnarBatch { private int numRowsFiltered = 0; // Staging row returned from getRow. - final Row row; + final ColumnarRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -73,312 +65,12 @@ public final class ColumnarBatch { } /** - * Adapter class to interop with existing components that expect internal row. A lot of - * performance is lost with this translation. - */ - public static final class Row extends InternalRow { - protected int rowId; - private final ColumnarBatch parent; - private final int fixedLenRowSize; - private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; - - // Ctor used if this is a top level row. - private Row(ColumnarBatch parent) { - this.parent = parent; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); - this.columns = parent.columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - - // Ctor used if this is a struct. - protected Row(ColumnVector[] columns) { - this.parent = null; - this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); - this.columns = columns; - this.writableColumns = new WritableColumnVector[this.columns.length]; - for (int i = 0; i < this.columns.length; i++) { - if (this.columns[i] instanceof WritableColumnVector) { - this.writableColumns[i] = (WritableColumnVector) this.columns[i]; - } - } - } - - /** - * Marks this row as being filtered out. This means a subsequent iteration over the rows - * in this batch will not include this row. - */ - public void markFiltered() { - parent.markFiltered(rowId); - } - - public ColumnVector[] columns() { return columns; } - - @Override - public int numFields() { return columns.length; } - - @Override - /** - * Revisit this. This is expensive. This is currently only used in test paths. - */ - public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(columns.length); - for (int i = 0; i < numFields(); i++) { - if (isNullAt(i)) { - row.setNullAt(i); - } else { - DataType dt = columns[i].dataType(); - if (dt instanceof BooleanType) { - row.setBoolean(i, getBoolean(i)); - } else if (dt instanceof ByteType) { - row.setByte(i, getByte(i)); - } else if (dt instanceof ShortType) { - row.setShort(i, getShort(i)); - } else if (dt instanceof IntegerType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof LongType) { - row.setLong(i, getLong(i)); - } else if (dt instanceof FloatType) { - row.setFloat(i, getFloat(i)); - } else if (dt instanceof DoubleType) { - row.setDouble(i, getDouble(i)); - } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i).copy()); - } else if (dt instanceof BinaryType) { - row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; - row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); - } else if (dt instanceof DateType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof TimestampType) { - row.setLong(i, getLong(i)); - } else { - throw new RuntimeException("Not implemented. " + dt); - } - } - } - return row; - } - - @Override - public boolean anyNull() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } - - @Override - public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } - - @Override - public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } - - @Override - public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } - - @Override - public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } - - @Override - public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } - - @Override - public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } - - @Override - public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getDecimal(rowId, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getUTF8String(rowId); - } - - @Override - public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getBinary(rowId); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); - return new CalendarInterval(months, microseconds); - } - - @Override - public InternalRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getStruct(rowId); - } - - @Override - public ArrayData getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; - return columns[ordinal].getArray(rowId); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { - if (value == null) { - setNullAt(ordinal); - } else { - DataType dt = columns[ordinal].dataType(); - if (dt instanceof BooleanType) { - setBoolean(ordinal, (boolean) value); - } else if (dt instanceof IntegerType) { - setInt(ordinal, (int) value); - } else if (dt instanceof ShortType) { - setShort(ordinal, (short) value); - } else if (dt instanceof LongType) { - setLong(ordinal, (long) value); - } else if (dt instanceof FloatType) { - setFloat(ordinal, (float) value); - } else if (dt instanceof DoubleType) { - setDouble(ordinal, (double) value); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType) dt; - setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), - t.precision()); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dt); - } - } - } - - @Override - public void setNullAt(int ordinal) { - getWritableColumn(ordinal).putNull(rowId); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putBoolean(rowId, value); - } - - @Override - public void setByte(int ordinal, byte value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putByte(rowId, value); - } - - @Override - public void setShort(int ordinal, short value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putShort(rowId, value); - } - - @Override - public void setInt(int ordinal, int value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putInt(rowId, value); - } - - @Override - public void setLong(int ordinal, long value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putLong(rowId, value); - } - - @Override - public void setFloat(int ordinal, float value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putFloat(rowId, value); - } - - @Override - public void setDouble(int ordinal, double value) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDouble(rowId, value); - } - - @Override - public void setDecimal(int ordinal, Decimal value, int precision) { - WritableColumnVector column = getWritableColumn(ordinal); - column.putNotNull(rowId); - column.putDecimal(rowId, value, precision); - } - - private WritableColumnVector getWritableColumn(int ordinal) { - WritableColumnVector column = writableColumns[ordinal]; - assert (!column.isConstant); - return column; - } - } - - /** * Returns an iterator over the rows in this batch. This skips rows that are filtered out. */ - public Iterator<Row> rowIterator() { + public Iterator<ColumnarRow> rowIterator() { final int maxRows = ColumnarBatch.this.numRows(); - final Row row = new Row(this); - return new Iterator<Row>() { + final ColumnarRow row = new ColumnarRow(this); + return new Iterator<ColumnarRow>() { int rowId = 0; @Override @@ -390,7 +82,7 @@ public final class ColumnarBatch { } @Override - public Row next() { + public ColumnarRow next() { while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { ++rowId; } @@ -491,7 +183,7 @@ public final class ColumnarBatch { /** * Returns the row in this batch at `rowId`. Returned row is reused across calls. */ - public ColumnarBatch.Row getRow(int rowId) { + public ColumnarRow getRow(int rowId) { assert(rowId >= 0); assert(rowId < numRows); row.rowId = rowId; @@ -522,6 +214,6 @@ public final class ColumnarBatch { this.capacity = capacity; this.nullFilteredColumns = new HashSet<>(); this.filteredRows = new boolean[capacity]; - this.row = new Row(this); + this.row = new ColumnarRow(this); } } http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java new file mode 100644 index 0000000..c75adaf --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Row abstraction in {@link ColumnVector}. The instance of this class is intended + * to be reused, callers should copy the data out if it needs to be stored. + */ +public final class ColumnarRow extends InternalRow { + protected int rowId; + private final ColumnarBatch parent; + private final int fixedLenRowSize; + private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; + + // Ctor used if this is a top level row. + ColumnarRow(ColumnarBatch parent) { + this.parent = parent; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } + } + + // Ctor used if this is a struct. + ColumnarRow(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } + } + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public void markFiltered() { + parent.markFiltered(rowId); + } + + public ColumnVector[] columns() { return columns; } + + @Override + public int numFields() { return columns.length; } + + /** + * Revisit this. This is expensive. This is currently only used in test paths. + */ + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), + t.precision()); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + getWritableColumn(ordinal).putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDecimal(rowId, value, precision); + } + + private WritableColumnVector getWritableColumn(int ordinal) { + WritableColumnVector column = writableColumns[ordinal]; + assert (!column.isConstant); + return column; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7522eb..2bf523b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -523,7 +523,7 @@ public final class OffHeapColumnVector extends WritableColumnVector { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; Platform.copyMemory( null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 166a39e..d699d29 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -494,7 +494,7 @@ public final class OnHeapColumnVector extends WritableColumnVector { } @Override - public void loadBytes(ColumnVector.Array array) { + public void loadBytes(ColumnarArray array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d3a14b9..96cfeed 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -283,8 +283,8 @@ public abstract class WritableColumnVector extends ColumnVector { /** * Returns the value for rowId. */ - private ColumnVector.Array getByteArray(int rowId) { - ColumnVector.Array array = getArray(rowId); + private ColumnarArray getByteArray(int rowId) { + ColumnarArray array = getArray(rowId); array.data.loadBytes(array); return array; } @@ -324,7 +324,7 @@ public abstract class WritableColumnVector extends ColumnVector { @Override public UTF8String getUTF8String(int rowId) { if (dictionary == null) { - ColumnVector.Array a = getByteArray(rowId); + ColumnarArray a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } else { byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); @@ -338,7 +338,7 @@ public abstract class WritableColumnVector extends ColumnVector { @Override public byte[] getBinary(int rowId) { if (dictionary == null) { - ColumnVector.Array array = getByteArray(rowId); + ColumnarArray array = getByteArray(rowId); byte[] bytes = new byte[array.length]; System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); return bytes; @@ -685,7 +685,7 @@ public abstract class WritableColumnVector extends ColumnVector { } this.childColumns = new WritableColumnVector[1]; this.childColumns[0] = reserveNewColumn(childCapacity, childType); - this.resultArray = new ColumnVector.Array(this.childColumns[0]); + this.resultArray = new ColumnarArray(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; @@ -694,14 +694,14 @@ public abstract class WritableColumnVector extends ColumnVector { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); + this.resultStruct = new ColumnarRow(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); + this.resultStruct = new ColumnarRow(this.childColumns); } else { this.childColumns = null; this.resultArray = null; http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2a208a2..51f7c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -595,7 +595,7 @@ case class HashAggregateExec( ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( - "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>", + "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow>", iterTermForFastHashMap, "") } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, @@ -681,7 +681,7 @@ case class HashAggregateExec( """ } - // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow + // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow def outputFromVectorizedMap: String = { val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null @@ -697,8 +697,8 @@ case class HashAggregateExec( s""" | while ($iterTermForFastHashMap.hasNext()) { | $numOutput.add(1); - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = - | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | org.apache.spark.sql.execution.vectorized.ColumnarRow $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarRow) | $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -892,7 +892,7 @@ case class HashAggregateExec( ${ if (isVectorizedHashMapEnabled) { s""" - | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $fastRowBuffer = null; + | org.apache.spark.sql.execution.vectorized.ColumnarRow $fastRowBuffer = null; """.stripMargin } else { s""" http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 812d405..fd783d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ /** @@ -142,14 +142,14 @@ class VectorizedHashMapGenerator( /** * Generates a method that returns a mutable - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * [[org.apache.spark.sql.execution.vectorized.ColumnarRow]] which keeps track of the * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the * generated method adds the corresponding row in the associated * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we * have 2 long group-by keys, the generated function would be of the form: * * {{{ - * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert( * long agg_key, long agg_key1) { * long h = hash(agg_key, agg_key1); * int step = 0; @@ -189,7 +189,7 @@ class VectorizedHashMapGenerator( } s""" - |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + |public org.apache.spark.sql.execution.vectorized.ColumnarRow findOrInsert(${ groupingKeySignature}) { | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; @@ -229,7 +229,7 @@ class VectorizedHashMapGenerator( protected def generateRowIterator(): String = { s""" - |public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row> + |public java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow> | rowIterator() { | return batch.rowIterator(); |} http://git-wip-us.apache.org/repos/asf/spark/blob/dce1610a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index c5c8ae3..3c76ca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -57,7 +57,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendBoolean(i % 2 == 0) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, BooleanType) === (i % 2 == 0)) @@ -69,7 +69,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByte(i.toByte) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, ByteType) === i.toByte) @@ -81,7 +81,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendShort(i.toShort) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, ShortType) === i.toShort) @@ -93,7 +93,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendInt(i) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, IntegerType) === i) @@ -105,7 +105,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendLong(i) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, LongType) === i) @@ -117,7 +117,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendFloat(i.toFloat) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, FloatType) === i.toFloat) @@ -129,7 +129,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendDouble(i.toDouble) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, DoubleType) === i.toDouble) @@ -142,7 +142,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) @@ -155,7 +155,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.appendByteArray(utf8, 0, utf8.length) } - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 10).foreach { i => val utf8 = s"str$i".getBytes("utf8") @@ -179,12 +179,12 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.putArray(2, 3, 0) testVector.putArray(3, 3, 3) - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) - assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0)) - assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) - assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) - assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + assert(array.getArray(0).toIntArray() === Array(0)) + assert(array.getArray(1).toIntArray() === Array(1, 2)) + assert(array.getArray(2).toIntArray() === Array.empty[Int]) + assert(array.getArray(3).toIntArray() === Array(3, 4, 5)) } val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) @@ -196,12 +196,12 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { c1.putInt(1, 456) c2.putDouble(1, 5.67) - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) - assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) - assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) - assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) - assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + assert(array.getStruct(0, structType.length).get(0, IntegerType) === 123) + assert(array.getStruct(0, structType.length).get(1, DoubleType) === 3.45) + assert(array.getStruct(1, structType.length).get(0, IntegerType) === 456) + assert(array.getStruct(1, structType.length).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { @@ -214,7 +214,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { testVector.reserve(16) // Check that none of the values got lost/overwritten. - val array = new ColumnVector.Array(testVector) + val array = new ColumnarArray(testVector) (0 until 8).foreach { i => assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
