This is an automated email from the ASF dual-hosted git repository.

karan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new d52391eea22 Update PolicyEnforcer to only validate table based 
segment. (#18018)
d52391eea22 is described below

commit d52391eea22e18916bacfc5e2e398c2e669f2d99
Author: Cece Mei <[email protected]>
AuthorDate: Mon May 19 20:30:21 2025 -0700

    Update PolicyEnforcer to only validate table based segment. (#18018)
---
 .../druid/msq/sql/MSQTaskQueryMakerTest.java       | 73 ++++++++++++++++++++++
 .../apache/druid/query/policy/PolicyEnforcer.java  | 11 +++-
 .../RestrictAllTablesPolicyEnforcerTest.java       | 33 ++++++++++
 3 files changed, 114 insertions(+), 3 deletions(-)

diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
index 88236aa63e2..4005f08ef2e 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java
@@ -58,7 +58,10 @@ import org.apache.druid.msq.test.MSQTestTaskActionClient;
 import org.apache.druid.query.DruidProcessingConfig;
 import org.apache.druid.query.Druids;
 import org.apache.druid.query.ForwardingQueryProcessingPool;
+import org.apache.druid.query.InlineDataSource;
 import org.apache.druid.query.JoinDataSource;
+import org.apache.druid.query.LookupDataSource;
+import org.apache.druid.query.OrderBy;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryContext;
 import org.apache.druid.query.QueryDataSource;
@@ -389,6 +392,76 @@ public class MSQTaskQueryMakerTest
     );
   }
 
