This is an automated email from the ASF dual-hosted git repository.

totalo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a7868d931d Add unsupported check for combine statement with encrypt 
columns (#30916)
6a7868d931d is described below

commit 6a7868d931dc12749533fd66dc740cc914384e1f
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Wed Apr 17 08:41:04 2024 +0800

    Add unsupported check for combine statement with encrypt columns (#30916)
---
 .../generator/EncryptProjectionTokenGenerator.java |  4 ++
 .../token/util/EncryptTokenGeneratorUtils.java     | 51 +++++++++++++++++++---
 .../EncryptProjectionTokenGeneratorTest.java       | 42 +++++++++++++++++-
 3 files changed, 90 insertions(+), 7 deletions(-)

diff --git 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGenerator.java
 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGenerator.java
index 553c503151f..45f496d8246 100644
--- 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGenerator.java
+++ 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGenerator.java
@@ -20,6 +20,7 @@ package 
org.apache.shardingsphere.encrypt.rewrite.token.generator;
 import lombok.Setter;
 import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseTypeAware;
 import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
+import 
org.apache.shardingsphere.encrypt.rewrite.token.util.EncryptTokenGeneratorUtils;
 import org.apache.shardingsphere.encrypt.rule.EncryptRule;
 import org.apache.shardingsphere.encrypt.rule.EncryptTable;
 import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
@@ -90,6 +91,9 @@ public final class EncryptProjectionTokenGenerator implements 
CollectionSQLToken
     }
     
     private void addGenerateSQLTokens(final Collection<SQLToken> sqlTokens, 
final SelectStatementContext selectStatementContext) {
+        ShardingSpherePreconditions.checkState(
+                !selectStatementContext.isContainsCombine() || 
!EncryptTokenGeneratorUtils.isContainsEncryptProjectionInCombineStatement(selectStatementContext,
 encryptRule),
+                () -> new UnsupportedSQLOperationException("Can not support 
encrypt projection in combine statement"));
         for (ProjectionSegment each : 
selectStatementContext.getSqlStatement().getProjections().getProjections()) {
             SubqueryType subqueryType = 
selectStatementContext.getSubqueryType();
             if (each instanceof ColumnProjectionSegment) {
diff --git 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
index b3d03354be4..d676858f033 100644
--- 
a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
+++ 
b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/util/EncryptTokenGeneratorUtils.java
@@ -25,13 +25,19 @@ import 
org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
 import org.apache.shardingsphere.encrypt.spi.EncryptAlgorithm;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
+import 
org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
+import 
org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
+import 
org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
 import 
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
 
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Iterator;
+import java.util.List;
 
 /**
  * Encrypt token generator utils.
@@ -53,21 +59,21 @@ public final class EncryptTokenGeneratorUtils {
             }
             EncryptAlgorithm leftColumnEncryptor = 
getColumnEncryptor(((ColumnSegment) each.getLeft()).getColumnBoundedInfo(), 
encryptRule);
             EncryptAlgorithm rightColumnEncryptor = 
getColumnEncryptor(((ColumnSegment) each.getRight()).getColumnBoundedInfo(), 
encryptRule);
-            if (!isSameEncryptor(leftColumnEncryptor, rightColumnEncryptor)) {
+            if (isDifferentEncryptor(leftColumnEncryptor, 
rightColumnEncryptor)) {
                 return false;
             }
         }
         return true;
     }
     
-    private static boolean isSameEncryptor(final EncryptAlgorithm 
leftColumnEncryptor, final EncryptAlgorithm rightColumnEncryptor) {
+    private static boolean isDifferentEncryptor(final EncryptAlgorithm 
leftColumnEncryptor, final EncryptAlgorithm rightColumnEncryptor) {
         if (null != leftColumnEncryptor && null != rightColumnEncryptor) {
             if 
(!leftColumnEncryptor.getType().equals(rightColumnEncryptor.getType())) {
-                return false;
+                return true;
             }
-            return leftColumnEncryptor.equals(rightColumnEncryptor);
+            return !leftColumnEncryptor.equals(rightColumnEncryptor);
         }
-        return null == leftColumnEncryptor && null == rightColumnEncryptor;
+        return null != leftColumnEncryptor || null != rightColumnEncryptor;
     }
     
     private static EncryptAlgorithm getColumnEncryptor(final 
ColumnSegmentBoundedInfo columnBoundedInfo, final EncryptRule encryptRule) {
@@ -103,10 +109,43 @@ public final class EncryptTokenGeneratorUtils {
                     ? new ColumnSegmentBoundedInfo(null, null, 
((ColumnProjection) projection).getOriginalTable(), ((ColumnProjection) 
projection).getOriginalColumn())
                     : new ColumnSegmentBoundedInfo(new 
IdentifierValue(projection.getColumnLabel()));
             EncryptAlgorithm rightColumnEncryptor = 
getColumnEncryptor(columnBoundedInfo, encryptRule);
-            if (!isSameEncryptor(leftColumnEncryptor, rightColumnEncryptor)) {
+            if (isDifferentEncryptor(leftColumnEncryptor, 
rightColumnEncryptor)) {
                 return false;
             }
         }
         return true;
     }
+    
+    /**
+     * Judge whether contains encrypt projection in combine statement or not.
+     * 
+     * @param selectStatementContext select statement context
+     * @param encryptRule encrypt rule
+     * @return whether contains encrypt projection in combine statement or not
+     */
+    public static boolean isContainsEncryptProjectionInCombineStatement(final 
SelectStatementContext selectStatementContext, final EncryptRule encryptRule) {
+        if 
(!selectStatementContext.getSqlStatement().getCombine().isPresent()) {
+            return false;
+        }
+        CombineSegment combineSegment = 
selectStatementContext.getSqlStatement().getCombine().get();
+        List<Projection> leftProjections = new 
ArrayList<>(selectStatementContext.getSubqueryContexts().get(combineSegment.getLeft().getStartIndex()).getProjectionsContext().getExpandProjections());
+        List<Projection> rightProjections = new 
ArrayList<>(selectStatementContext.getSubqueryContexts().get(combineSegment.getRight().getStartIndex()).getProjectionsContext().getExpandProjections());
+        ShardingSpherePreconditions.checkState(leftProjections.size() == 
rightProjections.size(), () -> new UnsupportedSQLOperationException("Column 
projections must be same for combine statement"));
+        for (int i = 0; i < leftProjections.size(); i++) {
+            Projection leftProjection = leftProjections.get(i);
+            Projection rightProjection = rightProjections.get(i);
+            ColumnSegmentBoundedInfo leftColumnBoundedInfo = leftProjection 
instanceof ColumnProjection
+                    ? new ColumnSegmentBoundedInfo(null, null, 
((ColumnProjection) leftProjection).getOriginalTable(), ((ColumnProjection) 
leftProjection).getOriginalColumn())
+                    : new ColumnSegmentBoundedInfo(new 
IdentifierValue(leftProjection.getColumnLabel()));
+            ColumnSegmentBoundedInfo rightColumnBoundedInfo = rightProjection 
instanceof ColumnProjection
+                    ? new ColumnSegmentBoundedInfo(null, null, 
((ColumnProjection) rightProjection).getOriginalTable(), ((ColumnProjection) 
rightProjection).getOriginalColumn())
+                    : new ColumnSegmentBoundedInfo(new 
IdentifierValue(rightProjection.getColumnLabel()));
+            EncryptAlgorithm leftColumnEncryptor = 
getColumnEncryptor(leftColumnBoundedInfo, encryptRule);
+            EncryptAlgorithm rightColumnEncryptor = 
getColumnEncryptor(rightColumnBoundedInfo, encryptRule);
+            if (null != leftColumnEncryptor || null != rightColumnEncryptor) {
+                return true;
+            }
+        }
+        return false;
+    }
 }
diff --git 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGeneratorTest.java
 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGeneratorTest.java
index dc51060c6ba..1354fe1224e 100644
--- 
a/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGeneratorTest.java
+++ 
b/features/encrypt/core/src/test/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/EncryptProjectionTokenGeneratorTest.java
@@ -20,15 +20,18 @@ package 
org.apache.shardingsphere.encrypt.rewrite.token.generator;
 import org.apache.shardingsphere.encrypt.rule.EncryptRule;
 import org.apache.shardingsphere.encrypt.rule.EncryptTable;
 import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
+import org.apache.shardingsphere.encrypt.spi.EncryptAlgorithm;
 import 
org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
 import 
org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
 import 
org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
 import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
+import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType;
 import 
org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
 import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
@@ -44,6 +47,8 @@ import org.junit.jupiter.api.Test;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
 import java.util.Optional;
 
 import static org.hamcrest.CoreMatchers.is;
@@ -66,7 +71,7 @@ class EncryptProjectionTokenGeneratorTest {
     }
     
     private EncryptRule mockEncryptRule() {
-        EncryptRule result = mock(EncryptRule.class);
+        EncryptRule result = mock(EncryptRule.class, RETURNS_DEEP_STUBS);
         EncryptTable encryptTable1 = mock(EncryptTable.class);
         EncryptTable encryptTable2 = mock(EncryptTable.class);
         
when(encryptTable1.getLogicColumns()).thenReturn(Collections.singleton("mobile"));
@@ -77,6 +82,9 @@ class EncryptProjectionTokenGeneratorTest {
         EncryptColumn encryptColumn = mock(EncryptColumn.class, 
RETURNS_DEEP_STUBS);
         when(encryptColumn.getAssistedQuery()).thenReturn(Optional.empty());
         
when(encryptTable1.getEncryptColumn("mobile")).thenReturn(encryptColumn);
+        when(result.findEncryptTable("t_order").isPresent()).thenReturn(true);
+        
when(result.getEncryptTable("t_order").isEncryptColumn("order_id")).thenReturn(true);
+        
when(result.getEncryptTable("t_order").getEncryptColumn("order_id").getCipher().getEncryptor()).thenReturn(mock(EncryptAlgorithm.class));
         return result;
     }
     
@@ -150,4 +158,36 @@ class EncryptProjectionTokenGeneratorTest {
         
when(sqlStatementContext.getSqlStatement().getProjections().getProjections()).thenReturn(Collections.singleton(new
 ShorthandProjectionSegment(0, 0)));
         assertThrows(UnsupportedSQLOperationException.class, () -> 
generator.generateSQLTokens(sqlStatementContext));
     }
+    
+    @Test
+    void assertGenerateSQLTokensWhenCombineStatementContainsEncryptColumn() {
+        SelectStatementContext sqlStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        when(sqlStatementContext.isContainsCombine()).thenReturn(true);
+        
when(sqlStatementContext.getSqlStatement().getCombine().isPresent()).thenReturn(true);
+        CombineSegment combineSegment = mock(CombineSegment.class, 
RETURNS_DEEP_STUBS);
+        
when(sqlStatementContext.getSqlStatement().getCombine().get()).thenReturn(combineSegment);
+        ColumnProjection orderIdColumn = new ColumnProjection("o", "order_id", 
null, new MySQLDatabaseType());
+        orderIdColumn.setOriginalTable(new IdentifierValue("t_order"));
+        orderIdColumn.setOriginalColumn(new IdentifierValue("order_id"));
+        ColumnProjection userIdColumn = new ColumnProjection("o", "user_id", 
null, new MySQLDatabaseType());
+        userIdColumn.setOriginalTable(new IdentifierValue("t_order"));
+        userIdColumn.setOriginalColumn(new IdentifierValue("user_id"));
+        SelectStatementContext leftSelectStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        
when(leftSelectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Arrays.asList(orderIdColumn,
 userIdColumn));
+        ColumnProjection merchantIdColumn = new ColumnProjection("m", 
"merchant_id", null, new MySQLDatabaseType());
+        merchantIdColumn.setOriginalTable(new IdentifierValue("t_merchant"));
+        merchantIdColumn.setOriginalColumn(new IdentifierValue("merchant_id"));
+        ColumnProjection merchantNameColumn = new ColumnProjection("m", 
"merchant_name", null, new MySQLDatabaseType());
+        merchantNameColumn.setOriginalTable(new IdentifierValue("t_merchant"));
+        merchantNameColumn.setOriginalColumn(new 
IdentifierValue("merchant_name"));
+        SelectStatementContext rightSelectStatementContext = 
mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
+        
when(rightSelectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Arrays.asList(merchantIdColumn,
 merchantNameColumn));
+        Map<Integer, SelectStatementContext> subqueryContexts = new 
LinkedHashMap<>();
+        subqueryContexts.put(0, leftSelectStatementContext);
+        subqueryContexts.put(1, rightSelectStatementContext);
+        
when(sqlStatementContext.getSubqueryContexts()).thenReturn(subqueryContexts);
+        when(combineSegment.getLeft().getStartIndex()).thenReturn(0);
+        when(combineSegment.getRight().getStartIndex()).thenReturn(1);
+        assertThrows(UnsupportedSQLOperationException.class, () -> 
generator.generateSQLTokens(sqlStatementContext));
+    }
 }

Reply via email to