This is an automated email from the ASF dual-hosted git repository.

zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new bd205373d26 add unit test for audit algorithm and checker (#18824)
bd205373d26 is described below

commit bd205373d2621fd4ec22ff23fadf52bb984bf58e
Author: natehuang <[email protected]>
AuthorDate: Tue Jul 5 10:59:01 2022 +0800

    add unit test for audit algorithm and checker (#18824)
---
 ...ardingConditionsShardingAuditAlgorithmTest.java |  99 ++++++++++++++++++++
 .../sharding/checker/ShardingAuditCheckerTest.java | 101 +++++++++++++++++++++
 2 files changed, 200 insertions(+)

diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/algorithm/audit/DMLShardingConditionsShardingAuditAlgorithmTest.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/algorithm/audit/DMLShardingConditionsShardingAuditAlgorithmTest.java
new file mode 100644
index 00000000000..5b4dff7879c
--- /dev/null
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/algorithm/audit/DMLShardingConditionsShardingAuditAlgorithmTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.sharding.algorithm.audit;
+
+import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
+import org.apache.shardingsphere.infra.check.SQLCheckResult;
+import 
org.apache.shardingsphere.infra.config.algorithm.ShardingSphereAlgorithmConfiguration;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
+import org.apache.shardingsphere.infra.metadata.user.Grantee;
+import 
org.apache.shardingsphere.sharding.factory.ShardingAuditAlgorithmFactory;
+import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.apache.shardingsphere.sharding.spi.ShardingAuditAlgorithm;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.DDLStatement;
+import 
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.Properties;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public final class DMLShardingConditionsShardingAuditAlgorithmTest {
+    
+    private SQLStatementContext sqlStatementContext;
+    
+    private ShardingSphereDatabase database;
+    
+    private ShardingRule rule;
+    
+    private ShardingAuditAlgorithm shardingAuditAlgorithm;
+    
+    @Before
+    public void setUp() {
+        shardingAuditAlgorithm = ShardingAuditAlgorithmFactory.newInstance(new 
ShardingSphereAlgorithmConfiguration("DML_SHARDING_CONDITIONS", new 
Properties()));
+        sqlStatementContext = mock(SQLStatementContext.class, 
RETURNS_DEEP_STUBS);
+        database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
+        rule = mock(ShardingRule.class);
+        
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singletonList("t_order"));
+    }
+    
+    @Test
+    public void assertNotDMLStatementCheck() {
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DDLStatement.class));
+        asserCheckResult(shardingAuditAlgorithm.check(sqlStatementContext, 
Collections.emptyList(), mock(Grantee.class), database), true, "");
+        verify(database, times(0)).getRuleMetaData();
+    }
+    
+    @Test
+    public void assertAllBroadcastTablesCheck() {
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DMLStatement.class));
+        
when(database.getRuleMetaData().findRules(ShardingRule.class)).thenReturn(Collections.singletonList(rule));
+        
when(rule.isAllBroadcastTables(sqlStatementContext.getTablesContext().getTableNames())).thenReturn(true);
+        asserCheckResult(shardingAuditAlgorithm.check(sqlStatementContext, 
Collections.emptyList(), mock(Grantee.class), database), true, "");
+    }
+    
+    @Test
+    public void assertNotAllShardingTablesCheck() {
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DMLStatement.class));
+        
when(database.getRuleMetaData().findRules(ShardingRule.class)).thenReturn(Collections.singletonList(rule));
+        
when(rule.isAllBroadcastTables(sqlStatementContext.getTablesContext().getTableNames())).thenReturn(false);
+        asserCheckResult(shardingAuditAlgorithm.check(sqlStatementContext, 
Collections.emptyList(), mock(Grantee.class), database), true, "");
+    }
+    
+    @Test
+    public void assertEmptyShardingConditionsCheck() {
+        
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(DMLStatement.class));
+        
when(database.getRuleMetaData().findRules(ShardingRule.class)).thenReturn(Collections.singletonList(rule));
+        
when(rule.isAllBroadcastTables(sqlStatementContext.getTablesContext().getTableNames())).thenReturn(false);
+        when(rule.isShardingTable("t_order")).thenReturn(true);
+        asserCheckResult(shardingAuditAlgorithm.check(sqlStatementContext, 
Collections.emptyList(), mock(Grantee.class), database), false, "Not allow DML 
operation without sharding conditions");
+    }
+    
+    private void asserCheckResult(final SQLCheckResult checkResult, final 
boolean isPassed, final String errorMessage) {
+        assertThat(checkResult.isPassed(), is(isPassed));
+        assertThat(checkResult.getErrorMessage(), is(errorMessage));
+    }
+}
diff --git 
a/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/checker/ShardingAuditCheckerTest.java
 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/checker/ShardingAuditCheckerTest.java
