gianm closed pull request #4949: Add limit to query result buffering queue
URL: https://github.com/apache/incubator-druid/pull/4949
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/docs/content/configuration/broker.md
b/docs/content/configuration/broker.md
index af97d2dffa5..c34de70c6b3 100644
--- a/docs/content/configuration/broker.md
+++ b/docs/content/configuration/broker.md
@@ -44,6 +44,8 @@ Druid uses Jetty to serve HTTP requests.
|`druid.broker.http.numConnections`|Size of connection pool for the Broker to
connect to historical and real-time processes. If there are more queries than
this number that all need to speak to the same node, then they will queue
up.|20|
|`druid.broker.http.compressionCodec`|Compression codec the Broker uses to
communicate with historical and real-time processes. May be "gzip" or
"identity".|gzip|
|`druid.broker.http.readTimeout`|The timeout for data reads from historical
and real-time processes.|PT15M|
+|`druid.broker.http.maxBufferSizeBytes`|The maximum number of bytes collected
from query nodes to buffer at broker to execute a query. Compared to
`maxScatterGatherBytes`, this advanced configuration protects broker having
OOMs without aborting the query immediately when reaching maximum buffer size
limit. This limit can be further reduced at query time using
`maxBufferSizeBytes` in the context. |Long.MAX_VALUE|
+|`druid.broker.http.queryBufferingTimeout`|The timeout in milliseconds, beyond
which broker will stop waiting for buffer space to accept data from query nodes
and abort the query. This timeout can be further adjusted at query time using
`queryBufferingTimeout` in the context |300000|
#### Retry Policy
diff --git a/processing/src/main/java/io/druid/query/QueryContexts.java
b/processing/src/main/java/io/druid/query/QueryContexts.java
index b56812d8b0e..e85e3dea4d2 100644
--- a/processing/src/main/java/io/druid/query/QueryContexts.java
+++ b/processing/src/main/java/io/druid/query/QueryContexts.java
@@ -33,6 +33,8 @@
public static final String MAX_SCATTER_GATHER_BYTES_KEY =
"maxScatterGatherBytes";
public static final String DEFAULT_TIMEOUT_KEY = "defaultTimeout";
public static final String CHUNK_PERIOD_KEY = "chunkPeriod";
+ public static final String MAX_BUFFER_SIZE_BYTES = "maxBufferSizeBytes";
+ public static final String QUERY_BUFFERING_TIMEOUT_KEY =
"queryBufferingTimeout";
public static final boolean DEFAULT_BY_SEGMENT = false;
public static final boolean DEFAULT_POPULATE_CACHE = true;
@@ -112,31 +114,52 @@
return query.getContextValue(CHUNK_PERIOD_KEY, "P0D");
}
- public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long
maxScatterGatherBytesLimit)
+ private static <T> Query<T> withCustomizedLimit(Query<T> query, String key,
long maxLimit)
{
- Object obj = query.getContextValue(MAX_SCATTER_GATHER_BYTES_KEY);
+ Object obj = query.getContextValue(key);
if (obj == null) {
- return
query.withOverriddenContext(ImmutableMap.of(MAX_SCATTER_GATHER_BYTES_KEY,
maxScatterGatherBytesLimit));
+ return query.withOverriddenContext(ImmutableMap.of(key, maxLimit));
} else {
long curr = ((Number) obj).longValue();
- if (curr > maxScatterGatherBytesLimit) {
- throw new IAE(
- "configured [%s = %s] is more than enforced limit of [%s].",
- MAX_SCATTER_GATHER_BYTES_KEY,
- curr,
- maxScatterGatherBytesLimit
- );
+ if (curr > maxLimit) {
+ String err = "configured [%s = %s] is more than enforced limit of
[%s].";
+ throw new IAE(err, key, curr, maxLimit);
} else {
return query;
}
}
}
+ public static <T> Query<T> withMaxScatterGatherBytes(Query<T> query, long
maxScatterGatherBytesLimit)
+ {
+ return withCustomizedLimit(query, MAX_SCATTER_GATHER_BYTES_KEY,
maxScatterGatherBytesLimit);
+ }
+
public static <T> long getMaxScatterGatherBytes(Query<T> query)
{
return parseLong(query, MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE);
}
+ public static <T> Query<T> withMaxBufferSizeBytes(Query<T> query, long
maxBufferSizeBytes)
+ {
+ return withCustomizedLimit(query, MAX_BUFFER_SIZE_BYTES,
maxBufferSizeBytes);
+ }
+
+ public static <T> long getMaxBufferSizeBytes(Query<T> query)
+ {
+ return parseLong(query, MAX_BUFFER_SIZE_BYTES, Long.MAX_VALUE);
+ }
+
+ public static <T> Query<T> withQueryBufferingTimeout(Query<T> query, long
bufferingTimeout)
+ {
+ return
query.withOverriddenContext(ImmutableMap.of(QUERY_BUFFERING_TIMEOUT_KEY,
bufferingTimeout));
+ }
+
+ public static <T> long getQueryBufferingTimeout(Query<T> query)
+ {
+ return parseLong(query, QUERY_BUFFERING_TIMEOUT_KEY,
DEFAULT_TIMEOUT_MILLIS);
+ }
+
public static <T> boolean hasTimeout(Query<T> query)
{
return getTimeout(query) != NO_TIMEOUT;
diff --git a/server/src/main/java/io/druid/client/DirectDruidClient.java
b/server/src/main/java/io/druid/client/DirectDruidClient.java
index 16d01a6be6b..f543dfcd13a 100644
--- a/server/src/main/java/io/druid/client/DirectDruidClient.java
+++ b/server/src/main/java/io/druid/client/DirectDruidClient.java
@@ -123,7 +123,9 @@
{
return (QueryType) QueryContexts.withMaxScatterGatherBytes(
QueryContexts.withDefaultTimeout(
- (Query) query,
+ QueryContexts.withQueryBufferingTimeout(
+ QueryContexts.withMaxBufferSizeBytes((Query) query,
serverConfig.getMaxBufferSizeBytes()),
+ serverConfig.getQueryBufferingTimeout()),
serverConfig.getDefaultQueryTimeout()
),
serverConfig.getMaxScatterGatherBytes()
@@ -214,14 +216,17 @@ public int getNumOpenConnections()
final long requestStartTimeNs = System.nanoTime();
- long timeoutAt = ((Long) context.get(QUERY_FAIL_TIME)).longValue();
- long maxScatterGatherBytes =
QueryContexts.getMaxScatterGatherBytes(query);
+ final long timeoutAt = ((Long) context.get(QUERY_FAIL_TIME)).longValue();
+ final long maxScatterGatherBytes =
QueryContexts.getMaxScatterGatherBytes(query);
+ final long maxBufferSizeBytes =
QueryContexts.getMaxBufferSizeBytes(query);
+ final long queryBufferingTimeout =
QueryContexts.getQueryBufferingTimeout(query);
AtomicLong totalBytesGathered = (AtomicLong)
context.get(QUERY_TOTAL_BYTES_GATHERED);
+ AtomicLong totalBufferedBytes = new AtomicLong(0);
final HttpResponseHandler<InputStream, InputStream> responseHandler =
new HttpResponseHandler<InputStream, InputStream>()
{
private final AtomicLong byteCount = new AtomicLong(0);
- private final BlockingQueue<InputStream> queue = new
LinkedBlockingQueue<>();
+ private final BlockingQueue<BufferedStream> queue = new
LinkedBlockingQueue<>();
private final AtomicBoolean done = new AtomicBoolean(false);
private final AtomicReference<String> fail = new AtomicReference<>();
@@ -241,7 +246,9 @@ public int getNumOpenConnections()
public ClientResponse<InputStream> handleResponse(HttpResponse
response)
{
checkQueryTimeout();
- checkTotalBytesLimit(response.getContent().readableBytes());
+
+ final long totalBytes = response.getContent().readableBytes();
+ checkTotalBytesLimit(totalBytes);
log.debug("Initial response from url[%s] for queryId[%s]", url,
query.getId());
responseStartTimeNs = System.nanoTime();
@@ -257,7 +264,9 @@ public int getNumOpenConnections()
)
);
}
- queue.put(new ChannelBufferInputStream(response.getContent()));
+
+ checkBufferCapacity(totalBytes, queryBufferingTimeout);
+ queue.put(new BufferedStream(new
ChannelBufferInputStream(response.getContent()), totalBytes));
}
catch (final IOException e) {
log.error(e, "Error parsing response context from url [%s]", url);
@@ -277,7 +286,7 @@ public int read() throws IOException
Thread.currentThread().interrupt();
throw Throwables.propagate(e);
}
- byteCount.addAndGet(response.getContent().readableBytes());
+ byteCount.addAndGet(totalBytes);
return ClientResponse.<InputStream>finished(
new SequenceInputStream(
new Enumeration<InputStream>()
@@ -305,9 +314,10 @@ public InputStream nextElement()
}
try {
- InputStream is = queue.poll(checkQueryTimeout(),
TimeUnit.MILLISECONDS);
- if (is != null) {
- return is;
+ BufferedStream stream =
queue.poll(checkQueryTimeout(), TimeUnit.MILLISECONDS);
+ totalBufferedBytes.addAndGet(-stream.getBytes());
+ if (stream.getInputStream() != null) {
+ return stream.getInputStream();
} else {
throw new RE("Query[%s] url[%s] timed out.",
query.getId(), url);
}
@@ -331,12 +341,12 @@ public InputStream nextElement()
final ChannelBuffer channelBuffer = chunk.getContent();
final int bytes = channelBuffer.readableBytes();
-
checkTotalBytesLimit(bytes);
if (bytes > 0) {
try {
- queue.put(new ChannelBufferInputStream(channelBuffer));
+ checkBufferCapacity(bytes, queryBufferingTimeout);
+ queue.put(new BufferedStream(new
ChannelBufferInputStream(channelBuffer), bytes));
}
catch (InterruptedException e) {
log.error(e, "Unable to put finalizing input stream into
Sequence queue for url [%s]", url);
@@ -370,7 +380,7 @@ public InputStream nextElement()
try {
// An empty byte array is put at the end to give the
SequenceInputStream.close() as something to close out
// after done is set to true, regardless of the rest of the
stream's state.
- queue.put(ByteSource.empty().openStream());
+ queue.put(new BufferedStream(ByteSource.empty().openStream(),
0));
}
catch (InterruptedException e) {
log.error(e, "Unable to put finalizing input stream into
Sequence queue for url [%s]", url);
@@ -404,7 +414,8 @@ private void setupResponseReadFailure(String msg, Throwable
th)
{
fail.set(msg);
queue.clear();
- queue.offer(new InputStream()
+ totalBufferedBytes.set(0);
+ queue.offer(new BufferedStream(new InputStream()
{
@Override
public int read() throws IOException
@@ -415,7 +426,7 @@ public int read() throws IOException
throw new IOException(msg);
}
}
- });
+ }, 0));
}
@@ -444,6 +455,28 @@ private void checkTotalBytesLimit(long bytes)
throw new RE(msg);
}
}
+
+ private long getBufferCapacity()
+ {
+ return maxBufferSizeBytes - totalBufferedBytes.get();
+ }
+
+ private void checkBufferCapacity(long requestedBytes, long timeoutMs)
+ {
+ final long startTimeMs = System.currentTimeMillis();
+ while (maxBufferSizeBytes < Long.MAX_VALUE && getBufferCapacity() <
requestedBytes) {
+ if (System.currentTimeMillis() - startTimeMs > timeoutMs) {
+ String msg = StringUtils.format(
+ "Query[%s] url[%s] max buffer limit reached: waiting for
free buffer timeout",
+ query.getId(),
+ url
+ );
+ setupResponseReadFailure(msg, null);
+ throw new RE(msg);
+ }
+ }
+ totalBufferedBytes.addAndGet(requestedBytes);
+ }
};
long timeLeft = timeoutAt - System.currentTimeMillis();
@@ -658,6 +691,28 @@ public void close() throws IOException
}
}
+ private static class BufferedStream
+ {
+ private final InputStream inputStream;
+ private final long bytes;
+
+ public BufferedStream(InputStream is, long size)
+ {
+ inputStream = is;
+ bytes = size;
+ }
+
+ public InputStream getInputStream()
+ {
+ return inputStream;
+ }
+
+ public long getBytes()
+ {
+ return bytes;
+ }
+ }
+
@Override
public String toString()
{
diff --git
a/server/src/main/java/io/druid/server/initialization/ServerConfig.java
b/server/src/main/java/io/druid/server/initialization/ServerConfig.java
index d5db2cd2c53..3fa157b0811 100644
--- a/server/src/main/java/io/druid/server/initialization/ServerConfig.java
+++ b/server/src/main/java/io/druid/server/initialization/ServerConfig.java
@@ -53,6 +53,14 @@
@Min(1)
private long maxScatterGatherBytes = Long.MAX_VALUE;
+ @JsonProperty
+ @Min(1)
+ private long maxBufferSizeBytes = Long.MAX_VALUE;
+
+ @JsonProperty
+ @Min(1)
+ private long queryBufferingTimeout = 300_000;
+
public int getNumThreads()
{
return numThreads;
@@ -83,6 +91,16 @@ public long getMaxScatterGatherBytes()
return maxScatterGatherBytes;
}
+ public long getMaxBufferSizeBytes()
+ {
+ return maxBufferSizeBytes;
+ }
+
+ public long getQueryBufferingTimeout()
+ {
+ return queryBufferingTimeout;
+ }
+
@Override
public boolean equals(Object o)
{
@@ -98,6 +116,8 @@ public boolean equals(Object o)
enableRequestLimit == that.enableRequestLimit &&
defaultQueryTimeout == that.defaultQueryTimeout &&
maxScatterGatherBytes == that.maxScatterGatherBytes &&
+ maxBufferSizeBytes == that.maxBufferSizeBytes &&
+ queryBufferingTimeout == that.queryBufferingTimeout &&
Objects.equals(maxIdleTime, that.maxIdleTime);
}
@@ -105,12 +125,25 @@ public boolean equals(Object o)
public int hashCode()
{
return Objects.hash(
- numThreads,
- queueSize,
- enableRequestLimit,
- maxIdleTime,
- defaultQueryTimeout,
- maxScatterGatherBytes
+ numThreads,
+ maxIdleTime,
+ defaultQueryTimeout,
+ maxScatterGatherBytes,
+ maxBufferSizeBytes,
+ queryBufferingTimeout
);
}
+
+ @Override
+ public String toString()
+ {
+ return "ServerConfig{" +
+ "numThreads=" + numThreads +
+ ", maxIdleTime=" + maxIdleTime +
+ ", defaultQueryTimeout=" + defaultQueryTimeout +
+ ", maxScatterGatherBytes=" + maxScatterGatherBytes +
+ ", maxBufferSizeBytes=" + maxBufferSizeBytes +
+ ", queryBufferingTimeout=" + queryBufferingTimeout +
+ '}';
+ }
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]