Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5725#discussion_r29269216
  
    --- Diff: 
sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
 ---
    @@ -0,0 +1,251 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.catalyst.expressions;
    +
    +import java.util.Arrays;
    +import java.util.Iterator;
    +
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.map.BytesToBytesMap;
    +import org.apache.spark.unsafe.memory.MemoryLocation;
    +import org.apache.spark.unsafe.memory.MemoryManager;
    +
    +/**
    + * Unsafe-based HashMap for performing aggregations where the aggregated 
values are fixed-width.
    + *
    + * This map supports a maximum of 2 billion keys.
    + */
    +public final class UnsafeFixedWidthAggregationMap {
    +
    +  /**
    +   * An empty aggregation buffer, encoded in UnsafeRow format. When 
inserting a new key into the
    +   * map, we copy this buffer and use it as the value.
    +   */
    +  private final long[] emptyAggregationBuffer;
    +
    +  private final StructType aggregationBufferSchema;
    +
    +  private final StructType groupingKeySchema;
    +
    +  /**
    +   * Encodes grouping keys as UnsafeRows.
    +   */
    +  private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
    +
    +  /**
    +   * A hashmap which maps from opaque bytearray keys to bytearray values.
    +   */
    +  private final BytesToBytesMap map;
    +
    +  /**
    +   * Re-used pointer to the current aggregation buffer
    +   */
    +  private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
    +
    +  /**
    +   * Scratch space that is used when encoding grouping keys into UnsafeRow 
format.
    +   *
    +   * By default, this is a 1MB array, but it will grow as necessary in 
case larger keys are
    +   * encountered.
    +   */
    +  private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];
    +
    +  private final boolean enablePerfMetrics;
    +
    +  /**
    +   * @return true if UnsafeFixedWidthAggregationMap supports grouping keys 
with the given schema,
    +   *         false otherwise.
    +   */
    +  public static boolean supportsGroupKeySchema(StructType schema) {
    +    for (StructField field: schema.fields()) {
    +      if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
    +        return false;
    +      }
    +    }
    +    return true;
    +  }
    +
    +  /**
    +   * @return true if UnsafeFixedWidthAggregationMap supports aggregation 
buffers with the given
    +   *         schema, false otherwise.
    +   */
    +  public static boolean supportsAggregationBufferSchema(StructType schema) 
{
    +    for (StructField field: schema.fields()) {
    +      if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
    +        return false;
    +      }
    +    }
    +    return true;
    +  }
    +
    +  /**
    +   * Create a new UnsafeFixedWidthAggregationMap.
    +   *
    +   * @param emptyAggregationBuffer the default value for new keys (a 
"zero" of the agg. function)
    +   * @param aggregationBufferSchema the schema of the aggregation buffer, 
used for row conversion.
    +   * @param groupingKeySchema the schema of the grouping key, used for row 
conversion.
    +   * @param groupingKeySchema the memory manager used to allocate our 
Unsafe memory structures.
    +   * @param initialCapacity the initial capacity of the map (a sizing hint 
to avoid re-hashing).
    +   * @param enablePerfMetrics if true, performance metrics will be 
recorded (has minor perf impact)
    +   */
    +  public UnsafeFixedWidthAggregationMap(
    +      Row emptyAggregationBuffer,
    +      StructType aggregationBufferSchema,
    +      StructType groupingKeySchema,
    +      MemoryManager memoryManager,
    +      int initialCapacity,
    +      boolean enablePerfMetrics) {
    +    this.emptyAggregationBuffer =
    +      convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
    +    this.aggregationBufferSchema = aggregationBufferSchema;
    +    this.groupingKeyToUnsafeRowConverter = new 
UnsafeRowConverter(groupingKeySchema);
    +    this.groupingKeySchema = groupingKeySchema;
    +    this.map = new BytesToBytesMap(memoryManager, initialCapacity, 
enablePerfMetrics);
    +    this.enablePerfMetrics = enablePerfMetrics;
    +  }
    +
    +  /**
    +   * Convert a Java object row into an UnsafeRow, allocating it into a new 
long array.
    +   */
    +  private static long[] convertToUnsafeRow(Row javaRow, StructType schema) 
{
    +    final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
    +    final long[] unsafeRow = new 
long[converter.getSizeRequirement(javaRow)];
    +    final long writtenLength =
    +      converter.writeRow(javaRow, unsafeRow, 
PlatformDependent.LONG_ARRAY_OFFSET);
    +    assert (writtenLength == unsafeRow.length): "Size requirement 
calculation was wrong!";
    +    return unsafeRow;
    +  }
    +
    +  /**
    +   * Return the aggregation buffer for the current group. For efficiency, 
all calls to this method
    +   * return the same object.
    +   */
    +  public UnsafeRow getAggregationBuffer(Row groupingKey) {
    +    // Zero out the buffer that's used to hold the current row. This is 
necessary in order
    +    // to ensure that rows hash properly, since garbage data from the 
previous row could
    +    // otherwise end up as padding in this row.
    +    Arrays.fill(groupingKeyConversionScratchSpace, 0);
    --- End diff --
    
    Yes.  I think that `Array.fill` is reasonably fast because it can be 
translated into SIMD operations, but it still doesn't hurt to zero fewer bytes. 
 This could actually matter in cases where we grow the size of 
`groupingKeyConversionScratchSpace` in responses to a single large key, since 
the current code could wind up zeroing much more space than is actually needed. 
 I'll update this.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to