new file mode 100644
index 00000000000..9dbcfb9205c
--- /dev/null
+++ 
b/shardingsphere-features/shardingsphere-sharding/shardingsphere-sharding-core/src/test/java/org/apache/shardingsphere/sharding/checker/ShardingAuditCheckerTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.sharding.checker;
+
+import 
org.apache.shardingsphere.infra.binder.statement.CommonSQLStatementContext;
+import org.apache.shardingsphere.infra.check.SQLCheckResult;
+import 
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
+import org.apache.shardingsphere.infra.metadata.user.Grantee;
+import org.apache.shardingsphere.sharding.checker.audit.ShardingAuditChecker;
+import org.apache.shardingsphere.sharding.rule.ShardingRule;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Answers;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public final class ShardingAuditCheckerTest {
+    
+    @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+    private ShardingRule rule;
+    
+    @Mock(answer = Answers.RETURNS_DEEP_STUBS)
+    private CommonSQLStatementContext<?> sqlStatementContext;
+    
+    @Mock
+    private Grantee grantee;
+    
+    private final ShardingAuditChecker checker = new ShardingAuditChecker();
+    
+    private final Map<String, ShardingSphereDatabase> databases = new 
LinkedHashMap<>();
+    
+    private final List<Object> parameters = Collections.emptyList();
+    
+    @Before
+    public void setUp() {
+        
when(sqlStatementContext.getSqlHintExtractor().findDisableAuditNames()).thenReturn(new
 HashSet<>(Collections.singletonList("audit_algorithm")));
+        
when(rule.getAuditStrategyConfig().getAuditAlgorithmNames()).thenReturn(Collections.singleton("audit_algorithm"));
+        databases.put("foo_db", mock(ShardingSphereDatabase.class));
+    }
+    
+    @Test
+    public void assertCheckSQLStatementPass() {
+        
when(rule.getAuditAlgorithms().get("audit_algorithm").check(sqlStatementContext,
 parameters, grantee, databases.get("foo_db")))
+                .thenReturn(new SQLCheckResult(true, ""));
+        asserCheckResult(checker.check(sqlStatementContext, 
Collections.emptyList(), grantee, "foo_db", databases, rule), true, "");
+        verify(rule.getAuditAlgorithms().get("audit_algorithm"), 
times(1)).check(sqlStatementContext, parameters, grantee, 
databases.get("foo_db"));
+    }
+    
+    @Test
+    public void assertSQCheckPassByDisableAuditNames() {
+        
when(rule.getAuditAlgorithms().get("audit_algorithm").check(sqlStatementContext,
 parameters, grantee, databases.get("foo_db")))
+                .thenReturn(new SQLCheckResult(false, ""));
+        
when(rule.getAuditStrategyConfig().isAllowHintDisable()).thenReturn(true);
+        asserCheckResult(checker.check(sqlStatementContext, 
Collections.emptyList(), grantee, "foo_db", databases, rule), true, "");
+        
+        verify(rule.getAuditAlgorithms().get("audit_algorithm"), 
times(0)).check(sqlStatementContext, parameters, grantee, 
databases.get("foo_db"));
+    }
+    
+    @Test
+    public void assertSQLCheckNotPass() {
+        
when(rule.getAuditAlgorithms().get("audit_algorithm").check(sqlStatementContext,
 parameters, grantee, databases.get("foo_db")))
+                .thenReturn(new SQLCheckResult(false, "Not allow DML operation 
without sharding conditions"));
+        asserCheckResult(checker.check(sqlStatementContext, 
Collections.emptyList(), grantee, "foo_db", databases, rule), false, "Not allow 
DML operation without sharding conditions");
+        verify(rule.getAuditAlgorithms().get("audit_algorithm"), 
times(1)).check(sqlStatementContext, parameters, grantee, 
databases.get("foo_db"));
+    }
+    
+    private void asserCheckResult(final SQLCheckResult checkResult, final 
boolean isPassed, final String errorMessage) {
+        assertThat(checkResult.isPassed(), is(isPassed));
+        assertThat(checkResult.getErrorMessage(), is(errorMessage));
+    }
+}

Reply via email to