This is an automated email from the ASF dual-hosted git repository.
zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 19b4f727f33 Add test cases of PreviewExecutor (#37999)
19b4f727f33 is described below
commit 19b4f727f3345cca38720f8dbb9c995476b0791e
Author: Liang Zhang <[email protected]>
AuthorDate: Wed Feb 11 01:26:13 2026 +0800
Add test cases of PreviewExecutor (#37999)
* Add test cases of PreviewExecutor
* Add test cases of PreviewExecutor
* Add test cases of PreviewExecutor
* Add test cases of PreviewExecutor
* Add test cases of PreviewExecutor
---
.codex/skills/gen-ut/SKILL.md | 95 +++++--
CODE_OF_CONDUCT.md | 1 +
docs/community/content/involved/conduct/code.cn.md | 1 +
docs/community/content/involved/conduct/code.en.md | 1 +
.../handler/distsql/rul/PreviewExecutorTest.java | 286 +++++++++++++++++++++
5 files changed, 364 insertions(+), 20 deletions(-)
diff --git a/.codex/skills/gen-ut/SKILL.md b/.codex/skills/gen-ut/SKILL.md
index 318aaa6a941..1ce065eff1c 100644
--- a/.codex/skills/gen-ut/SKILL.md
+++ b/.codex/skills/gen-ut/SKILL.md
@@ -17,6 +17,9 @@ Optional inputs:
- Module name (limits Maven command scope).
- Test class list (for targeted execution only; does not limit in-place
updates for related test classes).
+Default completion level:
+- Unless the user explicitly waives or lowers the target, requests such as
"add tests" remain bound to `R10-A` completion criteria (including default
coverage and quality gates).
+
Missing input handling:
- Note: this section only describes entry handling; final decisions follow
`R7`/`R10`.
- Missing target classes: enter `R10-INPUT_BLOCKED`.
@@ -29,6 +32,8 @@ Missing input handling:
- `<ResolvedTestClass>`: one fully-qualified test class or a comma-separated
list of test classes.
- `<ResolvedTestFileSet>`: editable file set (space-separated in shell
commands), containing only related test files and required test resources.
- `<ResolvedTestModules>`: comma-separated Maven module list used by scoped
verification commands.
+- `<ResolvedTargetClasses>`: one fully-qualified production class or a
comma-separated list of target classes from user input.
+- `Target-class coverage scope`: for each target class, aggregate coverage for
the target binary class and all binary classes whose names start with
`<targetBinaryName>$` (including member/anonymous/local classes).
- `Related test classes`: existing `TargetClassName + Test` classes resolvable
within the same module's test scope.
- `Assertion differences`: distinguishable assertions in externally observable
results or side effects.
- `Necessity reason tag`: fixed-format tag for retention reasons, using
`KEEP:<id>:<reason>`, recorded in the "Implementation and Optimization" section
of the delivery report.
@@ -73,7 +78,8 @@ Module resolution order:
- Dedicated test targets `MUST` follow the `R4` branch-mapping exclusion
scope.
- `R6`: SPI, Mock, and reflection
- - If the class under test can be obtained via SPI, `MUST` instantiate by
default with `TypedSPILoader`/`OrderedSPILoader` (or database-specific loaders).
+ - If the class under test can be obtained via SPI, `MUST` instantiate by
default with `TypedSPILoader`/`OrderedSPILoader` (or database-specific
loaders), and `MUST` keep the resolved instance as a test-class-level field
(global variable) by default.
+ - SPI metadata accessor methods `TypedSPI#getType`, `OrderedSPI#getOrder`,
and `getTypeClass` `MUST` be treated as no-test-required targets by default,
unless the user explicitly requests tests for them.
- If not instantiated via SPI, `MUST` record the reason before
implementation.
- Test dependencies `SHOULD` use Mockito mocks by default.
- Reflection access `MUST` use `Plugins.getMemberAccessor()`, and field
access only.
@@ -107,11 +113,14 @@ Module resolution order:
- scope satisfies `R3`;
- target test command succeeds, and surefire report has `Tests run > 0`
(recommended to also satisfy `Tests run - Skipped > 0`);
- coverage evidence satisfies the target (default class/line/branch 100%,
unless explicitly lowered by the user);
+ - each class in `<ResolvedTargetClasses>` has explicit aggregated
class-level coverage evidence (CLASS/LINE/BRANCH counters with
covered/missed/ratio) over the `Target-class coverage scope`, and all ratios
satisfy the declared target;
- Checkstyle, Spotless, and two `R14` scans all pass;
- `R8` analysis and compliance evidence are complete.
- `R10-B` (blocked): under the "production code cannot be changed"
constraint, dead code blocks coverage targets, and evidence satisfies `R9`.
- `R10-C` (blocked): failure occurs outside `R3` scope, and evidence
satisfies `R11`.
- - Decision priority: `R10-INPUT_BLOCKED` > `R10-B` > `R10-C` > `R10-A`.
+ - `R10-D` (in-progress): none of `R10-INPUT_BLOCKED/R10-B/R10-C/R10-A` is
satisfied yet.
+ - Decision priority: `R10-INPUT_BLOCKED` > `R10-B` > `R10-C` > `R10-A` >
`R10-D`.
+ - `MUST NOT` conclude the task as completed while in `R10-D`; continue
implementation and verification loops until reaching a terminal state.
- `R11`: failure handling
- If failure is within `R3` scope: `MUST` fix within `<ResolvedTestFileSet>`
and rerun minimal verification.
@@ -152,7 +161,8 @@ Module resolution order:
7. Complete test implementation or extension according to `R2-R7`.
8. Perform necessity trimming and coverage re-verification according to `R13`.
9. Run verification commands and handle failures by `R11`; execute two `R14`
scans.
-10. Decide status by `R10` and output rule-to-evidence mapping.
+10. Decide status by `R10` after verification; if status is `R10-D`, return to
Step 5 and continue.
+11. Before final response, run a second `R10` status decision and output
`R10=<state>` with rule-to-evidence mapping.
## Verification and Commands
@@ -179,6 +189,53 @@ If the module does not define `jacoco-check@jacoco-check`:
./mvnw <GateModuleFlags> -DskipITs -Djacoco.skip=false test jacoco:report
```
+2.1 Target-class coverage hard gate (default target 100 unless explicitly
lowered, aggregated over `Target-class coverage scope`):
+```bash
+bash -lc '
+python3 - <JacocoXmlPath> <TargetRatioPercent> <ResolvedTargetClasses>
<<'"'"'PY'"'"'
+import sys
+import xml.etree.ElementTree as ET
+xml_path, target = sys.argv[1], float(sys.argv[2])
+target_classes = [each.strip() for each in sys.argv[3].split(",") if
each.strip()]
+if not target_classes:
+ print("[R10] empty target class list")
+ sys.exit(1)
+all_classes = list(ET.parse(xml_path).getroot().iter("class"))
+all_ok = True
+for fqcn in target_classes:
+ class_name = fqcn.replace(".", "/")
+ matched_nodes = [each for each in all_classes if each.get("name") ==
class_name or each.get("name", "").startswith(class_name + "$")]
+ if not matched_nodes:
+ print(f"[R10] class not found in jacoco.xml: {fqcn}")
+ all_ok = False
+ continue
+ for counter_type in ("CLASS", "LINE", "BRANCH"):
+ covered = 0
+ missed = 0
+ found_counter = False
+ for each in matched_nodes:
+ counter = next((c for c in each.findall("counter") if
c.get("type") == counter_type), None)
+ if counter is None:
+ continue
+ found_counter = True
+ covered += int(counter.get("covered"))
+ missed += int(counter.get("missed"))
+ if not found_counter:
+ print(f"[R10] missing {counter_type} counter for {fqcn}")
+ all_ok = False
+ continue
+ total = covered + missed
+ ratio = 100.0 if total == 0 else covered * 100.0 / total
+ print(f"[R10] {fqcn} (+inner) {counter_type} covered={covered}
missed={missed} ratio={ratio:.2f}%")
+ if ratio + 1e-9 < target:
+ print(f"[R10] {fqcn} (+inner) {counter_type} ratio {ratio:.2f}% <
target {target:.2f}%")
+ all_ok = False
+if not all_ok:
+ sys.exit(1)
+PY
+'
+```
+
3. Checkstyle:
```bash
./mvnw <GateModuleFlags> -Pcheck checkstyle:check -DskipTests
@@ -200,43 +257,33 @@ from pathlib import Path
name_pattern = re.compile(r'name\s*=\s*"\{0\}"')
token = "@ParameterizedTest"
-
-def collect_violations(path):
+violations = []
+for path in (each for each in sys.argv[1:] if each.endswith(".java")):
source = Path(path).read_text(encoding="utf-8")
- violations = []
pos = 0
while True:
token_pos = source.find(token, pos)
if token_pos < 0:
- return violations
+ break
line = source.count("\n", 0, token_pos) + 1
cursor = token_pos + len(token)
while cursor < len(source) and source[cursor].isspace():
cursor += 1
- if cursor >= len(source) or source[cursor] != "(":
+ if cursor >= len(source) or "(" != source[cursor]:
violations.append(f"{path}:{line}")
pos = token_pos + len(token)
continue
depth = 1
end = cursor + 1
while end < len(source) and depth:
- if source[end] == "(":
+ if "(" == source[end]:
depth += 1
- elif source[end] == ")":
+ elif ")" == source[end]:
depth -= 1
end += 1
- if depth:
- violations.append(f"{path}:{line}")
- return violations
- if not name_pattern.search(source[cursor + 1:end - 1]):
+ if depth or not name_pattern.search(source[cursor + 1:end - 1]):
violations.append(f"{path}:{line}")
pos = end
-
-violations = []
-for each in sys.argv[1:]:
- if each.endswith(".java"):
- violations.extend(collect_violations(each))
-
if violations:
print("[R8] @ParameterizedTest must use name = \"{0}\"")
for each in violations:
@@ -260,3 +307,11 @@ fi'
```bash
git diff --name-only
```
+
+## Final Output Requirements
+
+- `MUST` include a status line `R10=<state>`.
+- `MUST` include aggregated class-level coverage evidence for each class in
`<ResolvedTargetClasses>` over the `Target-class coverage scope`
(CLASS/LINE/BRANCH counters and ratios).
+- `MUST` include executed commands and exit codes.
+- If `R10` is not `R10-A`, `MUST` explicitly mark the task as not completed
and provide blocking reason plus next action.
+- `MUST NOT` use completion wording when `R10` is `R10-D`.
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index 9b4ff2e1c41..0ae53c72852 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -59,6 +59,7 @@ The following code of conduct is based on full compliance
with the [Apache Softw
- Access control for classes and methods should be minimal.
- Private methods used by a method should immediately follow that method. If
there are multiple private methods, they should be written in the same order as
they appear in the original method.
- Method parameters and return values are not allowed to be `null`.
+- Method parameters must not use `Optional`; pass plain values (nullable when
needed).
- Prefer using lombok instead of constructors, getter, setter methods and log
variables.
- Do not leave fully-qualified class names inline; add import statements
instead.
- Consider using `LinkedList` first, only use `ArrayList` when you need to get
element values from the collection by index.
diff --git a/docs/community/content/involved/conduct/code.cn.md
b/docs/community/content/involved/conduct/code.cn.md
index 9503778bac4..965cecab520 100644
--- a/docs/community/content/involved/conduct/code.cn.md
+++ b/docs/community/content/involved/conduct/code.cn.md
@@ -63,6 +63,7 @@ chapter = true
- 类和方法的访问权限控制为最小。
- 方法所用到的私有方法应紧跟该方法,如果有多个私有方法,书写私有方法应与私有方法在原方法的出现顺序相同。
- 方法入参和返回值不允许为 `null`。
+ - 方法入参禁止使用 `Optional`;应传递普通值(必要时允许为 `null`)。
- 优先使用 lombok 代替构造器,getter, setter 方法和 log 变量。
- 禁止内联全限定类名,必须通过 import 引入。
- 优先考虑使用 `LinkedList`,只有在需要通过下标获取集合中元素值时再使用 `ArrayList`。
diff --git a/docs/community/content/involved/conduct/code.en.md
b/docs/community/content/involved/conduct/code.en.md
index c6ed3c74809..5ef3b827c9d 100644
--- a/docs/community/content/involved/conduct/code.en.md
+++ b/docs/community/content/involved/conduct/code.en.md
@@ -63,6 +63,7 @@ The following code of conduct is based on full compliance
with the [Apache Softw
- Access control for classes and methods should be minimal.
- Private methods used by a method should immediately follow that method. If
there are multiple private methods, they should be written in the same order as
they appear in the original method.
- Method parameters and return values are not allowed to be `null`.
+- Method parameters must not use `Optional`; pass plain values (nullable when
needed).
- Prefer using lombok instead of constructors, getter, setter methods and log
variables.
- Do not leave fully-qualified class names inline; add import statements
instead.
- Consider using `LinkedList` first, only use `ArrayList` when you need to get
element values from the collection by index.
diff --git
a/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/rul/PreviewExecutorTest.java
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/rul/PreviewExecutorTest.java
new file mode 100644
index 00000000000..1a1b958dbd6
--- /dev/null
+++
b/proxy/backend/core/src/test/java/org/apache/shardingsphere/proxy/backend/handler/distsql/rul/PreviewExecutorTest.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.proxy.backend.handler.distsql.rul;
+
+import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
+import
org.apache.shardingsphere.database.exception.core.exception.syntax.sql.DialectSQLParsingException;
+import
org.apache.shardingsphere.distsql.handler.engine.DistSQLConnectionContext;
+import
org.apache.shardingsphere.distsql.handler.engine.query.DistSQLQueryExecutor;
+import
org.apache.shardingsphere.distsql.statement.type.rul.sql.PreviewStatement;
+import
org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
+import
org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
+import
org.apache.shardingsphere.infra.binder.context.statement.type.ddl.CursorHeldSQLStatementContext;
+import
org.apache.shardingsphere.infra.binder.context.statement.type.ddl.CursorStatementContext;
+import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
+import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
+import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
+import
org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
+import
org.apache.shardingsphere.infra.exception.kernel.metadata.rule.EmptyRuleException;
+import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
+import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
+import org.apache.shardingsphere.infra.executor.sql.context.SQLUnit;
+import
org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
+import
org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
+import
org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
+import
org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
+import
org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
+import
org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
+import
org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
+import
org.apache.shardingsphere.infra.executor.sql.prepare.driver.ExecutorStatementManager;
+import org.apache.shardingsphere.infra.hint.HintValueContext;
+import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
+import
org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.resource.node.StorageNode;
+import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
+import
org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
+import org.apache.shardingsphere.infra.route.context.RouteContext;
+import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
+import org.apache.shardingsphere.infra.session.query.QueryContext;
+import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
+import org.apache.shardingsphere.mode.manager.ContextManager;
+import org.apache.shardingsphere.parser.rule.SQLParserRule;
+import
org.apache.shardingsphere.parser.rule.builder.DefaultSQLParserRuleConfigurationBuilder;
+import
org.apache.shardingsphere.proxy.backend.connector.ProxyDatabaseConnectionManager;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.cursor.CursorNameSegment;
+import
org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
+import
org.apache.shardingsphere.sql.parser.statement.core.statement.attribute.SQLStatementAttributes;
+import
org.apache.shardingsphere.sql.parser.statement.core.statement.attribute.type.CursorSQLStatementAttribute;
+import
org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
+import org.apache.shardingsphere.sqlfederation.context.SQLFederationContext;
+import org.apache.shardingsphere.sqlfederation.engine.SQLFederationEngine;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedConstruction;
+import org.mockito.invocation.InvocationOnMock;
+
+import javax.sql.DataSource;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockConstruction;
+import static org.mockito.Mockito.when;
+
+class PreviewExecutorTest {
+
+ private final PreviewExecutor executor = (PreviewExecutor)
TypedSPILoader.getService(DistSQLQueryExecutor.class, PreviewStatement.class);
+
+ private final DatabaseType databaseType =
TypedSPILoader.getService(DatabaseType.class, "MySQL");
+
+ private ContextManager contextManager;
+
+ @BeforeEach
+ void setUp() {
+ contextManager = mockContextManager();
+ ProxyContext.init(contextManager);
+ }
+
+ private ContextManager mockContextManager() {
+ ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
+ when(result.getMetaDataContexts().getMetaData()).thenReturn(new
ShardingSphereMetaData(Collections.emptyList(), mock(),
+ new RuleMetaData(Collections.singleton(new SQLParserRule(new
DefaultSQLParserRuleConfigurationBuilder().build()))), new
ConfigurationProperties(new Properties())));
+ return result;
+ }
+
+ @Test
+ void assertGetColumnNames() {
+ assertThat(executor.getColumnNames(new PreviewStatement("SELECT 1")),
is(Arrays.asList("data_source_name", "actual_sql")));
+ }
+
+ @Test
+ void assertGetRowsWithInvalidSQL() {
+ ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
+ when(database.getProtocolType()).thenReturn(databaseType);
+ executor.setDatabase(database);
+ assertThrows(DialectSQLParsingException.class, () ->
executor.getRows(new PreviewStatement("invalid sql"), contextManager));
+ }
+
+ @Test
+ void assertGetRowsWithIncompleteDatabase() {
+ ShardingSphereDatabase database = mock(ShardingSphereDatabase.class);
+ when(database.getProtocolType()).thenReturn(databaseType);
+ when(database.getName()).thenReturn("foo_db");
+ executor.setDatabase(database);
+ HintValueContext hintValueContext = new HintValueContext();
+ executor.setConnectionContext(mockConnectionContext(hintValueContext,
new ConnectionContext(Collections::emptyList),
mock(DatabaseConnectionManager.class)));
+ assertThrows(EmptyRuleException.class, () -> executor.getRows(new
PreviewStatement("SELECT 1"), contextManager));
+ assertTrue(hintValueContext.isSkipMetadataValidate());
+ }
+
+ @Test
+ void assertGetRowsWithCursorAttributeAndNotCursorHeld() {
+ HintValueContext hintValueContext = new HintValueContext();
+ executor.setDatabase(mockCompleteDatabase());
+ executor.setConnectionContext(mockConnectionContext(hintValueContext,
new ConnectionContext(Collections::emptyList),
mock(DatabaseConnectionManager.class)));
+ SQLStatement sqlStatement = mockSQLStatement(new
CursorSQLStatementAttribute(null));
+ SQLStatementContext sqlStatementContext =
mock(SQLStatementContext.class, RETURNS_DEEP_STUBS);
+ when(sqlStatementContext.getSqlStatement()).thenReturn(sqlStatement);
+
when(sqlStatementContext.getTablesContext().getSchemaName()).thenReturn(Optional.empty());
+ ExecutionContext executionContext = new
ExecutionContext(mock(QueryContext.class),
Collections.singletonList(createExecutionUnit("foo_ds", "SELECT 1")),
mock(RouteContext.class));
+ try (
+ MockedConstruction<SQLBindEngine> ignoredBindEngine =
+ mockConstruction(SQLBindEngine.class, (mock, context)
-> when(mock.bind(any(SQLStatement.class))).thenReturn(sqlStatementContext));
+ MockedConstruction<JDBCExecutor> ignoredJDBCExecutor =
mockConstruction(JDBCExecutor.class);
+ MockedConstruction<SQLFederationEngine>
ignoredFederationEngine =
+ mockConstruction(SQLFederationEngine.class, (mock,
context) -> when(mock.decide(any(QueryContext.class),
any(RuleMetaData.class))).thenReturn(false));
+ MockedConstruction<KernelProcessor> ignoredKernelProcessor =
mockConstruction(KernelProcessor.class,
+ (mock, context) ->
when(mock.generateExecutionContext(any(QueryContext.class),
any(RuleMetaData.class),
any(ConfigurationProperties.class))).thenReturn(executionContext))) {
+ assertThat(executor.getRows(new PreviewStatement("SELECT 1"),
contextManager).iterator().next().getCell(1), is("foo_ds"));
+ assertTrue(hintValueContext.isSkipMetadataValidate());
+ }
+ }
+
+ @Test
+ void assertGetRowsWithCursorHeldAndNoCursorName() {
+ HintValueContext hintValueContext = new HintValueContext();
+ executor.setDatabase(mockCompleteDatabase());
+ executor.setConnectionContext(mockConnectionContext(hintValueContext,
new ConnectionContext(Collections::emptyList),
mock(DatabaseConnectionManager.class)));
+ SQLStatement sqlStatement = mockSQLStatement(new
CursorSQLStatementAttribute(null));
+ CursorHeldSQLStatementContext cursorHeldSQLStatementContext = new
CursorHeldSQLStatementContext(sqlStatement);
+ ExecutionContext executionContext = new
ExecutionContext(mock(QueryContext.class),
Collections.singletonList(createExecutionUnit("foo_ds", "SELECT 1")),
mock(RouteContext.class));
+ try (
+ MockedConstruction<SQLBindEngine> ignoredBindEngine =
+ mockConstruction(SQLBindEngine.class, (mock, context)
->
when(mock.bind(any(SQLStatement.class))).thenReturn(cursorHeldSQLStatementContext));
+ MockedConstruction<JDBCExecutor> ignoredJDBCExecutor =
mockConstruction(JDBCExecutor.class);
+ MockedConstruction<SQLFederationEngine>
ignoredFederationEngine =
+ mockConstruction(SQLFederationEngine.class, (mock,
context) -> when(mock.decide(any(QueryContext.class),
any(RuleMetaData.class))).thenReturn(false));
+ MockedConstruction<KernelProcessor> ignoredKernelProcessor =
mockConstruction(KernelProcessor.class,
+ (mock, context) ->
when(mock.generateExecutionContext(any(QueryContext.class),
any(RuleMetaData.class),
any(ConfigurationProperties.class))).thenReturn(executionContext))) {
+ assertThat(executor.getRows(new PreviewStatement("SELECT 1"),
contextManager).iterator().next().getCell(2), is("SELECT 1"));
+
assertNull(cursorHeldSQLStatementContext.getCursorStatementContext());
+ }
+ }
+
+ @Test
+ void assertGetRowsWithCursorHeldAndCursorNameWithFederation() {
+ ConnectionContext connectionContext = new
ConnectionContext(Collections::emptyList);
+ CursorStatementContext cursorStatementContext =
mock(CursorStatementContext.class);
+ TablesContext tablesContext = mock(TablesContext.class);
+
when(tablesContext.getSchemaName()).thenReturn(Optional.of("foo_schema"));
+
when(cursorStatementContext.getTablesContext()).thenReturn(tablesContext);
+
connectionContext.getCursorContext().getCursorStatementContexts().put("foo_cursor",
cursorStatementContext);
+ ProxyDatabaseConnectionManager databaseConnectionManager =
mock(ProxyDatabaseConnectionManager.class, RETURNS_DEEP_STUBS);
+
when(databaseConnectionManager.getConnectionSession().getProcessId()).thenReturn("process_id");
+ HintValueContext hintValueContext = new HintValueContext();
+ executor.setDatabase(mockCompleteDatabase());
+ executor.setConnectionContext(mockConnectionContext(hintValueContext,
connectionContext, databaseConnectionManager));
+ SQLStatement sqlStatement = mockSQLStatement(new
CursorSQLStatementAttribute(new CursorNameSegment(0, 0, new
IdentifierValue("FOO_CURSOR"))));
+ CursorHeldSQLStatementContext cursorHeldSQLStatementContext = new
CursorHeldSQLStatementContext(sqlStatement);
+ AtomicReference<String> actualSchemaName = new AtomicReference<>();
+ try (
+ MockedConstruction<SQLBindEngine> ignoredBindEngine =
+ mockConstruction(SQLBindEngine.class, (mock, context)
->
when(mock.bind(any(SQLStatement.class))).thenReturn(cursorHeldSQLStatementContext));
+ MockedConstruction<JDBCExecutor> ignoredJDBCExecutor =
mockConstruction(JDBCExecutor.class);
+ MockedConstruction<DriverExecutionPrepareEngine>
ignoredPrepareEngine = mockConstruction(DriverExecutionPrepareEngine.class);
+ MockedConstruction<SQLFederationEngine> ignored4 =
mockConstruction(SQLFederationEngine.class, (mock, context) ->
configureSQLFederationEngine(mock, context, actualSchemaName))) {
+ assertThat(executor.getRows(new PreviewStatement("SELECT 1"),
contextManager).iterator().next().getCell(1), is("bar_ds"));
+ assertThat(actualSchemaName.get(), is("foo_schema"));
+
assertThat(cursorHeldSQLStatementContext.getCursorStatementContext(),
is(cursorStatementContext));
+ }
+ }
+
+ private ShardingSphereDatabase mockCompleteDatabase() {
+ ShardingSphereDatabase result = mock(ShardingSphereDatabase.class);
+ when(result.getProtocolType()).thenReturn(databaseType);
+ when(result.getName()).thenReturn("foo_db");
+ when(result.isComplete()).thenReturn(true);
+
when(result.getResourceMetaData()).thenReturn(createResourceMetaData());
+ when(result.getRuleMetaData()).thenReturn(new
RuleMetaData(Collections.emptyList()));
+ return result;
+ }
+
+ private ResourceMetaData createResourceMetaData() {
+ Map<String, StorageUnit> storageUnits = new LinkedHashMap<>(2, 1F);
+ storageUnits.put("foo_ds", createStorageUnit("foo_ds",
"jdbc:mysql://localhost:3306/foo_db"));
+ storageUnits.put("bar_ds", createStorageUnit("bar_ds",
"jdbc:postgresql://localhost:5432/bar_db"));
+ return new ResourceMetaData(Collections.emptyMap(), storageUnits);
+ }
+
+ private StorageUnit createStorageUnit(final String name, final String url)
{
+ Map<String, Object> props = new LinkedHashMap<>(2, 1F);
+ props.put("url", url);
+ props.put("username", "root");
+ return new StorageUnit(new StorageNode(name), new
DataSourcePoolProperties("com.zaxxer.hikari.HikariDataSource", props),
mock(DataSource.class));
+ }
+
+ @SuppressWarnings("rawtypes")
+ private DistSQLConnectionContext mockConnectionContext(final
HintValueContext hintValueContext, final ConnectionContext connectionContext,
+ final
DatabaseConnectionManager databaseConnectionManager) {
+ QueryContext queryContext = mock(QueryContext.class);
+ when(queryContext.getHintValueContext()).thenReturn(hintValueContext);
+
when(queryContext.getConnectionContext()).thenReturn(connectionContext);
+ return new DistSQLConnectionContext(queryContext, 1, databaseType,
databaseConnectionManager, mock(ExecutorStatementManager.class));
+ }
+
+ private SQLStatement mockSQLStatement(final CursorSQLStatementAttribute
cursorSQLStatementAttribute) {
+ SQLStatement result = mock(SQLStatement.class);
+ when(result.getDatabaseType()).thenReturn(databaseType);
+ when(result.getAttributes()).thenReturn(new
SQLStatementAttributes(cursorSQLStatementAttribute));
+ return result;
+ }
+
+ private ExecutionUnit createExecutionUnit(final String dataSourceName,
final String sql) {
+ return new ExecutionUnit(dataSourceName, new SQLUnit(sql,
Collections.emptyList()));
+ }
+
+ private JDBCExecutionUnit createJDBCExecutionUnit(final String
dataSourceName, final String sql, final boolean isThrowSQLException) throws
SQLException {
+ Statement statement = mock(Statement.class);
+ if (isThrowSQLException) {
+ when(statement.executeQuery(sql)).thenThrow(new SQLException("mock
exception"));
+ } else {
+
when(statement.executeQuery(sql)).thenReturn(mock(ResultSet.class));
+ }
+ return new JDBCExecutionUnit(createExecutionUnit(dataSourceName, sql),
ConnectionMode.MEMORY_STRICTLY, statement);
+ }
+
+ private void configureSQLFederationEngine(final SQLFederationEngine
federationEngine, final MockedConstruction.Context context, final
AtomicReference<String> actualSchemaName) {
+ actualSchemaName.set(context.arguments().get(1).toString());
+ when(federationEngine.decide(any(QueryContext.class),
any(RuleMetaData.class))).thenReturn(true);
+
doAnswer(this::executePreviewCallback).when(federationEngine).executeQuery(any(),
any(), any(SQLFederationContext.class));
+ }
+
+ private Object executePreviewCallback(final InvocationOnMock invocation)
throws SQLException {
+ SQLFederationContext sqlFederationContext = invocation.getArgument(2);
+ JDBCExecutorCallback<? extends ExecuteResult> callback =
invocation.getArgument(1);
+
callback.execute(Collections.singletonList(createJDBCExecutionUnit("foo_ds",
"SELECT 2", false)), true, "process_id");
+ try {
+
callback.execute(Collections.singletonList(createJDBCExecutionUnit("bar_ds",
"SELECT 3", true)), true, "process_id");
+ } catch (final SQLException ignored) {
+ }
+
sqlFederationContext.getPreviewExecutionUnits().add(createExecutionUnit("bar_ds",
"SELECT 2"));
+ return null;
+ }
+}