This is an automated email from the ASF dual-hosted git repository.
cwylie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new f524c68f08 Add mechanism for 'safe' memory reads for complex types
(#13361)
f524c68f08 is described below
commit f524c68f08eb58047064e8234d320b51e83499ce
Author: Clint Wylie <[email protected]>
AuthorDate: Wed Nov 23 00:25:22 2022 -0800
Add mechanism for 'safe' memory reads for complex types (#13361)
* we can read where we want to
we can leave your bounds behind
'cause if the memory is not there
we really don't care
and we'll crash this process of mine
---
.../hll/HllSketchMergeComplexMetricSerde.java | 15 +-
.../datasketches/hll/HllSketchObjectStrategy.java | 10 +
.../kll/KllDoublesSketchComplexMetricSerde.java | 2 +-
.../kll/KllDoublesSketchObjectStrategy.java | 13 +
.../kll/KllDoublesSketchOperations.java | 21 +
.../kll/KllFloatsSketchComplexMetricSerde.java | 2 +-
.../kll/KllFloatsSketchObjectStrategy.java | 13 +
.../kll/KllFloatsSketchOperations.java | 21 +
.../quantiles/DoublesSketchComplexMetricSerde.java | 2 +-
.../quantiles/DoublesSketchObjectStrategy.java | 13 +
.../quantiles/DoublesSketchOperations.java | 20 +
.../theta/SketchConstantPostAggregator.java | 2 +-
.../datasketches/theta/SketchHolder.java | 22 +
.../theta/SketchHolderObjectStrategy.java | 14 +
.../theta/SketchMergeComplexMetricSerde.java | 2 +-
...rrayOfDoublesSketchMergeComplexMetricSerde.java | 2 +-
.../tuple/ArrayOfDoublesSketchObjectStrategy.java | 13 +-
.../tuple/ArrayOfDoublesSketchOperations.java | 24 +-
.../hll/HllSketchObjectStrategyTest.java | 77 ++++
.../KllDoublesSketchComplexMetricSerdeTest.java | 44 ++
.../kll/KllDoublesSketchOperationsTest.java | 51 +++
.../kll/KllFloatsSketchComplexMetricSerdeTest.java | 44 ++
.../kll/KllFloatsSketchOperationsTest.java | 51 +++
.../DoublesSketchComplexMetricSerdeTest.java | 43 ++
.../quantiles/DoublesSketchOperationsTest.java | 50 ++
.../theta/SketchHolderObjectStrategyTest.java | 79 ++++
.../datasketches/theta/SketchHolderTest.java | 52 +++
.../ArrayOfDoublesSketchObjectStrategyTest.java | 70 +++
.../tuple/ArrayOfDoublesSketchOperationsTest.java | 55 +++
.../column/ObjectStrategyComplexTypeStrategy.java | 2 +-
.../apache/druid/segment/data/ObjectStrategy.java | 27 ++
.../druid/segment/data/SafeWritableBase.java | 450 ++++++++++++++++++
.../druid/segment/data/SafeWritableBuffer.java | 501 +++++++++++++++++++++
.../druid/segment/data/SafeWritableMemory.java | 417 +++++++++++++++++
.../druid/segment/data/SafeWritableBufferTest.java | 224 +++++++++
.../druid/segment/data/SafeWritableMemoryTest.java | 359 +++++++++++++++
36 files changed, 2796 insertions(+), 11 deletions(-)
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java
index c8ac48ab18..1063bbdfec 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchMergeComplexMetricSerde.java
@@ -28,6 +28,7 @@ import org.apache.druid.segment.GenericColumnSerializer;
import org.apache.druid.segment.column.ColumnBuilder;
import org.apache.druid.segment.data.GenericIndexed;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
import org.apache.druid.segment.serde.ComplexColumnPartSupplier;
import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.apache.druid.segment.serde.ComplexMetricSerde;
@@ -70,7 +71,7 @@ public class HllSketchMergeComplexMetricSerde extends
ComplexMetricSerde
if (object == null) {
return null;
}
- return deserializeSketch(object);
+ return deserializeSketchSafe(object);
}
};
}
@@ -98,6 +99,18 @@ public class HllSketchMergeComplexMetricSerde extends
ComplexMetricSerde
throw new IAE("Object is not of a type that can be deserialized to an
HllSketch:" + object.getClass().getName());
}
+ static HllSketch deserializeSketchSafe(final Object object)
+ {
+ if (object instanceof String) {
+ return
HllSketch.wrap(SafeWritableMemory.wrap(StringUtils.decodeBase64(((String)
object).getBytes(StandardCharsets.UTF_8))));
+ } else if (object instanceof byte[]) {
+ return HllSketch.wrap(SafeWritableMemory.wrap((byte[]) object));
+ } else if (object instanceof HllSketch) {
+ return (HllSketch) object;
+ }
+ throw new IAE("Object is not of a type that can be deserialized to an
HllSketch:" + object.getClass().getName());
+ }
+
// support large columns
@Override
public GenericColumnSerializer getSerializer(final SegmentWriteOutMedium
segmentWriteOutMedium, final String column)
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java
index 34145863fd..65257b22b7 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategy.java
@@ -22,7 +22,9 @@ package org.apache.druid.query.aggregation.datasketches.hll;
import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
+import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -55,4 +57,12 @@ public class HllSketchObjectStrategy implements
ObjectStrategy<HllSketch>
return sketch.toCompactByteArray();
}
+ @Nullable
+ @Override
+ public HllSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
+ {
+ return HllSketch.wrap(
+ SafeWritableMemory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java
index 4c18a97856..e5249853ac 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerde.java
@@ -91,7 +91,7 @@ public class KllDoublesSketchComplexMetricSerde extends
ComplexMetricSerde
if (object == null || object instanceof KllDoublesSketch || object
instanceof Memory) {
return object;
}
- return KllDoublesSketchOperations.deserialize(object);
+ return KllDoublesSketchOperations.deserializeSafe(object);
}
};
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java
index 97e670a625..17cb94e2fc 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchObjectStrategy.java
@@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
+import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -60,4 +62,15 @@ public class KllDoublesSketchObjectStrategy implements
ObjectStrategy<KllDoubles
return sketch.toByteArray();
}
+ @Nullable
+ @Override
+ public KllDoublesSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
+ {
+ if (numBytes == 0) {
+ return KllDoublesSketchOperations.EMPTY_SKETCH;
+ }
+ return KllDoublesSketch.wrap(
+ SafeWritableMemory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperations.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperations.java
index 57cb517471..6da454d7f8 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperations.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperations.java
@@ -23,6 +23,7 @@ import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets;
@@ -46,6 +47,16 @@ public class KllDoublesSketchOperations
);
}
+ public static KllDoublesSketch deserializeSafe(final Object serializedSketch)
+ {
+ if (serializedSketch instanceof String) {
+ return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
+ } else if (serializedSketch instanceof byte[]) {
+ return deserializeFromByteArraySafe((byte[]) serializedSketch);
+ }
+ return deserialize(serializedSketch);
+ }
+
public static KllDoublesSketch deserializeFromBase64EncodedString(final
String str)
{
return
deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@@ -56,4 +67,14 @@ public class KllDoublesSketchOperations
return KllDoublesSketch.wrap(Memory.wrap(data));
}
+ public static KllDoublesSketch deserializeFromBase64EncodedStringSafe(final
String str)
+ {
+ return
deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
+ }
+
+ public static KllDoublesSketch deserializeFromByteArraySafe(final byte[]
data)
+ {
+ return KllDoublesSketch.wrap(SafeWritableMemory.wrap(data));
+ }
+
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerde.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerde.java
index 4a71befe0c..175b307ec3 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerde.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerde.java
@@ -91,7 +91,7 @@ public class KllFloatsSketchComplexMetricSerde extends
ComplexMetricSerde
if (object == null || object instanceof KllFloatsSketch || object
instanceof Memory) {
return object;
}
- return KllFloatsSketchOperations.deserialize(object);
+ return KllFloatsSketchOperations.deserializeSafe(object);
}
};
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchObjectStrategy.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchObjectStrategy.java
index ff177a2f54..93ff0a7dba 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchObjectStrategy.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchObjectStrategy.java
@@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
+import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -60,4 +62,15 @@ public class KllFloatsSketchObjectStrategy implements
ObjectStrategy<KllFloatsSk
return sketch.toByteArray();
}
+ @Nullable
+ @Override
+ public KllFloatsSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
+ {
+ if (numBytes == 0) {
+ return KllFloatsSketchOperations.EMPTY_SKETCH;
+ }
+ return KllFloatsSketch.wrap(
+ SafeWritableMemory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperations.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperations.java
index e32b67b254..02fb615da4 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperations.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperations.java
@@ -23,6 +23,7 @@ import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets;
@@ -46,6 +47,16 @@ public class KllFloatsSketchOperations
);
}
+ public static KllFloatsSketch deserializeSafe(final Object serializedSketch)
+ {
+ if (serializedSketch instanceof String) {
+ return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
+ } else if (serializedSketch instanceof byte[]) {
+ return deserializeFromByteArraySafe((byte[]) serializedSketch);
+ }
+ return deserialize(serializedSketch);
+ }
+
public static KllFloatsSketch deserializeFromBase64EncodedString(final
String str)
{
return
deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@@ -56,4 +67,14 @@ public class KllFloatsSketchOperations
return KllFloatsSketch.wrap(Memory.wrap(data));
}
+ public static KllFloatsSketch deserializeFromBase64EncodedStringSafe(final
String str)
+ {
+ return
deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
+ }
+
+ public static KllFloatsSketch deserializeFromByteArraySafe(final byte[] data)
+ {
+ return KllFloatsSketch.wrap(SafeWritableMemory.wrap(data));
+ }
+
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerde.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerde.java
index d97b5f8c6d..3614f214c7 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerde.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerde.java
@@ -92,7 +92,7 @@ public class DoublesSketchComplexMetricSerde extends
ComplexMetricSerde
if (object == null || object instanceof DoublesSketch || object
instanceof Memory) {
return object;
}
- return DoublesSketchOperations.deserialize(object);
+ return DoublesSketchOperations.deserializeSafe(object);
}
};
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchObjectStrategy.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchObjectStrategy.java
index 826de9378f..569b60bf03 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchObjectStrategy.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchObjectStrategy.java
@@ -23,7 +23,9 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.quantiles.DoublesSketch;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
+import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -60,4 +62,15 @@ public class DoublesSketchObjectStrategy implements
ObjectStrategy<DoublesSketch
return sketch.toByteArray(true);
}
+ @Nullable
+ @Override
+ public DoublesSketch fromByteBufferSafe(ByteBuffer buffer, int numBytes)
+ {
+ if (numBytes == 0) {
+ return DoublesSketchOperations.EMPTY_SKETCH;
+ }
+ return DoublesSketch.wrap(
+ SafeWritableMemory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperations.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperations.java
index e30fb9bdae..a2ca197c11 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperations.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperations.java
@@ -23,6 +23,7 @@ import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.quantiles.DoublesSketch;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets;
@@ -46,6 +47,16 @@ public class DoublesSketchOperations
);
}
+ public static DoublesSketch deserializeSafe(final Object serializedSketch)
+ {
+ if (serializedSketch instanceof String) {
+ return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
+ } else if (serializedSketch instanceof byte[]) {
+ return deserializeFromByteArraySafe((byte[]) serializedSketch);
+ }
+ return deserialize(serializedSketch);
+ }
+
public static DoublesSketch deserializeFromBase64EncodedString(final String
str)
{
return
deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@@ -56,4 +67,13 @@ public class DoublesSketchOperations
return DoublesSketch.wrap(Memory.wrap(data));
}
+ public static DoublesSketch deserializeFromBase64EncodedStringSafe(final
String str)
+ {
+ return
deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
+ }
+
+ public static DoublesSketch deserializeFromByteArraySafe(final byte[] data)
+ {
+ return DoublesSketch.wrap(SafeWritableMemory.wrap(data));
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchConstantPostAggregator.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchConstantPostAggregator.java
index b3541bd506..64c182a0d6 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchConstantPostAggregator.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchConstantPostAggregator.java
@@ -51,7 +51,7 @@ public class SketchConstantPostAggregator implements
PostAggregator
Preconditions.checkArgument(value != null && !value.isEmpty(),
"Constant value cannot be null or empty, expecting base64 encoded
sketch string");
this.value = value;
- this.sketchValue = SketchHolder.deserialize(value);
+ this.sketchValue = SketchHolder.deserializeSafe(value);
}
@Override
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolder.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolder.java
index 59ca453bb2..838b4ae91f 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolder.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolder.java
@@ -34,6 +34,7 @@ import org.apache.datasketches.theta.Union;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
@@ -224,6 +225,17 @@ public class SketchHolder
);
}
+ public static SketchHolder deserializeSafe(Object serializedSketch)
+ {
+ if (serializedSketch instanceof String) {
+ return SketchHolder.of(deserializeFromBase64EncodedStringSafe((String)
serializedSketch));
+ } else if (serializedSketch instanceof byte[]) {
+ return SketchHolder.of(deserializeFromByteArraySafe((byte[])
serializedSketch));
+ }
+
+ return deserialize(serializedSketch);
+ }
+
private static Sketch deserializeFromBase64EncodedString(String str)
{
return
deserializeFromByteArray(StringUtils.decodeBase64(StringUtils.toUtf8(str)));
@@ -234,6 +246,16 @@ public class SketchHolder
return deserializeFromMemory(Memory.wrap(data));
}
+ private static Sketch deserializeFromBase64EncodedStringSafe(String str)
+ {
+ return
deserializeFromByteArraySafe(StringUtils.decodeBase64(StringUtils.toUtf8(str)));
+ }
+
+ private static Sketch deserializeFromByteArraySafe(byte[] data)
+ {
+ return deserializeFromMemory(SafeWritableMemory.wrap(data));
+ }
+
private static Sketch deserializeFromMemory(Memory mem)
{
if (Sketch.getSerializationVersion(mem) < 3) {
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategy.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategy.java
index e98bc3d95a..96fafe8262 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategy.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategy.java
@@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.bytes.ByteArrays;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.Sketch;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -66,4 +67,17 @@ public class SketchHolderObjectStrategy implements
ObjectStrategy<SketchHolder>
return ByteArrays.EMPTY_ARRAY;
}
}
+
+ @Nullable
+ @Override
+ public SketchHolder fromByteBufferSafe(ByteBuffer buffer, int numBytes)
+ {
+ if (numBytes == 0) {
+ return SketchHolder.EMPTY;
+ }
+
+ return SketchHolder.of(
+ SafeWritableMemory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java
index a824312c0e..4f3ecfae29 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/SketchMergeComplexMetricSerde.java
@@ -59,7 +59,7 @@ public class SketchMergeComplexMetricSerde extends
ComplexMetricSerde
public SketchHolder extractValue(InputRow inputRow, String metricName)
{
final Object object = inputRow.getRaw(metricName);
- return object == null ? null : SketchHolder.deserialize(object);
+ return object == null ? null : SketchHolder.deserializeSafe(object);
}
};
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java
index 19c8da292b..028bcdc354 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchMergeComplexMetricSerde.java
@@ -60,7 +60,7 @@ public class ArrayOfDoublesSketchMergeComplexMetricSerde
extends ComplexMetricSe
if (object == null || object instanceof ArrayOfDoublesSketch) {
return object;
}
- return ArrayOfDoublesSketchOperations.deserialize(object);
+ return ArrayOfDoublesSketchOperations.deserializeSafe(object);
}
};
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java
index 1ae950e068..f893c83b57 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategy.java
@@ -23,6 +23,7 @@ import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketch;
import org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesSketches;
import org.apache.druid.segment.data.ObjectStrategy;
+import org.apache.druid.segment.data.SafeWritableMemory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -48,7 +49,9 @@ public class ArrayOfDoublesSketchObjectStrategy implements
ObjectStrategy<ArrayO
@Override
public ArrayOfDoublesSketch fromByteBuffer(final ByteBuffer buffer, final
int numBytes)
{
- return ArrayOfDoublesSketches.wrapSketch(Memory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes));
+ return ArrayOfDoublesSketches.wrapSketch(
+ Memory.wrap(buffer, ByteOrder.LITTLE_ENDIAN).region(buffer.position(),
numBytes)
+ );
}
@Override
@@ -61,4 +64,12 @@ public class ArrayOfDoublesSketchObjectStrategy implements
ObjectStrategy<ArrayO
return sketch.toByteArray();
}
+ @Nullable
+ @Override
+ public ArrayOfDoublesSketch fromByteBufferSafe(ByteBuffer buffer, int
numBytes)
+ {
+ return ArrayOfDoublesSketches.wrapSketch(
+ SafeWritableMemory.wrap(buffer,
ByteOrder.LITTLE_ENDIAN).region(buffer.position(), numBytes)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperations.java
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperations.java
index b1658a9957..2768858ffe 100644
---
a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperations.java
+++
b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperations.java
@@ -30,6 +30,7 @@ import
org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUnion;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.data.SafeWritableMemory;
import java.nio.charset.StandardCharsets;
@@ -115,6 +116,17 @@ public class ArrayOfDoublesSketchOperations
throw new ISE("Object is not of a type that can deserialize to sketch:
%s", serializedSketch.getClass());
}
+ public static ArrayOfDoublesSketch deserializeSafe(final Object
serializedSketch)
+ {
+ if (serializedSketch instanceof String) {
+ return deserializeFromBase64EncodedStringSafe((String) serializedSketch);
+ } else if (serializedSketch instanceof byte[]) {
+ return deserializeFromByteArraySafe((byte[]) serializedSketch);
+ }
+
+ return deserialize(serializedSketch);
+ }
+
public static ArrayOfDoublesSketch deserializeFromBase64EncodedString(final
String str)
{
return
deserializeFromByteArray(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
@@ -122,8 +134,16 @@ public class ArrayOfDoublesSketchOperations
public static ArrayOfDoublesSketch deserializeFromByteArray(final byte[]
data)
{
- final Memory mem = Memory.wrap(data);
- return ArrayOfDoublesSketches.wrapSketch(mem);
+ return ArrayOfDoublesSketches.wrapSketch(Memory.wrap(data));
+ }
+
+ public static ArrayOfDoublesSketch
deserializeFromBase64EncodedStringSafe(final String str)
+ {
+ return
deserializeFromByteArraySafe(StringUtils.decodeBase64(str.getBytes(StandardCharsets.UTF_8)));
}
+ public static ArrayOfDoublesSketch deserializeFromByteArraySafe(final byte[]
data)
+ {
+ return ArrayOfDoublesSketches.wrapSketch(SafeWritableMemory.wrap(data));
+ }
}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategyTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategyTest.java
new file mode 100644
index 0000000000..ff1eb94740
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/HllSketchObjectStrategyTest.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.hll;
+
+import org.apache.datasketches.SketchesArgumentException;
+import org.apache.datasketches.hll.HllSketch;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class HllSketchObjectStrategyTest
+{
+ @Test
+ public void testSafeRead()
+ {
+ HllSketch sketch = new HllSketch();
+ sketch.update(new int[]{1, 2, 3});
+
+ final byte[] bytes = sketch.toCompactByteArray();
+
+ ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
+ HllSketchObjectStrategy objectStrategy = new HllSketchObjectStrategy();
+
+ // valid sketch should not explode when copied, which reads the memory
+ objectStrategy.fromByteBufferSafe(buf, bytes.length).copy();
+
+ // corrupted sketch should fail with a regular java buffer exception
+ for (int subset = 3; subset < bytes.length - 1; subset++) {
+ final byte[] garbage2 = new byte[subset];
+ for (int i = 0; i < garbage2.length; i++) {
+ garbage2[i] = buf.get(i);
+ }
+
+ final ByteBuffer buf2 =
ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf2, garbage2.length).copy()
+ );
+ }
+
+ // non sketch that is too short to contain header should fail with regular
java buffer exception
+ final byte[] garbage = new byte[]{0x01, 0x02};
+ final ByteBuffer buf3 =
ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf3, garbage.length).copy()
+ );
+
+ // non sketch that is long enough to check (this one doesn't actually need
'safe' read)
+ final byte[] garbageLonger = StringUtils.toUtf8("notasketch");
+ final ByteBuffer buf4 =
ByteBuffer.wrap(garbageLonger).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ SketchesArgumentException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf4,
garbageLonger.length).copy()
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java
index 3628c5e621..0ae46bef49 100644
---
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchComplexMetricSerdeTest.java
@@ -23,10 +23,14 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.kll.KllDoublesSketch;
import org.apache.druid.data.input.MapBasedInputRow;
+import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.junit.Assert;
import org.junit.Test;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
public class KllDoublesSketchComplexMetricSerdeTest
{
@Test
@@ -92,4 +96,44 @@ public class KllDoublesSketchComplexMetricSerdeTest
Assert.assertEquals(1, sketch.getNumRetained());
Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d);
}
+
+ @Test
+ public void testSafeRead()
+ {
+ final KllDoublesSketchComplexMetricSerde serde = new
KllDoublesSketchComplexMetricSerde();
+ final ObjectStrategy<KllDoublesSketch> objectStrategy =
serde.getObjectStrategy();
+
+ KllDoublesSketch sketch = KllDoublesSketch.newHeapInstance();
+ sketch.update(1.1);
+ sketch.update(1.2);
+ final byte[] bytes = sketch.toByteArray();
+
+ ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
+
+ // valid sketch should not explode when converted to byte array, which
reads the memory
+ objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray();
+
+ // corrupted sketch should fail with a regular java buffer exception, not
all subsets actually fail with the same
+ // index out of bounds exceptions, but at least this many do
+ for (int subset = 3; subset < 24; subset++) {
+ final byte[] garbage2 = new byte[subset];
+ for (int i = 0; i < garbage2.length; i++) {
+ garbage2[i] = buf.get(i);
+ }
+
+ final ByteBuffer buf2 =
ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf2,
garbage2.length).toByteArray()
+ );
+ }
+
+ // non sketch that is too short to contain header should fail with regular
java buffer exception
+ final byte[] garbage = new byte[]{0x01, 0x02};
+ final ByteBuffer buf3 =
ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf3,
garbage.length).toByteArray()
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperationsTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperationsTest.java
new file mode 100644
index 0000000000..d2b0e38398
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllDoublesSketchOperationsTest.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.kll;
+
+import org.apache.datasketches.kll.KllDoublesSketch;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class KllDoublesSketchOperationsTest
+{
+ @Test
+ public void testDeserializeSafe()
+ {
+ KllDoublesSketch sketch = KllDoublesSketch.newHeapInstance();
+ sketch.update(1.1);
+ sketch.update(1.2);
+ final byte[] bytes = sketch.toByteArray();
+ final String base64 = StringUtils.encodeBase64String(bytes);
+
+ Assert.assertArrayEquals(bytes,
KllDoublesSketchOperations.deserializeSafe(sketch).toByteArray());
+ Assert.assertArrayEquals(bytes,
KllDoublesSketchOperations.deserializeSafe(bytes).toByteArray());
+ Assert.assertArrayEquals(bytes,
KllDoublesSketchOperations.deserializeSafe(base64).toByteArray());
+
+ final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 20);
+ Assert.assertThrows(IndexOutOfBoundsException.class, () ->
KllDoublesSketchOperations.deserializeSafe(trunacted));
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () ->
KllDoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted))
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java
index 5ff441df1c..c6b8c31022 100644
---
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchComplexMetricSerdeTest.java
@@ -23,10 +23,14 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.druid.data.input.MapBasedInputRow;
+import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.junit.Assert;
import org.junit.Test;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
public class KllFloatsSketchComplexMetricSerdeTest
{
@Test
@@ -92,4 +96,44 @@ public class KllFloatsSketchComplexMetricSerdeTest
Assert.assertEquals(1, sketch.getNumRetained());
Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d);
}
+
+ @Test
+ public void testSafeRead()
+ {
+ final KllFloatsSketchComplexMetricSerde serde = new
KllFloatsSketchComplexMetricSerde();
+ final ObjectStrategy<KllFloatsSketch> objectStrategy =
serde.getObjectStrategy();
+
+ KllFloatsSketch sketch = KllFloatsSketch.newHeapInstance();
+ sketch.update(1.1f);
+ sketch.update(1.2f);
+ final byte[] bytes = sketch.toByteArray();
+
+ ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
+
+ // valid sketch should not explode when converted to byte array, which
reads the memory
+ objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray();
+
+ // corrupted sketch should fail with a regular java buffer exception, not
all subsets actually fail with the same
+ // index out of bounds exceptions, but at least this many do
+ for (int subset = 3; subset < 24; subset++) {
+ final byte[] garbage2 = new byte[subset];
+ for (int i = 0; i < garbage2.length; i++) {
+ garbage2[i] = buf.get(i);
+ }
+
+ final ByteBuffer buf2 =
ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf2,
garbage2.length).toByteArray()
+ );
+ }
+
+ // non sketch that is too short to contain header should fail with regular
java buffer exception
+ final byte[] garbage = new byte[]{0x01, 0x02};
+ final ByteBuffer buf3 =
ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf3,
garbage.length).toByteArray()
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperationsTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperationsTest.java
new file mode 100644
index 0000000000..613b38c660
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/kll/KllFloatsSketchOperationsTest.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.kll;
+
+import org.apache.datasketches.kll.KllFloatsSketch;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class KllFloatsSketchOperationsTest
+{
+ @Test
+ public void testDeserializeSafe()
+ {
+ KllFloatsSketch sketch = KllFloatsSketch.newHeapInstance();
+ sketch.update(1.1f);
+ sketch.update(1.2f);
+ final byte[] bytes = sketch.toByteArray();
+ final String base64 = StringUtils.encodeBase64String(bytes);
+
+ Assert.assertArrayEquals(bytes,
KllFloatsSketchOperations.deserializeSafe(sketch).toByteArray());
+ Assert.assertArrayEquals(bytes,
KllFloatsSketchOperations.deserializeSafe(bytes).toByteArray());
+ Assert.assertArrayEquals(bytes,
KllFloatsSketchOperations.deserializeSafe(base64).toByteArray());
+
+ final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 20);
+ Assert.assertThrows(IndexOutOfBoundsException.class, () ->
KllFloatsSketchOperations.deserializeSafe(trunacted));
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () ->
KllFloatsSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted))
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java
index e198c77042..7dc82baee9 100644
---
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchComplexMetricSerdeTest.java
@@ -22,11 +22,16 @@ package
org.apache.druid.query.aggregation.datasketches.quantiles;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.datasketches.quantiles.DoublesSketch;
+import org.apache.datasketches.quantiles.DoublesUnion;
import org.apache.druid.data.input.MapBasedInputRow;
+import org.apache.druid.segment.data.ObjectStrategy;
import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.junit.Assert;
import org.junit.Test;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
public class DoublesSketchComplexMetricSerdeTest
{
@Test
@@ -92,4 +97,42 @@ public class DoublesSketchComplexMetricSerdeTest
Assert.assertEquals(1, sketch.getRetainedItems());
Assert.assertEquals(0.1d, sketch.getMaxValue(), 0.01d);
}
+
+ @Test
+ public void testSafeRead()
+ {
+ final DoublesSketchComplexMetricSerde serde = new
DoublesSketchComplexMetricSerde();
+ DoublesUnion union = DoublesUnion.builder().setMaxK(1024).build();
+ union.update(1.1);
+ final byte[] bytes = union.toByteArray();
+
+ ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
+ ObjectStrategy<DoublesSketch> objectStrategy = serde.getObjectStrategy();
+
+ // valid sketch should not explode when copied, which reads the memory
+ objectStrategy.fromByteBufferSafe(buf, bytes.length).toByteArray(true);
+
+ // corrupted sketch should fail with a regular java buffer exception
+ for (int subset = 3; subset < 15; subset++) {
+ final byte[] garbage2 = new byte[subset];
+ for (int i = 0; i < garbage2.length; i++) {
+ garbage2[i] = buf.get(i);
+ }
+
+ final ByteBuffer buf2 =
ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ "i " + subset,
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf2,
garbage2.length).toByteArray(true)
+ );
+ }
+
+ // non sketch that is too short to contain header should fail with regular
java buffer exception
+ final byte[] garbage = new byte[]{0x01, 0x02};
+ final ByteBuffer buf3 =
ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf3,
garbage.length).toByteArray(true)
+ );
+ }
}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperationsTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperationsTest.java
new file mode 100644
index 0000000000..38e5d39a91
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/DoublesSketchOperationsTest.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.quantiles;
+
+import org.apache.datasketches.quantiles.DoublesUnion;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class DoublesSketchOperationsTest
+{
+ @Test
+ public void testDeserializeSafe()
+ {
+ DoublesUnion union = DoublesUnion.builder().setMaxK(1024).build();
+ union.update(1.1);
+ final byte[] bytes = union.getResult().toByteArray();
+ final String base64 = StringUtils.encodeBase64String(bytes);
+
+ Assert.assertArrayEquals(bytes,
DoublesSketchOperations.deserializeSafe(union.getResult()).toByteArray());
+ Assert.assertArrayEquals(bytes,
DoublesSketchOperations.deserializeSafe(bytes).toByteArray());
+ Assert.assertArrayEquals(bytes,
DoublesSketchOperations.deserializeSafe(base64).toByteArray());
+
+ final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 4);
+ Assert.assertThrows(IndexOutOfBoundsException.class, () ->
DoublesSketchOperations.deserializeSafe(trunacted));
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () ->
DoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64(trunacted))
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategyTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategyTest.java
new file mode 100644
index 0000000000..5619facd5f
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderObjectStrategyTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.theta;
+
+import org.apache.datasketches.Family;
+import org.apache.datasketches.SketchesArgumentException;
+import org.apache.datasketches.theta.SetOperation;
+import org.apache.datasketches.theta.Union;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class SketchHolderObjectStrategyTest
+{
+ @Test
+ public void testSafeRead()
+ {
+ SketchHolderObjectStrategy objectStrategy = new
SketchHolderObjectStrategy();
+ Union union = (Union)
SetOperation.builder().setNominalEntries(1024).build(Family.UNION);
+ union.update(1234L);
+
+ final byte[] bytes = union.getResult().toByteArray();
+
+ ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
+
+ // valid sketch should not explode when copied, which reads the memory
+ objectStrategy.fromByteBufferSafe(buf,
bytes.length).getSketch().compact().getCompactBytes();
+
+ // corrupted sketch should fail with a regular java buffer exception
+ for (int subset = 3; subset < bytes.length - 1; subset++) {
+ final byte[] garbage2 = new byte[subset];
+ for (int i = 0; i < garbage2.length; i++) {
+ garbage2[i] = buf.get(i);
+ }
+
+ final ByteBuffer buf2 =
ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf2,
garbage2.length).getSketch().compact().getCompactBytes()
+ );
+ }
+
+ // non sketch that is too short to contain header should fail with regular
java buffer exception
+ final byte[] garbage = new byte[]{0x01, 0x02};
+ final ByteBuffer buf3 =
ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf3,
garbage.length).getSketch().compact().getCompactBytes()
+ );
+
+ // non sketch that is long enough to check (this one doesn't actually need
'safe' read)
+ final byte[] garbageLonger = StringUtils.toUtf8("notasketch");
+ final ByteBuffer buf4 =
ByteBuffer.wrap(garbageLonger).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ SketchesArgumentException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf4,
garbageLonger.length).getSketch().compact().getCompactBytes()
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderTest.java
new file mode 100644
index 0000000000..ef68fdeb8c
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/SketchHolderTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.theta;
+
+import org.apache.datasketches.Family;
+import org.apache.datasketches.theta.SetOperation;
+import org.apache.datasketches.theta.Union;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class SketchHolderTest
+{
+ @Test
+ public void testDeserializeSafe()
+ {
+ Union union = (Union)
SetOperation.builder().setNominalEntries(1024).build(Family.UNION);
+ union.update(1234L);
+ final byte[] bytes = union.getResult().toByteArray();
+ final String base64 = StringUtils.encodeBase64String(bytes);
+
+ Assert.assertArrayEquals(bytes,
SketchHolder.deserializeSafe(union.getResult()).getSketch().toByteArray());
+ Assert.assertArrayEquals(bytes,
SketchHolder.deserializeSafe(bytes).getSketch().toByteArray());
+ Assert.assertArrayEquals(bytes,
SketchHolder.deserializeSafe(base64).getSketch().toByteArray());
+
+ final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 10);
+ Assert.assertThrows(IndexOutOfBoundsException.class, () ->
SketchHolder.deserializeSafe(trunacted));
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () ->
SketchHolder.deserializeSafe(StringUtils.encodeBase64String(trunacted))
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategyTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategyTest.java
new file mode 100644
index 0000000000..ee59ddf576
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchObjectStrategyTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.tuple;
+
+import
org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketch;
+import
org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketchBuilder;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class ArrayOfDoublesSketchObjectStrategyTest
+{
+ @Test
+ public void testSafeRead()
+ {
+ ArrayOfDoublesSketchObjectStrategy objectStrategy = new
ArrayOfDoublesSketchObjectStrategy();
+ ArrayOfDoublesUpdatableSketch sketch = new
ArrayOfDoublesUpdatableSketchBuilder().setNominalEntries(1024)
+
.setNumberOfValues(4)
+
.build();
+ sketch.update(1L, new double[]{1.0, 2.0, 3.0, 4.0});
+
+ final byte[] bytes = sketch.compact().toByteArray();
+
+ ByteBuffer buf = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
+
+ // valid sketch should not explode when copied, which reads the memory
+ objectStrategy.fromByteBufferSafe(buf,
bytes.length).compact().toByteArray();
+
+ // corrupted sketch should fail with a regular java buffer exception
+ for (int subset = 3; subset < bytes.length - 1; subset++) {
+ final byte[] garbage2 = new byte[subset];
+ for (int i = 0; i < garbage2.length; i++) {
+ garbage2[i] = buf.get(i);
+ }
+
+ final ByteBuffer buf2 =
ByteBuffer.wrap(garbage2).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf2,
garbage2.length).compact().toByteArray()
+ );
+ }
+
+ // non sketch that is too short to contain header should fail with regular
java buffer exception
+ final byte[] garbage = new byte[]{0x01, 0x02};
+ final ByteBuffer buf3 =
ByteBuffer.wrap(garbage).order(ByteOrder.LITTLE_ENDIAN);
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () -> objectStrategy.fromByteBufferSafe(buf3,
garbage.length).compact().toByteArray()
+ );
+ }
+}
diff --git
a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperationsTest.java
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperationsTest.java
new file mode 100644
index 0000000000..415f3acab9
--- /dev/null
+++
b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/tuple/ArrayOfDoublesSketchOperationsTest.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.datasketches.tuple;
+
+import
org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketch;
+import
org.apache.datasketches.tuple.arrayofdoubles.ArrayOfDoublesUpdatableSketchBuilder;
+import org.apache.druid.java.util.common.StringUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class ArrayOfDoublesSketchOperationsTest
+{
+ @Test
+ public void testDeserializeSafe()
+ {
+ ArrayOfDoublesSketchObjectStrategy objectStrategy = new
ArrayOfDoublesSketchObjectStrategy();
+ ArrayOfDoublesUpdatableSketch sketch = new
ArrayOfDoublesUpdatableSketchBuilder().setNominalEntries(1024)
+
.setNumberOfValues(4)
+
.build();
+ sketch.update(1L, new double[]{1.0, 2.0, 3.0, 4.0});
+
+ final byte[] bytes = sketch.toByteArray();
+ final String base64 = StringUtils.encodeBase64String(bytes);
+
+ Assert.assertArrayEquals(bytes,
ArrayOfDoublesSketchOperations.deserializeSafe(sketch).toByteArray());
+ Assert.assertArrayEquals(bytes,
ArrayOfDoublesSketchOperations.deserializeSafe(bytes).toByteArray());
+ Assert.assertArrayEquals(bytes,
ArrayOfDoublesSketchOperations.deserializeSafe(base64).toByteArray());
+
+ final byte[] trunacted = Arrays.copyOfRange(bytes, 0, 10);
+ Assert.assertThrows(IndexOutOfBoundsException.class, () ->
ArrayOfDoublesSketchOperations.deserializeSafe(trunacted));
+ Assert.assertThrows(
+ IndexOutOfBoundsException.class,
+ () ->
ArrayOfDoublesSketchOperations.deserializeSafe(StringUtils.encodeBase64String(trunacted))
+ );
+ }
+}
diff --git
a/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java
b/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java
index 351f2665d0..d05ba20858 100644
---
a/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java
+++
b/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java
@@ -90,6 +90,6 @@ public class ObjectStrategyComplexTypeStrategy<T> implements
TypeStrategy<T>
@Override
public T fromBytes(byte[] value)
{
- return objectStrategy.fromByteBuffer(ByteBuffer.wrap(value), value.length);
+ return objectStrategy.fromByteBufferSafe(ByteBuffer.wrap(value),
value.length);
}
}
diff --git
a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
b/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
index 8a53fc57a7..eba97d04bb 100644
--- a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
+++ b/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
@@ -79,4 +79,31 @@ public interface ObjectStrategy<T> extends Comparator<T>
out.write(bytes);
}
}
+
+ /**
+ * Convert values from their underlying byte representation, when the
underlying bytes might be corrupted or
+ * maliciously constructed
+ *
+ * Implementations of this method <i>absolutely must never</i> perform any
sun.misc.Unsafe based memory read or write
+ * operations from instructions contained in the data read from this buffer
without first validating the data. If the
+ * data cannot be validated, all read and write operations from instructions
in this data must be done directly with
+ * the {@link ByteBuffer} methods, or using {@link SafeWritableMemory} if
+ * {@link org.apache.datasketches.memory.Memory} is employed to materialize
the value.
+ *
+ * Implementations of this method <i>may</i> change the given buffer's mark,
or limit, and position.
+ *
+ * Implementations of this method <i>may not</i> store the given buffer in a
field of the "deserialized" object,
+ * need to use {@link ByteBuffer#slice()}, {@link
ByteBuffer#asReadOnlyBuffer()} or {@link ByteBuffer#duplicate()} in
+ * this case.
+ *
+ *
+ * @param buffer buffer to read value from
+ * @param numBytes number of bytes used to store the value, starting at
buffer.position()
+ * @return an object created from the given byte buffer representation
+ */
+ @Nullable
+ default T fromByteBufferSafe(ByteBuffer buffer, int numBytes)
+ {
+ return fromByteBuffer(buffer, numBytes);
+ }
}
diff --git
a/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBase.java
b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBase.java
new file mode 100644
index 0000000000..df2fc14d05
--- /dev/null
+++
b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBase.java
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.data;
+
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Ints;
+import org.apache.datasketches.memory.BaseState;
+import org.apache.datasketches.memory.MemoryRequestServer;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.datasketches.memory.internal.BaseStateImpl;
+import org.apache.datasketches.memory.internal.UnsafeUtil;
+import org.apache.datasketches.memory.internal.XxHash64;
+import org.apache.druid.java.util.common.StringUtils;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+/**
+ * Base class for making a regular {@link ByteBuffer} look like a {@link
org.apache.datasketches.memory.Memory} or
+ * {@link org.apache.datasketches.memory.Buffer}. All methods delegate
directly to the {@link ByteBuffer} rather
+ * than using 'unsafe' reads.
+ *
+ * @see SafeWritableMemory
+ * @see SafeWritableBuffer
+ */
+
+@SuppressWarnings("unused")
+public abstract class SafeWritableBase implements BaseState
+{
+ static final MemoryRequestServer SAFE_HEAP_REQUEST_SERVER = new
HeapByteBufferMemoryRequestServer();
+
+ final ByteBuffer buffer;
+
+ public SafeWritableBase(ByteBuffer buffer)
+ {
+ this.buffer = buffer;
+ }
+
+ public MemoryRequestServer getMemoryRequestServer()
+ {
+ return SAFE_HEAP_REQUEST_SERVER;
+ }
+
+ public boolean getBoolean(long offsetBytes)
+ {
+ return getByte(Ints.checkedCast(offsetBytes)) != 0;
+ }
+
+ public byte getByte(long offsetBytes)
+ {
+ return buffer.get(Ints.checkedCast(offsetBytes));
+ }
+
+ public char getChar(long offsetBytes)
+ {
+ return buffer.getChar(Ints.checkedCast(offsetBytes));
+ }
+
+ public double getDouble(long offsetBytes)
+ {
+ return buffer.getDouble(Ints.checkedCast(offsetBytes));
+ }
+
+ public float getFloat(long offsetBytes)
+ {
+ return buffer.getFloat(Ints.checkedCast(offsetBytes));
+ }
+
+ public int getInt(long offsetBytes)
+ {
+ return buffer.getInt(Ints.checkedCast(offsetBytes));
+ }
+
+ public long getLong(long offsetBytes)
+ {
+ return buffer.getLong(Ints.checkedCast(offsetBytes));
+ }
+
+ public short getShort(long offsetBytes)
+ {
+ return buffer.getShort(Ints.checkedCast(offsetBytes));
+ }
+
+ public void putBoolean(long offsetBytes, boolean value)
+ {
+ buffer.put(Ints.checkedCast(offsetBytes), (byte) (value ? 1 : 0));
+ }
+
+ public void putByte(long offsetBytes, byte value)
+ {
+ buffer.put(Ints.checkedCast(offsetBytes), value);
+ }
+
+ public void putChar(long offsetBytes, char value)
+ {
+ buffer.putChar(Ints.checkedCast(offsetBytes), value);
+ }
+
+ public void putDouble(long offsetBytes, double value)
+ {
+ buffer.putDouble(Ints.checkedCast(offsetBytes), value);
+ }
+
+ public void putFloat(long offsetBytes, float value)
+ {
+ buffer.putFloat(Ints.checkedCast(offsetBytes), value);
+ }
+
+ public void putInt(long offsetBytes, int value)
+ {
+ buffer.putInt(Ints.checkedCast(offsetBytes), value);
+ }
+
+ public void putLong(long offsetBytes, long value)
+ {
+ buffer.putLong(Ints.checkedCast(offsetBytes), value);
+ }
+
+ public void putShort(long offsetBytes, short value)
+ {
+ buffer.putShort(Ints.checkedCast(offsetBytes), value);
+ }
+
+ @Override
+ public ByteOrder getTypeByteOrder()
+ {
+ return buffer.order();
+ }
+
+ @Override
+ public boolean isByteOrderCompatible(ByteOrder byteOrder)
+ {
+ return buffer.order().equals(byteOrder);
+ }
+
+ @Override
+ public ByteBuffer getByteBuffer()
+ {
+ return buffer;
+ }
+
+ @Override
+ public long getCapacity()
+ {
+ return buffer.capacity();
+ }
+
+ @Override
+ public long getCumulativeOffset()
+ {
+ return 0;
+ }
+
+ @Override
+ public long getCumulativeOffset(long offsetBytes)
+ {
+ return offsetBytes;
+ }
+
+ @Override
+ public long getRegionOffset()
+ {
+ return 0;
+ }
+
+ @Override
+ public long getRegionOffset(long offsetBytes)
+ {
+ return offsetBytes;
+ }
+
+ @Override
+ public boolean hasArray()
+ {
+ return false;
+ }
+
+ @Override
+ public long xxHash64(long offsetBytes, long lengthBytes, long seed)
+ {
+ return hash(buffer, offsetBytes, lengthBytes, seed);
+ }
+
+ @Override
+ public long xxHash64(long in, long seed)
+ {
+ return XxHash64.hash(in, seed);
+ }
+
+ @Override
+ public boolean hasByteBuffer()
+ {
+ return true;
+ }
+
+ @Override
+ public boolean isDirect()
+ {
+ return false;
+ }
+
+ @Override
+ public boolean isReadOnly()
+ {
+ return false;
+ }
+
+ @Override
+ public boolean isSameResource(Object that)
+ {
+ return this.equals(that);
+ }
+
+ @Override
+ public boolean isValid()
+ {
+ return true;
+ }
+
+ @Override
+ public void checkValidAndBounds(long offsetBytes, long lengthBytes)
+ {
+ Preconditions.checkArgument(
+ Ints.checkedCast(offsetBytes) < buffer.limit(),
+ "start offset %s is greater than buffer limit %s",
+ offsetBytes,
+ buffer.limit()
+ );
+ Preconditions.checkArgument(
+ Ints.checkedCast(offsetBytes + lengthBytes) < buffer.limit(),
+ "end offset %s is greater than buffer limit %s",
+ offsetBytes + lengthBytes,
+ buffer.limit()
+ );
+ }
+
+ /**
+ * Adapted from {@link BaseStateImpl#toHexString(String, long, int)}
+ */
+ @Override
+ public String toHexString(String header, long offsetBytes, int lengthBytes)
+ {
+ final String klass = this.getClass().getSimpleName();
+ final String s1 = StringUtils.format("(..., %d, %d)", offsetBytes,
lengthBytes);
+ final long hcode = hashCode() & 0XFFFFFFFFL;
+ final String call = ".toHexString" + s1 + ", hashCode: " + hcode;
+ String sb = "### " + klass + " SUMMARY ###" + UnsafeUtil.LS
+ + "Header Comment : " + header + UnsafeUtil.LS
+ + "Call Parameters : " + call;
+ return toHex(this, sb, offsetBytes, lengthBytes);
+ }
+
+ /**
+ * Adapted from {@link BaseStateImpl#toHex(BaseStateImpl, String, long, int)}
+ */
+ static String toHex(
+ final SafeWritableBase state,
+ final String preamble,
+ final long offsetBytes,
+ final int lengthBytes
+ )
+ {
+ final String lineSeparator = UnsafeUtil.LS;
+ final long capacity = state.getCapacity();
+ UnsafeUtil.checkBounds(offsetBytes, lengthBytes, capacity);
+ final StringBuilder sb = new StringBuilder();
+ final String uObjStr;
+ final long uObjHeader;
+ uObjStr = "null";
+ uObjHeader = 0;
+ final ByteBuffer bb = state.getByteBuffer();
+ final String bbStr = bb == null ? "null"
+ : bb.getClass().getSimpleName() + ", " +
(bb.hashCode() & 0XFFFFFFFFL);
+ final MemoryRequestServer memReqSvr = state.getMemoryRequestServer();
+ final String memReqStr = memReqSvr != null
+ ? memReqSvr.getClass().getSimpleName() + ", " +
(memReqSvr.hashCode() & 0XFFFFFFFFL)
+ : "null";
+ final long cumBaseOffset = state.getCumulativeOffset();
+ sb.append(preamble).append(lineSeparator);
+ sb.append("UnsafeObj, hashCode : ").append(uObjStr).append(lineSeparator);
+ sb.append("UnsafeObjHeader :
").append(uObjHeader).append(lineSeparator);
+ sb.append("ByteBuf, hashCode : ").append(bbStr).append(lineSeparator);
+ sb.append("RegionOffset :
").append(state.getRegionOffset()).append(lineSeparator);
+ sb.append("Capacity : ").append(capacity).append(lineSeparator);
+ sb.append("CumBaseOffset :
").append(cumBaseOffset).append(lineSeparator);
+ sb.append("MemReq, hashCode :
").append(memReqStr).append(lineSeparator);
+ sb.append("Valid :
").append(state.isValid()).append(lineSeparator);
+ sb.append("Read Only :
").append(state.isReadOnly()).append(lineSeparator);
+ sb.append("Type Byte Order :
").append(state.getTypeByteOrder()).append(lineSeparator);
+ sb.append("Native Byte Order :
").append(ByteOrder.nativeOrder()).append(lineSeparator);
+ sb.append("JDK Runtime Version :
").append(UnsafeUtil.JDK).append(lineSeparator);
+ //Data detail
+ sb.append("Data, littleEndian : 0 1 2 3 4 5 6 7");
+
+ for (long i = 0; i < lengthBytes; i++) {
+ final int b = state.getByte(cumBaseOffset + offsetBytes + i) & 0XFF;
+ if (i % 8 == 0) { //row header
+ sb.append(StringUtils.format("%n%20s: ", offsetBytes + i));
+ }
+ sb.append(StringUtils.format("%02x ", b));
+ }
+ sb.append(lineSeparator);
+
+ return sb.toString();
+ }
+
+ // copied from datasketches-memory XxHash64.java
+ private static final long P1 = -7046029288634856825L;
+ private static final long P2 = -4417276706812531889L;
+ private static final long P3 = 1609587929392839161L;
+ private static final long P4 = -8796714831421723037L;
+ private static final long P5 = 2870177450012600261L;
+
+ /**
+ * Adapted from {@link XxHash64#hash(Object, long, long, long)} to work with
{@link ByteBuffer}
+ */
+ static long hash(ByteBuffer memory, long cumOffsetBytes, final long
lengthBytes, final long seed)
+ {
+ long hash;
+ long remaining = lengthBytes;
+ int offset = Ints.checkedCast(cumOffsetBytes);
+
+ if (remaining >= 32) {
+ long v1 = seed + P1 + P2;
+ long v2 = seed + P2;
+ long v3 = seed;
+ long v4 = seed - P1;
+
+ do {
+ v1 += memory.getLong(offset) * P2;
+ v1 = Long.rotateLeft(v1, 31);
+ v1 *= P1;
+
+ v2 += memory.getLong(offset + 8) * P2;
+ v2 = Long.rotateLeft(v2, 31);
+ v2 *= P1;
+
+ v3 += memory.getLong(offset + 16) * P2;
+ v3 = Long.rotateLeft(v3, 31);
+ v3 *= P1;
+
+ v4 += memory.getLong(offset + 24) * P2;
+ v4 = Long.rotateLeft(v4, 31);
+ v4 *= P1;
+
+ offset += 32;
+ remaining -= 32;
+ } while (remaining >= 32);
+
+ hash = Long.rotateLeft(v1, 1)
+ + Long.rotateLeft(v2, 7)
+ + Long.rotateLeft(v3, 12)
+ + Long.rotateLeft(v4, 18);
+
+ v1 *= P2;
+ v1 = Long.rotateLeft(v1, 31);
+ v1 *= P1;
+ hash ^= v1;
+ hash = (hash * P1) + P4;
+
+ v2 *= P2;
+ v2 = Long.rotateLeft(v2, 31);
+ v2 *= P1;
+ hash ^= v2;
+ hash = (hash * P1) + P4;
+
+ v3 *= P2;
+ v3 = Long.rotateLeft(v3, 31);
+ v3 *= P1;
+ hash ^= v3;
+ hash = (hash * P1) + P4;
+
+ v4 *= P2;
+ v4 = Long.rotateLeft(v4, 31);
+ v4 *= P1;
+ hash ^= v4;
+ hash = (hash * P1) + P4;
+ } else { //end remaining >= 32
+ hash = seed + P5;
+ }
+
+ hash += lengthBytes;
+
+ while (remaining >= 8) {
+ long k1 = memory.getLong(offset);
+ k1 *= P2;
+ k1 = Long.rotateLeft(k1, 31);
+ k1 *= P1;
+ hash ^= k1;
+ hash = (Long.rotateLeft(hash, 27) * P1) + P4;
+ offset += 8;
+ remaining -= 8;
+ }
+
+ if (remaining >= 4) { //treat as unsigned ints
+ hash ^= (memory.getInt(offset) & 0XFFFF_FFFFL) * P1;
+ hash = (Long.rotateLeft(hash, 23) * P2) + P3;
+ offset += 4;
+ remaining -= 4;
+ }
+
+ while (remaining != 0) { //treat as unsigned bytes
+ hash ^= (memory.get(offset) & 0XFFL) * P5;
+ hash = Long.rotateLeft(hash, 11) * P1;
+ --remaining;
+ ++offset;
+ }
+
+ hash ^= hash >>> 33;
+ hash *= P2;
+ hash ^= hash >>> 29;
+ hash *= P3;
+ hash ^= hash >>> 32;
+ return hash;
+ }
+
+ private static class HeapByteBufferMemoryRequestServer implements
MemoryRequestServer
+ {
+ @Override
+ public WritableMemory request(WritableMemory currentWritableMemory, long
capacityBytes)
+ {
+ ByteBuffer newBuffer =
ByteBuffer.allocate(Ints.checkedCast(capacityBytes));
+ newBuffer.order(currentWritableMemory.getTypeByteOrder());
+ return new SafeWritableMemory(newBuffer);
+ }
+
+ @Override
+ public void requestClose(WritableMemory memToClose, WritableMemory
newMemory)
+ {
+ // do nothing
+ }
+ }
+}
diff --git
a/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBuffer.java
b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBuffer.java
new file mode 100644
index 0000000000..3da7e70b45
--- /dev/null
+++
b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableBuffer.java
@@ -0,0 +1,501 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.data;
+
+import com.google.common.primitives.Ints;
+import org.apache.datasketches.memory.BaseBuffer;
+import org.apache.datasketches.memory.Buffer;
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.WritableBuffer;
+import org.apache.datasketches.memory.WritableMemory;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+/**
+ * Safety first! Don't trust something whose contents you locations to read
and write stuff to, but need a
+ * {@link Buffer} or {@link WritableBuffer}? use this!
+ * <p>
+ * Delegates everything to an underlying {@link ByteBuffer} so all read and
write operations will have bounds checks
+ * built in rather than using 'unsafe'.
+ */
+public class SafeWritableBuffer extends SafeWritableBase implements
WritableBuffer
+{
+ private int start;
+ private int end;
+
+ public SafeWritableBuffer(ByteBuffer buffer)
+ {
+ super(buffer);
+ this.start = 0;
+ this.buffer.position(0);
+ this.end = buffer.capacity();
+ }
+
+ @Override
+ public WritableBuffer writableDuplicate()
+ {
+ return writableDuplicate(buffer.order());
+ }
+
+ @Override
+ public WritableBuffer writableDuplicate(ByteOrder byteOrder)
+ {
+ ByteBuffer dupe = buffer.duplicate();
+ dupe.order(byteOrder);
+ WritableBuffer duplicate = new SafeWritableBuffer(dupe);
+ duplicate.setStartPositionEnd(start, buffer.position(), end);
+ return duplicate;
+ }
+
+ @Override
+ public WritableBuffer writableRegion()
+ {
+ ByteBuffer dupe = buffer.duplicate().order(buffer.order());
+ dupe.position(start);
+ dupe.limit(end);
+ ByteBuffer remaining = buffer.slice();
+ remaining.order(dupe.order());
+ return new SafeWritableBuffer(remaining);
+ }
+
+ @Override
+ public WritableBuffer writableRegion(long offsetBytes, long capacityBytes,
ByteOrder byteOrder)
+ {
+ ByteBuffer dupe = buffer.duplicate();
+ dupe.position(Ints.checkedCast(offsetBytes));
+ dupe.limit(dupe.position() + Ints.checkedCast(capacityBytes));
+ return new SafeWritableBuffer(dupe.slice().order(byteOrder));
+ }
+
+ @Override
+ public WritableMemory asWritableMemory(ByteOrder byteOrder)
+ {
+ ByteBuffer dupe = buffer.duplicate();
+ dupe.order(byteOrder);
+ return new SafeWritableMemory(dupe);
+ }
+
+ @Override
+ public void putBoolean(boolean value)
+ {
+ buffer.put((byte) (value ? 1 : 0));
+ }
+
+ @Override
+ public void putBooleanArray(boolean[] srcArray, int srcOffsetBooleans, int
lengthBooleans)
+ {
+ for (int i = 0; i < lengthBooleans; i++) {
+ putBoolean(srcArray[srcOffsetBooleans + i]);
+ }
+ }
+
+ @Override
+ public void putByte(byte value)
+ {
+ buffer.put(value);
+ }
+
+ @Override
+ public void putByteArray(byte[] srcArray, int srcOffsetBytes, int
lengthBytes)
+ {
+ buffer.put(srcArray, srcOffsetBytes, lengthBytes);
+ }
+
+ @Override
+ public void putChar(char value)
+ {
+ buffer.putChar(value);
+ }
+
+ @Override
+ public void putCharArray(char[] srcArray, int srcOffsetChars, int
lengthChars)
+ {
+ for (int i = 0; i < lengthChars; i++) {
+ buffer.putChar(srcArray[srcOffsetChars + i]);
+ }
+ }
+
+ @Override
+ public void putDouble(double value)
+ {
+ buffer.putDouble(value);
+ }
+
+ @Override
+ public void putDoubleArray(double[] srcArray, int srcOffsetDoubles, int
lengthDoubles)
+ {
+ for (int i = 0; i < lengthDoubles; i++) {
+ buffer.putDouble(srcArray[srcOffsetDoubles + i]);
+ }
+ }
+
+ @Override
+ public void putFloat(float value)
+ {
+ buffer.putFloat(value);
+ }
+
+ @Override
+ public void putFloatArray(float[] srcArray, int srcOffsetFloats, int
lengthFloats)
+ {
+ for (int i = 0; i < lengthFloats; i++) {
+ buffer.putFloat(srcArray[srcOffsetFloats + i]);
+ }
+ }
+
+ @Override
+ public void putInt(int value)
+ {
+ buffer.putInt(value);
+ }
+
+ @Override
+ public void putIntArray(int[] srcArray, int srcOffsetInts, int lengthInts)
+ {
+ for (int i = 0; i < lengthInts; i++) {
+ buffer.putInt(srcArray[srcOffsetInts + i]);
+ }
+ }
+
+ @Override
+ public void putLong(long value)
+ {
+ buffer.putLong(value);
+ }
+
+ @Override
+ public void putLongArray(long[] srcArray, int srcOffsetLongs, int
lengthLongs)
+ {
+ for (int i = 0; i < lengthLongs; i++) {
+ buffer.putLong(srcArray[srcOffsetLongs + i]);
+ }
+ }
+
+ @Override
+ public void putShort(short value)
+ {
+ buffer.putShort(value);
+ }
+
+ @Override
+ public void putShortArray(short[] srcArray, int srcOffsetShorts, int
lengthShorts)
+ {
+ for (int i = 0; i < lengthShorts; i++) {
+ buffer.putShort(srcArray[srcOffsetShorts + i]);
+ }
+ }
+
+ @Override
+ public Object getArray()
+ {
+ return null;
+ }
+
+ @Override
+ public void clear()
+ {
+ fill((byte) 0);
+ }
+
+ @Override
+ public void fill(byte value)
+ {
+ while (buffer.hasRemaining() && buffer.position() < end) {
+ buffer.put(value);
+ }
+ }
+
+ @Override
+ public Buffer duplicate()
+ {
+ return writableDuplicate();
+ }
+
+ @Override
+ public Buffer duplicate(ByteOrder byteOrder)
+ {
+ return writableDuplicate(byteOrder);
+ }
+
+ @Override
+ public Buffer region()
+ {
+ return writableRegion();
+ }
+
+ @Override
+ public Buffer region(long offsetBytes, long capacityBytes, ByteOrder
byteOrder)
+ {
+ return writableRegion(offsetBytes, capacityBytes, byteOrder);
+ }
+
+ @Override
+ public Memory asMemory(ByteOrder byteOrder)
+ {
+ return asWritableMemory(byteOrder);
+ }
+
+ @Override
+ public boolean getBoolean()
+ {
+ return buffer.get() == 0 ? false : true;
+ }
+
+ @Override
+ public void getBooleanArray(boolean[] dstArray, int dstOffsetBooleans, int
lengthBooleans)
+ {
+ for (int i = 0; i < lengthBooleans; i++) {
+ dstArray[dstOffsetBooleans + i] = getBoolean();
+ }
+ }
+
+ @Override
+ public byte getByte()
+ {
+ return buffer.get();
+ }
+
+ @Override
+ public void getByteArray(byte[] dstArray, int dstOffsetBytes, int
lengthBytes)
+ {
+ for (int i = 0; i < lengthBytes; i++) {
+ dstArray[dstOffsetBytes + i] = buffer.get();
+ }
+ }
+
+ @Override
+ public char getChar()
+ {
+ return buffer.getChar();
+ }
+
+ @Override
+ public void getCharArray(char[] dstArray, int dstOffsetChars, int
lengthChars)
+ {
+ for (int i = 0; i < lengthChars; i++) {
+ dstArray[dstOffsetChars + i] = buffer.getChar();
+ }
+ }
+
+ @Override
+ public double getDouble()
+ {
+ return buffer.getDouble();
+ }
+
+ @Override
+ public void getDoubleArray(double[] dstArray, int dstOffsetDoubles, int
lengthDoubles)
+ {
+ for (int i = 0; i < lengthDoubles; i++) {
+ dstArray[dstOffsetDoubles + i] = buffer.getDouble();
+ }
+ }
+
+ @Override
+ public float getFloat()
+ {
+ return buffer.getFloat();
+ }
+
+ @Override
+ public void getFloatArray(float[] dstArray, int dstOffsetFloats, int
lengthFloats)
+ {
+ for (int i = 0; i < lengthFloats; i++) {
+ dstArray[dstOffsetFloats + i] = buffer.getFloat();
+ }
+ }
+
+ @Override
+ public int getInt()
+ {
+ return buffer.getInt();
+ }
+
+ @Override
+ public void getIntArray(int[] dstArray, int dstOffsetInts, int lengthInts)
+ {
+ for (int i = 0; i < lengthInts; i++) {
+ dstArray[dstOffsetInts + i] = buffer.getInt();
+ }
+ }
+
+ @Override
+ public long getLong()
+ {
+ return buffer.getLong();
+ }
+
+ @Override
+ public void getLongArray(long[] dstArray, int dstOffsetLongs, int
lengthLongs)
+ {
+ for (int i = 0; i < lengthLongs; i++) {
+ dstArray[dstOffsetLongs + i] = buffer.getLong();
+ }
+ }
+
+ @Override
+ public short getShort()
+ {
+ return buffer.getShort();
+ }
+
+ @Override
+ public void getShortArray(short[] dstArray, int dstOffsetShorts, int
lengthShorts)
+ {
+ for (int i = 0; i < lengthShorts; i++) {
+ dstArray[dstOffsetShorts + i] = buffer.getShort();
+ }
+ }
+
+ @Override
+ public int compareTo(
+ long thisOffsetBytes,
+ long thisLengthBytes,
+ Buffer that,
+ long thatOffsetBytes,
+ long thatLengthBytes
+ )
+ {
+ final int thisLength = Ints.checkedCast(thisLengthBytes);
+ final int thatLength = Ints.checkedCast(thatLengthBytes);
+
+ final int commonLength = Math.min(thisLength, thatLength);
+
+ for (int i = 0; i < commonLength; i++) {
+ final int cmp = Byte.compare(getByte(thisOffsetBytes + i),
that.getByte(thatOffsetBytes + i));
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+
+ return Integer.compare(thisLength, thatLength);
+ }
+
+ @Override
+ public BaseBuffer incrementPosition(long increment)
+ {
+ buffer.position(buffer.position() + Ints.checkedCast(increment));
+ return this;
+ }
+
+ @Override
+ public BaseBuffer incrementAndCheckPosition(long increment)
+ {
+ checkInvariants(start, buffer.position() + increment, end,
buffer.capacity());
+ return incrementPosition(increment);
+ }
+
+ @Override
+ public long getEnd()
+ {
+ return end;
+ }
+
+ @Override
+ public long getPosition()
+ {
+ return buffer.position();
+ }
+
+ @Override
+ public long getStart()
+ {
+ return start;
+ }
+
+ @Override
+ public long getRemaining()
+ {
+ return buffer.remaining();
+ }
+
+ @Override
+ public boolean hasRemaining()
+ {
+ return buffer.hasRemaining();
+ }
+
+ @Override
+ public BaseBuffer resetPosition()
+ {
+ buffer.position(start);
+ return this;
+ }
+
+ @Override
+ public BaseBuffer setPosition(long position)
+ {
+ buffer.position(Ints.checkedCast(position));
+ return this;
+ }
+
+ @Override
+ public BaseBuffer setAndCheckPosition(long position)
+ {
+ checkInvariants(start, position, end, buffer.capacity());
+ return setPosition(position);
+ }
+
+ @Override
+ public BaseBuffer setStartPositionEnd(long start, long position, long end)
+ {
+ this.start = Ints.checkedCast(start);
+ this.end = Ints.checkedCast(end);
+ buffer.position(Ints.checkedCast(position));
+ buffer.limit(this.end);
+ return this;
+ }
+
+ @Override
+ public BaseBuffer setAndCheckStartPositionEnd(long start, long position,
long end)
+ {
+ checkInvariants(start, position, end, buffer.capacity());
+ return setStartPositionEnd(start, position, end);
+ }
+
+ @Override
+ public boolean equalTo(long thisOffsetBytes, Object that, long
thatOffsetBytes, long lengthBytes)
+ {
+ if (!(that instanceof SafeWritableBuffer)) {
+ return false;
+ }
+ return compareTo(thisOffsetBytes, lengthBytes, (SafeWritableBuffer) that,
thatOffsetBytes, lengthBytes) == 0;
+ }
+
+ /**
+ * Adapted from {@link
org.apache.datasketches.memory.internal.BaseBufferImpl#checkInvariants(long,
long, long, long)}
+ */
+ static void checkInvariants(final long start, final long pos, final long
end, final long cap)
+ {
+ if ((start | pos | end | cap | (pos - start) | (end - pos) | (cap - end))
< 0L) {
+ throw new IllegalArgumentException(
+ "Violation of Invariants: "
+ + "start: " + start
+ + " <= pos: " + pos
+ + " <= end: " + end
+ + " <= cap: " + cap
+ + "; (pos - start): " + (pos - start)
+ + ", (end - pos): " + (end - pos)
+ + ", (cap - end): " + (cap - end)
+ );
+ }
+ }
+}
diff --git
a/processing/src/main/java/org/apache/druid/segment/data/SafeWritableMemory.java
b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableMemory.java
new file mode 100644
index 0000000000..9006ac5cec
--- /dev/null
+++
b/processing/src/main/java/org/apache/druid/segment/data/SafeWritableMemory.java
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.data;
+
+import com.google.common.primitives.Ints;
+import org.apache.datasketches.memory.Buffer;
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.Utf8CodingException;
+import org.apache.datasketches.memory.WritableBuffer;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.druid.java.util.common.StringUtils;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * Safety first! Don't trust something whose contents you locations to read
and write stuff to, but need a
+ * {@link Memory} or {@link WritableMemory}? use this!
+ * <p>
+ * Delegates everything to an underlying {@link ByteBuffer} so all read and
write operations will have bounds checks
+ * built in rather than using 'unsafe'.
+ */
+public class SafeWritableMemory extends SafeWritableBase implements
WritableMemory
+{
+ public static SafeWritableMemory wrap(byte[] bytes)
+ {
+ return wrap(ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()), 0,
bytes.length);
+ }
+
+ public static SafeWritableMemory wrap(ByteBuffer buffer)
+ {
+ return wrap(buffer.duplicate().order(buffer.order()), 0,
buffer.capacity());
+ }
+
+ public static SafeWritableMemory wrap(ByteBuffer buffer, ByteOrder byteOrder)
+ {
+ return wrap(buffer.duplicate().order(byteOrder), 0, buffer.capacity());
+ }
+
+ public static SafeWritableMemory wrap(ByteBuffer buffer, int offset, int
size)
+ {
+ final ByteBuffer dupe = buffer.duplicate().order(buffer.order());
+ dupe.position(offset);
+ dupe.limit(offset + size);
+ return new SafeWritableMemory(dupe.slice().order(buffer.order()));
+ }
+
+ public SafeWritableMemory(ByteBuffer buffer)
+ {
+ super(buffer);
+ }
+
+ @Override
+ public Memory region(long offsetBytes, long capacityBytes, ByteOrder
byteOrder)
+ {
+ return writableRegion(offsetBytes, capacityBytes, byteOrder);
+ }
+
+ @Override
+ public Buffer asBuffer(ByteOrder byteOrder)
+ {
+ return asWritableBuffer(byteOrder);
+ }
+
+ @Override
+ public void getBooleanArray(long offsetBytes, boolean[] dstArray, int
dstOffsetBooleans, int lengthBooleans)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthBooleans; j++) {
+ dstArray[dstOffsetBooleans + j] = buffer.get(offset + j) != 0;
+ }
+ }
+
+ @Override
+ public void getByteArray(long offsetBytes, byte[] dstArray, int
dstOffsetBytes, int lengthBytes)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthBytes; j++) {
+ dstArray[dstOffsetBytes + j] = buffer.get(offset + j);
+ }
+ }
+
+ @Override
+ public void getCharArray(long offsetBytes, char[] dstArray, int
dstOffsetChars, int lengthChars)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthChars; j++) {
+ dstArray[dstOffsetChars + j] = buffer.getChar(offset + (j *
Character.BYTES));
+ }
+ }
+
+ @Override
+ public int getCharsFromUtf8(long offsetBytes, int utf8LengthBytes,
Appendable dst)
+ throws IOException, Utf8CodingException
+ {
+ ByteBuffer dupe = buffer.asReadOnlyBuffer().order(buffer.order());
+ dupe.position(Ints.checkedCast(offsetBytes));
+ String s = StringUtils.fromUtf8(dupe, utf8LengthBytes);
+ dst.append(s);
+ return s.length();
+ }
+
+ @Override
+ public int getCharsFromUtf8(long offsetBytes, int utf8LengthBytes,
StringBuilder dst) throws Utf8CodingException
+ {
+ ByteBuffer dupe = buffer.asReadOnlyBuffer().order(buffer.order());
+ dupe.position(Ints.checkedCast(offsetBytes));
+ String s = StringUtils.fromUtf8(dupe, utf8LengthBytes);
+ dst.append(s);
+ return s.length();
+ }
+
+ @Override
+ public void getDoubleArray(long offsetBytes, double[] dstArray, int
dstOffsetDoubles, int lengthDoubles)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthDoubles; j++) {
+ dstArray[dstOffsetDoubles + j] = buffer.getDouble(offset + (j *
Double.BYTES));
+ }
+ }
+
+ @Override
+ public void getFloatArray(long offsetBytes, float[] dstArray, int
dstOffsetFloats, int lengthFloats)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthFloats; j++) {
+ dstArray[dstOffsetFloats + j] = buffer.getFloat(offset + (j *
Float.BYTES));
+ }
+ }
+
+ @Override
+ public void getIntArray(long offsetBytes, int[] dstArray, int dstOffsetInts,
int lengthInts)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthInts; j++) {
+ dstArray[dstOffsetInts + j] = buffer.getInt(offset + (j *
Integer.BYTES));
+ }
+ }
+
+ @Override
+ public void getLongArray(long offsetBytes, long[] dstArray, int
dstOffsetLongs, int lengthLongs)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthLongs; j++) {
+ dstArray[dstOffsetLongs + j] = buffer.getLong(offset + (j * Long.BYTES));
+ }
+ }
+
+ @Override
+ public void getShortArray(long offsetBytes, short[] dstArray, int
dstOffsetShorts, int lengthShorts)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int j = 0; j < lengthShorts; j++) {
+ dstArray[dstOffsetShorts + j] = buffer.getShort(offset + (j *
Short.BYTES));
+ }
+ }
+
+ @Override
+ public int compareTo(
+ long thisOffsetBytes,
+ long thisLengthBytes,
+ Memory that,
+ long thatOffsetBytes,
+ long thatLengthBytes
+ )
+ {
+ final int thisLength = Ints.checkedCast(thisLengthBytes);
+ final int thatLength = Ints.checkedCast(thatLengthBytes);
+
+ final int commonLength = Math.min(thisLength, thatLength);
+
+ for (int i = 0; i < commonLength; i++) {
+ final int cmp = Byte.compare(getByte(thisOffsetBytes + i),
that.getByte(thatOffsetBytes + i));
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+
+ return Integer.compare(thisLength, thatLength);
+ }
+
+ @Override
+ public void copyTo(long srcOffsetBytes, WritableMemory destination, long
dstOffsetBytes, long lengthBytes)
+ {
+ int offset = Ints.checkedCast(srcOffsetBytes);
+ for (int i = 0; i < lengthBytes; i++) {
+ destination.putByte(dstOffsetBytes + i, buffer.get(offset + i));
+ }
+ }
+
+ @Override
+ public void writeTo(long offsetBytes, long lengthBytes, WritableByteChannel
out) throws IOException
+ {
+ ByteBuffer dupe = buffer.duplicate();
+ dupe.position(Ints.checkedCast(offsetBytes));
+ dupe.limit(dupe.position() + Ints.checkedCast(lengthBytes));
+ ByteBuffer view = dupe.slice();
+ view.order(buffer.order());
+ out.write(view);
+ }
+
+ @Override
+ public boolean equalTo(long thisOffsetBytes, Object that, long
thatOffsetBytes, long lengthBytes)
+ {
+ if (!(that instanceof SafeWritableMemory)) {
+ return false;
+ }
+ return compareTo(thisOffsetBytes, lengthBytes, (SafeWritableMemory) that,
thatOffsetBytes, lengthBytes) == 0;
+ }
+
+
+ @Override
+ public WritableMemory writableRegion(long offsetBytes, long capacityBytes,
ByteOrder byteOrder)
+ {
+ final ByteBuffer dupe = buffer.duplicate().order(buffer.order());
+ final int sizeBytes = Ints.checkedCast(capacityBytes);
+ dupe.position(Ints.checkedCast(offsetBytes));
+ dupe.limit(dupe.position() + sizeBytes);
+ final ByteBuffer view = dupe.slice();
+ view.order(byteOrder);
+ return new SafeWritableMemory(view);
+ }
+
+ @Override
+ public WritableBuffer asWritableBuffer(ByteOrder byteOrder)
+ {
+ return new SafeWritableBuffer(buffer.duplicate().order(byteOrder));
+ }
+
+ @Override
+ public void putBooleanArray(long offsetBytes, boolean[] srcArray, int
srcOffsetBooleans, int lengthBooleans)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthBooleans; i++) {
+ buffer.put(offset + i, (byte) (srcArray[i + srcOffsetBooleans] ? 1 : 0));
+ }
+ }
+
+ @Override
+ public void putByteArray(long offsetBytes, byte[] srcArray, int
srcOffsetBytes, int lengthBytes)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthBytes; i++) {
+ buffer.put(offset + i, srcArray[srcOffsetBytes + i]);
+ }
+ }
+
+ @Override
+ public void putCharArray(long offsetBytes, char[] srcArray, int
srcOffsetChars, int lengthChars)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthChars; i++) {
+ buffer.putChar(offset + (i * Character.BYTES), srcArray[srcOffsetChars +
i]);
+ }
+ }
+
+ @Override
+ public long putCharsToUtf8(long offsetBytes, CharSequence src)
+ {
+ final byte[] bytes = StringUtils.toUtf8(src.toString());
+ putByteArray(offsetBytes, bytes, 0, bytes.length);
+ return bytes.length;
+ }
+
+ @Override
+ public void putDoubleArray(long offsetBytes, double[] srcArray, int
srcOffsetDoubles, int lengthDoubles)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthDoubles; i++) {
+ buffer.putDouble(offset + (i * Double.BYTES), srcArray[srcOffsetDoubles
+ i]);
+ }
+ }
+
+ @Override
+ public void putFloatArray(long offsetBytes, float[] srcArray, int
srcOffsetFloats, int lengthFloats)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthFloats; i++) {
+ buffer.putFloat(offset + (i * Float.BYTES), srcArray[srcOffsetFloats +
i]);
+ }
+ }
+
+ @Override
+ public void putIntArray(long offsetBytes, int[] srcArray, int srcOffsetInts,
int lengthInts)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthInts; i++) {
+ buffer.putInt(offset + (i * Integer.BYTES), srcArray[srcOffsetInts + i]);
+ }
+ }
+
+ @Override
+ public void putLongArray(long offsetBytes, long[] srcArray, int
srcOffsetLongs, int lengthLongs)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthLongs; i++) {
+ buffer.putLong(offset + (i * Long.BYTES), srcArray[srcOffsetLongs + i]);
+ }
+ }
+
+ @Override
+ public void putShortArray(long offsetBytes, short[] srcArray, int
srcOffsetShorts, int lengthShorts)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ for (int i = 0; i < lengthShorts; i++) {
+ buffer.putShort(offset + (i * Short.BYTES), srcArray[srcOffsetShorts +
i]);
+ }
+ }
+
+ @Override
+ public long getAndAddLong(long offsetBytes, long delta)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ final long currentValue;
+ synchronized (buffer) {
+ currentValue = buffer.getLong(offset);
+ buffer.putLong(offset, currentValue + delta);
+ }
+ return currentValue;
+ }
+
+ @Override
+ public boolean compareAndSwapLong(long offsetBytes, long expect, long update)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ synchronized (buffer) {
+ final long actual = buffer.getLong(offset);
+ if (expect == actual) {
+ buffer.putLong(offset, update);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public long getAndSetLong(long offsetBytes, long newValue)
+ {
+ int offset = Ints.checkedCast(offsetBytes);
+ synchronized (buffer) {
+ long l = buffer.getLong(offset);
+ buffer.putLong(offset, newValue);
+ return l;
+ }
+ }
+
+ @Override
+ public Object getArray()
+ {
+ return null;
+ }
+
+ @Override
+ public void clear()
+ {
+ fill((byte) 0);
+ }
+
+ @Override
+ public void clear(long offsetBytes, long lengthBytes)
+ {
+ fill(offsetBytes, lengthBytes, (byte) 0);
+ }
+
+ @Override
+ public void clearBits(long offsetBytes, byte bitMask)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ int value = buffer.get(offset) & 0XFF;
+ value &= ~bitMask;
+ buffer.put(offset, (byte) value);
+ }
+
+ @Override
+ public void fill(byte value)
+ {
+ for (int i = 0; i < buffer.capacity(); i++) {
+ buffer.put(i, value);
+ }
+ }
+
+ @Override
+ public void fill(long offsetBytes, long lengthBytes, byte value)
+ {
+ int offset = Ints.checkedCast(offsetBytes);
+ int length = Ints.checkedCast(lengthBytes);
+ for (int i = 0; i < length; i++) {
+ buffer.put(offset + i, value);
+ }
+ }
+
+ @Override
+ public void setBits(long offsetBytes, byte bitMask)
+ {
+ final int offset = Ints.checkedCast(offsetBytes);
+ buffer.put(offset, (byte) (buffer.get(offset) | bitMask));
+ }
+}
diff --git
a/processing/src/test/java/org/apache/druid/segment/data/SafeWritableBufferTest.java
b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableBufferTest.java
new file mode 100644
index 0000000000..f432b7c167
--- /dev/null
+++
b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableBufferTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.data;
+
+import org.apache.datasketches.memory.Buffer;
+import org.apache.datasketches.memory.WritableBuffer;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class SafeWritableBufferTest
+{
+ private static final int CAPACITY = 1024;
+
+ @Test
+ public void testPutAndGet()
+ {
+ WritableBuffer b1 = getBuffer();
+ Assert.assertEquals(0, b1.getPosition());
+ b1.putByte((byte) 0x01);
+ Assert.assertEquals(1, b1.getPosition());
+ b1.putBoolean(true);
+ Assert.assertEquals(2, b1.getPosition());
+ b1.putBoolean(false);
+ Assert.assertEquals(3, b1.getPosition());
+ b1.putChar('c');
+ Assert.assertEquals(5, b1.getPosition());
+ b1.putDouble(1.1);
+ Assert.assertEquals(13, b1.getPosition());
+ b1.putFloat(1.1f);
+ Assert.assertEquals(17, b1.getPosition());
+ b1.putInt(100);
+ Assert.assertEquals(21, b1.getPosition());
+ b1.putLong(1000L);
+ Assert.assertEquals(29, b1.getPosition());
+ b1.putShort((short) 15);
+ Assert.assertEquals(31, b1.getPosition());
+ b1.resetPosition();
+
+ Assert.assertEquals(0x01, b1.getByte());
+ Assert.assertTrue(b1.getBoolean());
+ Assert.assertFalse(b1.getBoolean());
+ Assert.assertEquals('c', b1.getChar());
+ Assert.assertEquals(1.1, b1.getDouble(), 0.0);
+ Assert.assertEquals(1.1f, b1.getFloat(), 0.0);
+ Assert.assertEquals(100, b1.getInt());
+ Assert.assertEquals(1000L, b1.getLong());
+ Assert.assertEquals(15, b1.getShort());
+ }
+
+ @Test
+ public void testPutAndGetArrays()
+ {
+ WritableBuffer buffer = getBuffer();
+ final byte[] b1 = new byte[]{0x01, 0x02, 0x08, 0x08};
+ final byte[] b2 = new byte[b1.length];
+
+ final boolean[] bool1 = new boolean[]{true, false, false, true};
+ final boolean[] bool2 = new boolean[bool1.length];
+
+ final char[] chars1 = new char[]{'a', 'b', 'c', 'd'};
+ final char[] chars2 = new char[chars1.length];
+
+ final double[] double1 = new double[]{1.1, -2.2, 3.3, 4.4};
+ final double[] double2 = new double[double1.length];
+
+ final float[] float1 = new float[]{1.1f, 2.2f, -3.3f, 4.4f};
+ final float[] float2 = new float[float1.length];
+
+ final int[] ints1 = new int[]{1, 2, -3, 4};
+ final int[] ints2 = new int[ints1.length];
+
+ final long[] longs1 = new long[]{1L, -2L, 3L, -14L};
+ final long[] longs2 = new long[ints1.length];
+
+ final short[] shorts1 = new short[]{1, -2, 3, -14};
+ final short[] shorts2 = new short[ints1.length];
+
+ buffer.putByteArray(b1, 0, 2);
+ buffer.putByteArray(b1, 2, b1.length - 2);
+ buffer.putBooleanArray(bool1, 0, bool1.length);
+ buffer.putCharArray(chars1, 0, chars1.length);
+ buffer.putDoubleArray(double1, 0, double1.length);
+ buffer.putFloatArray(float1, 0, float1.length);
+ buffer.putIntArray(ints1, 0, ints1.length);
+ buffer.putLongArray(longs1, 0, longs1.length);
+ buffer.putShortArray(shorts1, 0, shorts1.length);
+ long pos = buffer.getPosition();
+ buffer.resetPosition();
+ buffer.getByteArray(b2, 0, b1.length);
+ buffer.getBooleanArray(bool2, 0, bool1.length);
+ buffer.getCharArray(chars2, 0, chars1.length);
+ buffer.getDoubleArray(double2, 0, double1.length);
+ buffer.getFloatArray(float2, 0, float1.length);
+ buffer.getIntArray(ints2, 0, ints1.length);
+ buffer.getLongArray(longs2, 0, longs1.length);
+ buffer.getShortArray(shorts2, 0, shorts1.length);
+
+ Assert.assertArrayEquals(b1, b2);
+ Assert.assertArrayEquals(bool1, bool2);
+ Assert.assertArrayEquals(chars1, chars2);
+ for (int i = 0; i < double1.length; i++) {
+ Assert.assertEquals(double1[i], double2[i], 0.0);
+ }
+ for (int i = 0; i < float1.length; i++) {
+ Assert.assertEquals(float1[i], float2[i], 0.0);
+ }
+ Assert.assertArrayEquals(ints1, ints2);
+ Assert.assertArrayEquals(longs1, longs2);
+ Assert.assertArrayEquals(shorts1, shorts2);
+
+ Assert.assertEquals(pos, buffer.getPosition());
+ }
+
+ @Test
+ public void testStartEndRegionAndDuplicate()
+ {
+ WritableBuffer buffer = getBuffer();
+ Assert.assertEquals(0, buffer.getPosition());
+ Assert.assertEquals(0, buffer.getStart());
+ Assert.assertEquals(CAPACITY, buffer.getEnd());
+ Assert.assertEquals(CAPACITY, buffer.getRemaining());
+ Assert.assertEquals(CAPACITY, buffer.getCapacity());
+ Assert.assertTrue(buffer.hasRemaining());
+ buffer.fill((byte) 0x07);
+ buffer.setAndCheckStartPositionEnd(10L, 15L, 100L);
+ Assert.assertEquals(15L, buffer.getPosition());
+ Assert.assertEquals(10L, buffer.getStart());
+ Assert.assertEquals(100L, buffer.getEnd());
+ Assert.assertEquals(85L, buffer.getRemaining());
+ Assert.assertEquals(CAPACITY, buffer.getCapacity());
+ buffer.fill((byte) 0x70);
+ buffer.resetPosition();
+ Assert.assertEquals(10L, buffer.getPosition());
+ for (int i = 0; i < 90; i++) {
+ if (i < 5) {
+ Assert.assertEquals(0x07, buffer.getByte());
+ } else {
+ Assert.assertEquals(0x70, buffer.getByte());
+ }
+ }
+ buffer.setAndCheckPosition(50);
+
+ Buffer duplicate = buffer.duplicate();
+ Assert.assertEquals(buffer.getStart(), duplicate.getStart());
+ Assert.assertEquals(buffer.getPosition(), duplicate.getPosition());
+ Assert.assertEquals(buffer.getEnd(), duplicate.getEnd());
+ Assert.assertEquals(buffer.getRemaining(), duplicate.getRemaining());
+ Assert.assertEquals(buffer.getCapacity(), duplicate.getCapacity());
+
+ duplicate.resetPosition();
+ for (int i = 0; i < 90; i++) {
+ if (i < 5) {
+ Assert.assertEquals(0x07, duplicate.getByte());
+ } else {
+ Assert.assertEquals(0x70, duplicate.getByte());
+ }
+ }
+
+ Buffer region = buffer.region(5L, 105L, buffer.getTypeByteOrder());
+ Assert.assertEquals(0, region.getStart());
+ Assert.assertEquals(0, region.getPosition());
+ Assert.assertEquals(105L, region.getEnd());
+ Assert.assertEquals(105L, region.getRemaining());
+ Assert.assertEquals(105L, region.getCapacity());
+
+ for (int i = 0; i < 105; i++) {
+ if (i < 10) {
+ Assert.assertEquals(0x07, region.getByte());
+ } else if (i < 95) {
+ Assert.assertEquals(0x70, region.getByte());
+ } else {
+ Assert.assertEquals(0x07, region.getByte());
+ }
+ }
+ }
+
+ @Test
+ public void testFill()
+ {
+ WritableBuffer buffer = getBuffer();
+ WritableBuffer anotherBuffer = getBuffer();
+
+ buffer.fill((byte) 0x0F);
+ anotherBuffer.fill((byte) 0x0F);
+ Assert.assertTrue(buffer.equalTo(0L, anotherBuffer, 0L, CAPACITY));
+
+ anotherBuffer.setPosition(100);
+ anotherBuffer.clear();
+ Assert.assertFalse(buffer.equalTo(0L, anotherBuffer, 0L, CAPACITY));
+ Assert.assertTrue(buffer.equalTo(0L, anotherBuffer, 0L, 100L));
+ }
+
+ private WritableBuffer getBuffer()
+ {
+ return getBuffer(CAPACITY);
+ }
+
+ private WritableBuffer getBuffer(int capacity)
+ {
+ final ByteBuffer aBuffer =
ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN);
+ SafeWritableBuffer memory = new SafeWritableBuffer(aBuffer);
+ return memory;
+ }
+}
diff --git
a/processing/src/test/java/org/apache/druid/segment/data/SafeWritableMemoryTest.java
b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableMemoryTest.java
new file mode 100644
index 0000000000..786443f43e
--- /dev/null
+++
b/processing/src/test/java/org/apache/druid/segment/data/SafeWritableMemoryTest.java
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.data;
+
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.datasketches.memory.internal.UnsafeUtil;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.CharArrayWriter;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public class SafeWritableMemoryTest
+{
+ private static final int CAPACITY = 1024;
+
+ @Test
+ public void testPutAndGet()
+ {
+ final WritableMemory memory = getMemory();
+ memory.putByte(3L, (byte) 0x01);
+ Assert.assertEquals(memory.getByte(3L), 0x01);
+
+ memory.putBoolean(1L, true);
+ Assert.assertTrue(memory.getBoolean(1L));
+ memory.putBoolean(1L, false);
+ Assert.assertFalse(memory.getBoolean(1L));
+
+ memory.putChar(10L, 'c');
+ Assert.assertEquals('c', memory.getChar(10L));
+
+ memory.putDouble(14L, 3.3);
+ Assert.assertEquals(3.3, memory.getDouble(14L), 0.0);
+
+ memory.putFloat(27L, 3.3f);
+ Assert.assertEquals(3.3f, memory.getFloat(27L), 0.0);
+
+ memory.putInt(11L, 1234);
+ Assert.assertEquals(1234, memory.getInt(11L));
+
+ memory.putLong(500L, 500L);
+ Assert.assertEquals(500L, memory.getLong(500L));
+
+ memory.putShort(11L, (short) 15);
+ Assert.assertEquals(15, memory.getShort(11L));
+
+ long l = memory.getAndSetLong(900L, 10L);
+ Assert.assertEquals(0L, l);
+ l = memory.getAndSetLong(900L, 100L);
+ Assert.assertEquals(10L, l);
+ l = memory.getAndAddLong(900L, 10L);
+ Assert.assertEquals(100L, l);
+ Assert.assertEquals(110L, memory.getLong(900L));
+ Assert.assertTrue(memory.compareAndSwapLong(900L, 110L, 120L));
+ Assert.assertFalse(memory.compareAndSwapLong(900L, 110L, 120L));
+ Assert.assertEquals(120L, memory.getLong(900L));
+ }
+
+ @Test
+ public void testPutAndGetArrays()
+ {
+ final WritableMemory memory = getMemory();
+ final byte[] b1 = new byte[]{0x01, 0x02, 0x08, 0x08};
+ final byte[] b2 = new byte[b1.length];
+ memory.putByteArray(12L, b1, 0, 3);
+ memory.putByteArray(15L, b1, 3, 1);
+ memory.getByteArray(12L, b2, 0, 3);
+ memory.getByteArray(15L, b2, 3, 1);
+ Assert.assertArrayEquals(b1, b2);
+
+ final boolean[] bool1 = new boolean[]{true, false, false, true};
+ final boolean[] bool2 = new boolean[bool1.length];
+ memory.putBooleanArray(100L, bool1, 0, 2);
+ memory.putBooleanArray(102L, bool1, 2, 2);
+ memory.getBooleanArray(100L, bool2, 0, 2);
+ memory.getBooleanArray(102L, bool2, 2, 2);
+ Assert.assertArrayEquals(bool1, bool2);
+
+ final char[] chars1 = new char[]{'a', 'b', 'c', 'd'};
+ final char[] chars2 = new char[chars1.length];
+ memory.putCharArray(10L, chars1, 0, 4);
+ memory.getCharArray(10L, chars2, 0, chars1.length);
+ Assert.assertArrayEquals(chars1, chars2);
+
+ final double[] double1 = new double[]{1.1, -2.2, 3.3, 4.4};
+ final double[] double2 = new double[double1.length];
+ memory.putDoubleArray(100L, double1, 0, 1);
+ memory.putDoubleArray(100L + Double.BYTES, double1, 1, 3);
+ memory.getDoubleArray(100L, double2, 0, 2);
+ memory.getDoubleArray(100L + (2 * Double.BYTES), double2, 2, 2);
+ for (int i = 0; i < double1.length; i++) {
+ Assert.assertEquals(double1[i], double2[i], 0.0);
+ }
+
+ final float[] float1 = new float[]{1.1f, 2.2f, -3.3f, 4.4f};
+ final float[] float2 = new float[float1.length];
+ memory.putFloatArray(100L, float1, 0, 1);
+ memory.putFloatArray(100L + Float.BYTES, float1, 1, 3);
+ memory.getFloatArray(100L, float2, 0, 2);
+ memory.getFloatArray(100L + (2 * Float.BYTES), float2, 2, 2);
+ for (int i = 0; i < float1.length; i++) {
+ Assert.assertEquals(float1[i], float2[i], 0.0);
+ }
+
+ final int[] ints1 = new int[]{1, 2, -3, 4};
+ final int[] ints2 = new int[ints1.length];
+ memory.putIntArray(100L, ints1, 0, 1);
+ memory.putIntArray(100L + Integer.BYTES, ints1, 1, 3);
+ memory.getIntArray(100L, ints2, 0, 2);
+ memory.getIntArray(100L + (2 * Integer.BYTES), ints2, 2, 2);
+ Assert.assertArrayEquals(ints1, ints2);
+
+ final long[] longs1 = new long[]{1L, -2L, 3L, -14L};
+ final long[] longs2 = new long[ints1.length];
+ memory.putLongArray(100L, longs1, 0, 1);
+ memory.putLongArray(100L + Long.BYTES, longs1, 1, 3);
+ memory.getLongArray(100L, longs2, 0, 2);
+ memory.getLongArray(100L + (2 * Long.BYTES), longs2, 2, 2);
+ Assert.assertArrayEquals(longs1, longs2);
+
+ final short[] shorts1 = new short[]{1, -2, 3, -14};
+ final short[] shorts2 = new short[ints1.length];
+ memory.putShortArray(100L, shorts1, 0, 1);
+ memory.putShortArray(100L + Short.BYTES, shorts1, 1, 3);
+ memory.getShortArray(100L, shorts2, 0, 2);
+ memory.getShortArray(100L + (2 * Short.BYTES), shorts2, 2, 2);
+ Assert.assertArrayEquals(shorts1, shorts2);
+ }
+
+ @Test
+ public void testFill()
+ {
+ final byte theByte = 0x01;
+ final byte anotherByte = 0x02;
+ final WritableMemory memory = getMemory();
+ final int halfWay = (int) (memory.getCapacity() / 2);
+
+ memory.fill(theByte);
+ for (int i = 0; i < memory.getCapacity(); i++) {
+ Assert.assertEquals(theByte, memory.getByte(i));
+ }
+
+ memory.fill(halfWay, memory.getCapacity() - halfWay, anotherByte);
+ for (int i = 0; i < memory.getCapacity(); i++) {
+ if (i < halfWay) {
+ Assert.assertEquals(theByte, memory.getByte(i));
+ } else {
+ Assert.assertEquals(anotherByte, memory.getByte(i));
+ }
+ }
+
+ memory.clear(halfWay, memory.getCapacity() - halfWay);
+ for (int i = 0; i < memory.getCapacity(); i++) {
+ if (i < halfWay) {
+ Assert.assertEquals(theByte, memory.getByte(i));
+ } else {
+ Assert.assertEquals(0, memory.getByte(i));
+ }
+ }
+
+ memory.setBits(halfWay - 1, anotherByte);
+ Assert.assertEquals(0x03, memory.getByte(halfWay - 1));
+ memory.clearBits(halfWay - 1, theByte);
+ Assert.assertEquals(anotherByte, memory.getByte(halfWay - 1));
+
+ memory.clear();
+ for (int i = 0; i < memory.getCapacity(); i++) {
+ Assert.assertEquals(0, memory.getByte(i));
+ }
+ }
+
+ @Test
+ public void testStringStuff() throws IOException
+ {
+ WritableMemory memory = getMemory();
+ String s1 = "hello ";
+ memory.putCharsToUtf8(10L, s1);
+
+ StringBuilder builder = new StringBuilder();
+ memory.getCharsFromUtf8(10L, s1.length(), builder);
+ Assert.assertEquals(s1, builder.toString());
+
+ CharArrayWriter someAppendable = new CharArrayWriter();
+ memory.getCharsFromUtf8(10L, s1.length(), someAppendable);
+ Assert.assertEquals(s1, someAppendable.toString());
+ }
+
+ @Test
+ public void testRegion()
+ {
+ WritableMemory memory = getMemory();
+ Assert.assertEquals(CAPACITY, memory.getCapacity());
+ Assert.assertEquals(0, memory.getCumulativeOffset());
+ Assert.assertEquals(10L, memory.getCumulativeOffset(10L));
+ Assert.assertThrows(
+ IllegalArgumentException.class,
+ () -> memory.checkValidAndBounds(CAPACITY - 10, 11L)
+ );
+
+ final byte[] someBytes = new byte[]{0x01, 0x02, 0x03, 0x04};
+ memory.putByteArray(10L, someBytes, 0, someBytes.length);
+
+ Memory region = memory.region(10L, someBytes.length);
+ Assert.assertEquals(someBytes.length, region.getCapacity());
+ Assert.assertEquals(0, region.getCumulativeOffset());
+ Assert.assertEquals(2L, region.getCumulativeOffset(2L));
+ Assert.assertThrows(
+ IllegalArgumentException.class,
+ () -> region.checkValidAndBounds(2L, 4L)
+ );
+
+ final byte[] andBack = new byte[someBytes.length];
+ region.getByteArray(0L, andBack, 0, someBytes.length);
+ Assert.assertArrayEquals(someBytes, andBack);
+
+ Memory differentOrderRegion = memory.region(10L, someBytes.length,
ByteOrder.BIG_ENDIAN);
+ // different order
+
Assert.assertFalse(region.isByteOrderCompatible(differentOrderRegion.getTypeByteOrder()));
+ // contents are equal tho
+ Assert.assertTrue(region.equalTo(0L, differentOrderRegion, 0L,
someBytes.length));
+ }
+
+ @Test
+ public void testCompareAndEquals()
+ {
+ WritableMemory memory = getMemory();
+ final byte[] someBytes = new byte[]{0x01, 0x02, 0x03, 0x04};
+ final byte[] shorterSameBytes = new byte[]{0x01, 0x02, 0x03};
+ final byte[] differentBytes = new byte[]{0x02, 0x02, 0x03, 0x04};
+ memory.putByteArray(10L, someBytes, 0, someBytes.length);
+ memory.putByteArray(400L, someBytes, 0, someBytes.length);
+ memory.putByteArray(200L, shorterSameBytes, 0, shorterSameBytes.length);
+ memory.putByteArray(500L, differentBytes, 0, differentBytes.length);
+
+ Assert.assertEquals(0, memory.compareTo(10L, someBytes.length, memory,
400L, someBytes.length));
+ Assert.assertEquals(4, memory.compareTo(10L, someBytes.length, memory,
200L, someBytes.length));
+ Assert.assertEquals(-1, memory.compareTo(10L, someBytes.length, memory,
500L, differentBytes.length));
+
+ WritableMemory memory2 = getMemory();
+ memory2.putByteArray(0L, someBytes, 0, someBytes.length);
+
+ Assert.assertEquals(0, memory.compareTo(10L, someBytes.length, memory2,
0L, someBytes.length));
+
+ Assert.assertTrue(memory.equalTo(10L, memory2, 0L, someBytes.length));
+
+ WritableMemory memory3 = getMemory();
+ memory2.copyTo(0L, memory3, 0L, CAPACITY);
+ Assert.assertTrue(memory2.equalTo(0L, memory3, 0L, CAPACITY));
+ }
+
+ @Test
+ public void testHash()
+ {
+ WritableMemory memory = getMemory();
+ final long[] someLongs = new long[]{1L, 10L, 100L, 1000L, 10000L};
+ final int[] someInts = new int[]{1, 2, 3};
+ final byte[] someBytes = new byte[]{0x01, 0x02, 0x03};
+ final int longsLength = Long.BYTES * someLongs.length;
+ final int someIntsLength = Integer.BYTES * someInts.length;
+ final int totalLength = longsLength + someIntsLength + someBytes.length;
+ memory.putLongArray(2L, someLongs, 0, someLongs.length);
+ memory.putIntArray(2L + longsLength, someInts, 0, someInts.length);
+ memory.putByteArray(2L + longsLength + someIntsLength, someBytes, 0,
someBytes.length);
+ Memory memory2 = Memory.wrap(memory.getByteBuffer(),
ByteOrder.LITTLE_ENDIAN);
+ Assert.assertEquals(
+ memory2.xxHash64(2L, totalLength, 0),
+ memory.xxHash64(2L, totalLength, 0)
+ );
+
+ Assert.assertEquals(
+ memory2.xxHash64(2L, 0),
+ memory.xxHash64(2L, 0)
+ );
+ }
+
+ @Test
+ public void testToHexString()
+ {
+
+ final byte[] bytes = new byte[]{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
0x07};
+ final WritableMemory memory = getMemory(bytes.length);
+ memory.putByteArray(0L, bytes, 0, bytes.length);
+ final long hcode = memory.hashCode() & 0XFFFFFFFFL;
+ final long bufferhcode = memory.getByteBuffer().hashCode() & 0XFFFFFFFFL;
+ final long reqhcode = memory.getMemoryRequestServer().hashCode() &
0XFFFFFFFFL;
+ Assert.assertEquals(
+ "### SafeWritableMemory SUMMARY ###\n"
+ + "Header Comment : test memory dump\n"
+ + "Call Parameters : .toHexString(..., 0, 8), hashCode: " + hcode
+ "\n"
+ + "UnsafeObj, hashCode : null\n"
+ + "UnsafeObjHeader : 0\n"
+ + "ByteBuf, hashCode : HeapByteBuffer, " + bufferhcode + "\n"
+ + "RegionOffset : 0\n"
+ + "Capacity : 8\n"
+ + "CumBaseOffset : 0\n"
+ + "MemReq, hashCode : HeapByteBufferMemoryRequestServer, " +
reqhcode + "\n"
+ + "Valid : true\n"
+ + "Read Only : false\n"
+ + "Type Byte Order : LITTLE_ENDIAN\n"
+ + "Native Byte Order : LITTLE_ENDIAN\n"
+ + "JDK Runtime Version : " + UnsafeUtil.JDK + "\n"
+ + "Data, littleEndian : 0 1 2 3 4 5 6 7\n"
+ + " 0: 00 01 02 03 04 05 06 07 \n",
+ memory.toHexString("test memory dump", 0, bytes.length)
+ );
+ }
+
+ @Test
+ public void testMisc()
+ {
+ WritableMemory memory = getMemory(10);
+ WritableMemory memory2 = memory.getMemoryRequestServer().request(memory,
20);
+ Assert.assertEquals(20, memory2.getCapacity());
+
+ Assert.assertFalse(memory2.hasArray());
+
+ Assert.assertFalse(memory2.isReadOnly());
+ Assert.assertFalse(memory2.isDirect());
+ Assert.assertTrue(memory2.isValid());
+ Assert.assertTrue(memory2.hasByteBuffer());
+
+ Assert.assertFalse(memory2.isSameResource(memory));
+ Assert.assertTrue(memory2.isSameResource(memory2));
+
+ // does nothing
+ memory.getMemoryRequestServer().requestClose(memory, memory2);
+ }
+
+ private WritableMemory getMemory()
+ {
+ return getMemory(CAPACITY);
+ }
+
+ private WritableMemory getMemory(int capacity)
+ {
+ final ByteBuffer aBuffer =
ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN);
+ return SafeWritableMemory.wrap(aBuffer);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]