This is an automated email from the ASF dual-hosted git repository.
karan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new d52391eea22 Update PolicyEnforcer to only validate table based
segment. (#18018)
d52391eea22 is described below
commit d52391eea22e18916bacfc5e2e398c2e669f2d99
Author: Cece Mei <[email protected]>
AuthorDate: Mon May 19 20:30:21 2025 -0700
Update PolicyEnforcer to only validate table based segment. (#18018)
---
.../druid/msq/sql/MSQTaskQueryMakerTest.java | 73 ++++++++++++++++++++++
.../apache/druid/query/policy/PolicyEnforcer.java | 11 +++-
.../RestrictAllTablesPolicyEnforcerTest.java | 33 ++++++++++
3 files changed, 114 insertions(+), 3 deletions(-)
diff --git
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
index 88236aa63e2..4005f08ef2e 100644
---
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
+++
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
@@ -58,7 +58,10 @@ import org.apache.druid.msq.test.MSQTestTaskActionClient;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Druids;
import org.apache.druid.query.ForwardingQueryProcessingPool;
+import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
+import org.apache.druid.query.LookupDataSource;
+import org.apache.druid.query.OrderBy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryDataSource;
@@ -389,6 +392,76 @@ public class MSQTaskQueryMakerTest
);
}
+
+ @Test
+ public void testInlineDataSourcePassedPolicyValidation() throws Exception
+ {
+ // Arrange
+ policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
+ RowSignature resultSignature = RowSignature.builder()
+ .add("EXPR$0", ColumnType.LONG)
+ .build();
+ fieldMapping = buildFieldMapping(resultSignature);
+ InlineDataSource inlineDataSource = InlineDataSource.fromIterable(
+ ImmutableList.of(new Object[]{2L}),
+ resultSignature
+ );
+ Query query = new
Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+ .dataSource(inlineDataSource)
+ .eternityInterval()
+
.columns(resultSignature.getColumnNames())
+
.columnTypes(resultSignature.getColumnTypes())
+ .build();
+ DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature);
+ // Act
+ msqTaskQueryMaker = getMSQTaskQueryMaker();
+ QueryResponse<Object[]> response =
msqTaskQueryMaker.runQuery(druidQueryMock);
+ // Assert
+ String taskId = (String)
Iterables.getOnlyElement(response.getResults().toList())[0];
+ MSQTaskReportPayload payload = (MSQTaskReportPayload)
fakeOverlordClient.taskReportAsMap(taskId)
+
.get()
+
.get(MSQTaskReport.REPORT_KEY)
+
.getPayload();
+ Assert.assertTrue(payload.getStatus().getStatus().isSuccess());
+ ImmutableList<Object[]> expectedResults = ImmutableList.of(new
Object[]{2L});
+ assertResultsEquals("select 1 + 1", expectedResults,
payload.getResults().getResults());
+ }
+
+ @Test
+ public void testLookupDataSourcePassedPolicyValidation() throws Exception
+ {
+ // Arrange
+ policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
+ final RowSignature resultSignature = RowSignature.builder().add("v",
ColumnType.STRING).build();
+ fieldMapping = buildFieldMapping(resultSignature);
+ Query query = new
Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+ .eternityInterval()
+ .dataSource(new
LookupDataSource("lookyloo"))
+
.columns(resultSignature.getColumnNames())
+
.columnTypes(resultSignature.getColumnTypes())
+
.orderBy(ImmutableList.of(OrderBy.ascending("v")))
+ .build();
+ DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature);
+ // Act
+ msqTaskQueryMaker = getMSQTaskQueryMaker();
+ QueryResponse<Object[]> response =
msqTaskQueryMaker.runQuery(druidQueryMock);
+ // Assert
+ String taskId = (String)
Iterables.getOnlyElement(response.getResults().toList())[0];
+ MSQTaskReportPayload payload = (MSQTaskReportPayload)
fakeOverlordClient.taskReportAsMap(taskId)
+
.get()
+
.get(MSQTaskReport.REPORT_KEY)
+
.getPayload();
+ // Assert
+ Assert.assertTrue(payload.getStatus().getStatus().isSuccess());
+ ImmutableList<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{"mysteryvalue"},
+ new Object[]{"x6"},
+ new Object[]{"xa"},
+ new Object[]{"xabc"}
+ );
+ assertResultsEquals("select v from lookyloo", expectedResults,
payload.getResults().getResults());
+ }
+
@Test
public void testJoinFailWithPolicyValidationOnLeftChild() throws Exception
{
diff --git
a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
index 4b34b96c23b..f57bbc367d6 100644
--- a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
+++ b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
@@ -27,6 +27,7 @@ import org.apache.druid.query.DataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.segment.ReferenceCountingSegment;
import org.apache.druid.segment.SegmentReference;
+import org.apache.druid.timeline.SegmentId;
/**
* Interface for enforcing policies on data sources and segments in Druid
queries.
@@ -77,14 +78,18 @@ public interface PolicyEnforcer
*/
default void validateOrElseThrow(ReferenceCountingSegment segment, Policy
policy) throws DruidException
{
- // Validation will always fail on lookups, external, and inline segments,
because they will not have policies applied (except for NoopPolicyEnforcer).
- // This is a temporary solution since we don't have a perfect way to
identify segments that are backed by a regular table yet.
+ SegmentId segmentId = segment.getId();
+ // SegmentId is null if the segment is not table based, or is already
closed
+ if (segmentId == null) {
+ return;
+ }
+
if (validate(policy)) {
return;
}
throw DruidException.forPersona(DruidException.Persona.OPERATOR)
.ofCategory(DruidException.Category.FORBIDDEN)
- .build("Failed security validation with segment [%s]",
segment.getId());
+ .build("Failed security validation with segment [%s]",
segmentId);
}
/**
diff --git
a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
index 86c597d6e83..e5a579524c5 100644
---
a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
+++
b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
@@ -23,13 +23,17 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.RestrictedDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.segment.ReferenceCountingSegment;
+import org.apache.druid.segment.RowBasedSegment;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.TestSegmentUtils.SegmentForTesting;
+import org.apache.druid.segment.column.RowSignature;
import org.junit.Assert;
import org.junit.Test;
@@ -97,6 +101,35 @@ public class RestrictAllTablesPolicyEnforcerTest
policyEnforcer.validateOrElseThrow(segment, policy);
}
+ @Test
+ public void test_validate_allowNonTableSegments() throws Exception
+ {
+ final RestrictAllTablesPolicyEnforcer policyEnforcer = new
RestrictAllTablesPolicyEnforcer(null);
+
+ // Test validate segment, success for inline segment
+ final InlineDataSource inlineDataSource =
InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty());
+
+ final Segment inlineSegment = new RowBasedSegment<>(
+ Sequences.simple(inlineDataSource.getRows()),
+ inlineDataSource.rowAdapter(),
+ inlineDataSource.getRowSignature()
+ );
+ ReferenceCountingSegment segment =
ReferenceCountingSegment.wrapRootGenerationSegment(inlineSegment);
+
+ policyEnforcer.validateOrElseThrow(segment, null);
+ }
+
+ @Test
+ public void test_validate_closedSegment() throws Exception
+ {
+ final RestrictAllTablesPolicyEnforcer policyEnforcer = new
RestrictAllTablesPolicyEnforcer(null);
+ Segment baseSegment = new SegmentForTesting("table", Intervals.ETERNITY,
"1");
+ ReferenceCountingSegment segment =
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment);
+ segment.close();
+
+ policyEnforcer.validateOrElseThrow(segment, null);
+ }
+
@Test
public void test_validate_withAllowedPolicies() throws Exception
{
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]