+
+  @Test
+  public void testInlineDataSourcePassedPolicyValidation() throws Exception
+  {
+    // Arrange
+    policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
+    RowSignature resultSignature = RowSignature.builder()
+                                               .add("EXPR$0", ColumnType.LONG)
+                                               .build();
+    fieldMapping = buildFieldMapping(resultSignature);
+    InlineDataSource inlineDataSource = InlineDataSource.fromIterable(
+        ImmutableList.of(new Object[]{2L}),
+        resultSignature
+    );
+    Query query = new 
Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+                                               .dataSource(inlineDataSource)
+                                               .eternityInterval()
+                                               
.columns(resultSignature.getColumnNames())
+                                               
.columnTypes(resultSignature.getColumnTypes())
+                                               .build();
+    DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature);
+    // Act
+    msqTaskQueryMaker = getMSQTaskQueryMaker();
+    QueryResponse<Object[]> response = 
msqTaskQueryMaker.runQuery(druidQueryMock);
+    // Assert
+    String taskId = (String) 
Iterables.getOnlyElement(response.getResults().toList())[0];
+    MSQTaskReportPayload payload = (MSQTaskReportPayload) 
fakeOverlordClient.taskReportAsMap(taskId)
+                                                                            
.get()
+                                                                            
.get(MSQTaskReport.REPORT_KEY)
+                                                                            
.getPayload();
+    Assert.assertTrue(payload.getStatus().getStatus().isSuccess());
+    ImmutableList<Object[]> expectedResults = ImmutableList.of(new 
Object[]{2L});
+    assertResultsEquals("select 1 + 1", expectedResults, 
payload.getResults().getResults());
+  }
+
+  @Test
+  public void testLookupDataSourcePassedPolicyValidation() throws Exception
+  {
+    // Arrange
+    policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
+    final RowSignature resultSignature = RowSignature.builder().add("v", 
ColumnType.STRING).build();
+    fieldMapping = buildFieldMapping(resultSignature);
+    Query query = new 
Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+                                               .eternityInterval()
+                                               .dataSource(new 
LookupDataSource("lookyloo"))
+                                               
.columns(resultSignature.getColumnNames())
+                                               
.columnTypes(resultSignature.getColumnTypes())
+                                               
.orderBy(ImmutableList.of(OrderBy.ascending("v")))
+                                               .build();
+    DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature);
+    // Act
+    msqTaskQueryMaker = getMSQTaskQueryMaker();
+    QueryResponse<Object[]> response = 
msqTaskQueryMaker.runQuery(druidQueryMock);
+    // Assert
+    String taskId = (String) 
Iterables.getOnlyElement(response.getResults().toList())[0];
+    MSQTaskReportPayload payload = (MSQTaskReportPayload) 
fakeOverlordClient.taskReportAsMap(taskId)
+                                                                            
.get()
+                                                                            
.get(MSQTaskReport.REPORT_KEY)
+                                                                            
.getPayload();
+    // Assert
+    Assert.assertTrue(payload.getStatus().getStatus().isSuccess());
+    ImmutableList<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{"mysteryvalue"},
+        new Object[]{"x6"},
+        new Object[]{"xa"},
+        new Object[]{"xabc"}
+    );
+    assertResultsEquals("select v from lookyloo", expectedResults, 
payload.getResults().getResults());
+  }
+
   @Test
   public void testJoinFailWithPolicyValidationOnLeftChild() throws Exception
   {
diff --git 
a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java 
b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
index 4b34b96c23b..f57bbc367d6 100644
--- a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
+++ b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java
@@ -27,6 +27,7 @@ import org.apache.druid.query.DataSource;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.segment.ReferenceCountingSegment;
 import org.apache.druid.segment.SegmentReference;
+import org.apache.druid.timeline.SegmentId;
 
 /**
  * Interface for enforcing policies on data sources and segments in Druid 
queries.
@@ -77,14 +78,18 @@ public interface PolicyEnforcer
    */
   default void validateOrElseThrow(ReferenceCountingSegment segment, Policy 
policy) throws DruidException
   {
-    // Validation will always fail on lookups, external, and inline segments, 
because they will not have policies applied (except for NoopPolicyEnforcer).
-    // This is a temporary solution since we don't have a perfect way to 
identify segments that are backed by a regular table yet.
+    SegmentId segmentId = segment.getId();
+    // SegmentId is null if the segment is not table based, or is already 
closed
+    if (segmentId == null) {
+      return;
+    }
+
     if (validate(policy)) {
       return;
     }
     throw DruidException.forPersona(DruidException.Persona.OPERATOR)
                         .ofCategory(DruidException.Category.FORBIDDEN)
-                        .build("Failed security validation with segment [%s]", 
segment.getId());
+                        .build("Failed security validation with segment [%s]", 
segmentId);
   }
 
   /**
diff --git 
a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
 
b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
index 86c597d6e83..e5a579524c5 100644
--- 
a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java
@@ -23,13 +23,17 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import org.apache.druid.error.DruidException;
 import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.query.InlineDataSource;
 import org.apache.druid.query.RestrictedDataSource;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.filter.NullFilter;
 import org.apache.druid.segment.ReferenceCountingSegment;
+import org.apache.druid.segment.RowBasedSegment;
 import org.apache.druid.segment.Segment;
 import org.apache.druid.segment.TestHelper;
 import org.apache.druid.segment.TestSegmentUtils.SegmentForTesting;
+import org.apache.druid.segment.column.RowSignature;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -97,6 +101,35 @@ public class RestrictAllTablesPolicyEnforcerTest
     policyEnforcer.validateOrElseThrow(segment, policy);
   }
 
+  @Test
+  public void test_validate_allowNonTableSegments() throws Exception
+  {
+    final RestrictAllTablesPolicyEnforcer policyEnforcer = new 
RestrictAllTablesPolicyEnforcer(null);
+
+    // Test validate segment, success for inline segment
+    final InlineDataSource inlineDataSource = 
InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty());
+
+    final Segment inlineSegment = new RowBasedSegment<>(
+        Sequences.simple(inlineDataSource.getRows()),
+        inlineDataSource.rowAdapter(),
+        inlineDataSource.getRowSignature()
+    );
+    ReferenceCountingSegment segment = 
ReferenceCountingSegment.wrapRootGenerationSegment(inlineSegment);
+
+    policyEnforcer.validateOrElseThrow(segment, null);
+  }
+
+  @Test
+  public void test_validate_closedSegment() throws Exception
+  {
+    final RestrictAllTablesPolicyEnforcer policyEnforcer = new 
RestrictAllTablesPolicyEnforcer(null);
+    Segment baseSegment = new SegmentForTesting("table", Intervals.ETERNITY, 
"1");
+    ReferenceCountingSegment segment = 
ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment);
+    segment.close();
+
+    policyEnforcer.validateOrElseThrow(segment, null);
+  }
+
   @Test
   public void test_validate_withAllowedPolicies() throws Exception
   {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to