This is an automated email from the ASF dual-hosted git repository.
zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 7c7a2f939a5 Upgrade gen-ut skills (#38428)
7c7a2f939a5 is described below
commit 7c7a2f939a52c8547a913c655aa980a8380954ea
Author: Liang Zhang <[email protected]>
AuthorDate: Sat Mar 14 17:01:41 2026 +0800
Upgrade gen-ut skills (#38428)
* Upgrade gen-ut skills
* Upgrade gen-ut skills
* Upgrade gen-ut skills
* Upgrade gen-ut skills
* Upgrade gen-ut skills
* Upgrade gen-ut skills
---
.codex/skills/gen-ut/SKILL.md | 533 +++----------------
.../gen-ut/scripts/collect_quality_baseline.py | 187 +++++++
.codex/skills/gen-ut/scripts/run_quality_gates.py | 286 ++++++++++
.codex/skills/gen-ut/scripts/scan_quality_rules.py | 587 +++++++++++++++++++++
.../gen-ut/scripts/verification_gate_state.py | 232 ++++++++
.../infra/datanode/DataNodeTest.java | 311 ++++-------
.../infra/hint/HintValueContextTest.java | 107 ++--
7 files changed, 1541 insertions(+), 702 deletions(-)
diff --git a/.codex/skills/gen-ut/SKILL.md b/.codex/skills/gen-ut/SKILL.md
index 8fbd7773108..9e70b00116c 100644
--- a/.codex/skills/gen-ut/SKILL.md
+++ b/.codex/skills/gen-ut/SKILL.md
@@ -37,6 +37,11 @@ Missing input handling:
- `Related test classes`: existing `TargetClassName + Test` classes resolvable
within the same module's test scope.
- `Assertion differences`: distinguishable assertions in externally observable
results or side effects.
- `Necessity reason tag`: fixed-format tag for retention reasons, using
`KEEP:<id>:<reason>`, recorded in the "Implementation and Optimization" section
of the delivery report.
+- `Baseline quality summary`: one pre-edit diagnostic run that combines rule
scanning, candidate summary, and coverage evidence for the current scope.
+- `Verification snapshot digest`: content hash over `<ResolvedTestFileSet>`
used to decide whether a previous green verification result is still reusable.
+- `Gate reuse state`: persisted mapping from logical gate names (for example
`target-test`, `coverage`, `rule-scan`) to the latest green digest for that
gate.
+- `Latest green target-test digest`: compatibility alias for the `target-test`
entry in `Gate reuse state`.
+- `Consolidated hard-gate scan`: one script execution that enforces `R8`,
`R14`, and all file-content-based `R15` rules while still reporting results per
rule.
Module resolution order:
1. If the user explicitly provides modules, use them first.
@@ -183,14 +188,23 @@ Module resolution order:
- Capture scope baseline once: `git status --porcelain >
/tmp/gen-ut-status-before.txt`.
2. Parse target classes, related test classes, and input-blocked state
(`R10-INPUT_BLOCKED`).
3. Resolve `<ResolvedTestClass>`, `<ResolvedTestFileSet>`,
`<ResolvedTestModules>`, and record `pom.xml` evidence (`R3`).
-4. Decide whether `R12` is triggered; if not, output `R4` branch mapping.
-5. Execute `R8` parameterized optimization analysis, output `R8-CANDIDATES`,
and apply required refactoring.
-6. Execute `R9` dead-code checks and record evidence.
-7. Complete test implementation or extension according to `R2-R7`.
-8. Perform necessity trimming and coverage re-verification according to `R13`.
-9. Run verification commands and handle failures by `R11`; execute two `R14`
scans and all required `R15` scans.
-10. Decide status by `R10` after verification; if status is `R10-D`, return to
Step 5 and continue.
-11. Before final response, run a second `R10` status decision and output
`R10=<state>` with rule-to-evidence mapping.
+4. Run a `Baseline quality summary` using the bundled baseline script unless
equivalent evidence was just produced in the same turn.
+ - Use the baseline summary to identify current branch-miss lines, existing
`R15` risks, and likely `R8-CANDIDATES` before editing.
+5. Decide whether `R12` is triggered; if not, output `R4` branch mapping.
+6. Execute `R8` parameterized optimization analysis, output `R8-CANDIDATES`,
and apply required refactoring.
+7. Execute `R9` dead-code checks and record evidence.
+8. Complete test implementation or extension according to `R2-R7`.
+9. Perform necessity trimming and coverage re-verification according to `R13`.
+10. After each edit batch, recompute the `Verification snapshot digest`;
during in-scope repair loops, prefer `target test + one consolidated hard-gate
scan` as the minimal verification required by `R11`.
+ - After any standalone target-test command succeeds, `SHOULD` persist the
digest through `scripts/verification_gate_state.py mark-gate-green --gate
target-test`.
+11. Run final verification commands and handle failures by `R11`.
+ - Independent final gates (`coverage`, `checkstyle`, `spotless`,
`consolidated hard-gate scan`) `SHOULD` run in parallel when the environment
allows; otherwise serialize them.
+ - Prefer the bundled `scripts/run_quality_gates.py` runner so independent
gates share one orchestration entry and can reuse gate-level green results from
`Gate reuse state`.
+ - If `scripts/verification_gate_state.py match-gate-green --gate
target-test` reports a match for the current `<ResolvedTestFileSet>`, and the
final coverage command re-executes tests on that same digest, `MAY` skip an
extra standalone target-test rerun before delivery.
+ - A previously green `coverage` gate `MAY` be reused for the same digest;
`checkstyle` and `spotless` `SHOULD` still execute for the current module scope.
+ - The consolidated hard-gate scan `MUST` be executed twice to satisfy
`R14`: once after implementation stabilizes and once immediately before
delivery. Only the earlier scan may be reused for diagnostics; the delivery
scan must execute again.
+12. Decide status by `R10` after verification; if status is `R10-D`, return to
Step 5 and continue.
+13. Before final response, run a second `R10` status decision and output
`R10=<state>` with rule-to-evidence mapping.
## Verification and Commands
@@ -203,10 +217,35 @@ Flag presets:
- `<GateModuleFlags>` = `-pl <ResolvedTestModules>`
- `<FallbackGateModuleFlags>` = `<GateModuleFlags> -am` (for troubleshooting
missing cross-module dependencies only; does not change `R3` and `R10`).
+0. Baseline quality summary (recommended before editing):
+```bash
+python3 scripts/collect_quality_baseline.py --workdir <RepoRoot> \
+ --coverage-command "./mvnw <GateModuleFlags> -DskipITs
-Dsurefire.useManifestOnlyJar=false -Dtest=<ResolvedTestClass>
-DfailIfNoTests=true -Dsurefire.failIfNoSpecifiedTests=false
-Djacoco.skip=false -Djacoco.append=false
-Djacoco.destFile=/tmp/gen-ut-baseline.exec test jacoco:report
-Djacoco.dataFile=/tmp/gen-ut-baseline.exec" \
+ --jacoco-xml-path <JacocoXmlPath> \
+ --target-classes <ResolvedTargetClasses> \
+ --baseline-before /tmp/gen-ut-status-before.txt \
+ <ResolvedTestFileSet>
+```
+The baseline script reuses `scan_quality_rules.py` diagnostics and prints
current coverage plus branch-miss lines for each target class.
+
1. Target tests:
```bash
./mvnw <TestModuleFlags> -DskipITs -Dspotless.skip=true
-Dtest=<ResolvedTestClass> -DfailIfNoTests=true
-Dsurefire.failIfNoSpecifiedTests=false test
```
+After a green standalone target-test command, record the digest:
+```bash
+python3 scripts/verification_gate_state.py mark-gate-green --state-file
/tmp/gen-ut-gate-state.json --gate target-test <ResolvedTestFileSet>
+```
+
+1.1 Verification snapshot digest:
+```bash
+python3 scripts/verification_gate_state.py digest <ResolvedTestFileSet>
+```
+
+1.2 Latest green target-test digest reuse check:
+```bash
+python3 scripts/verification_gate_state.py match-gate-green --state-file
/tmp/gen-ut-gate-state.json --gate target-test <ResolvedTestFileSet>
+```
2. Coverage:
```bash
@@ -216,6 +255,10 @@ If the module does not define `jacoco-check@jacoco-check`:
```bash
./mvnw <GateModuleFlags> -DskipITs -Djacoco.skip=false test jacoco:report
```
+After a green standalone coverage command, the digest may be recorded for
reuse:
+```bash
+python3 scripts/verification_gate_state.py mark-gate-green --state-file
/tmp/gen-ut-gate-state.json --gate coverage <ResolvedTestFileSet>
+```
2.1 Target-class coverage hard gate (default target 100 unless explicitly
lowered, aggregated over `Target-class coverage scope`):
```bash
@@ -275,472 +318,42 @@ PY
```
If missing cross-module dependencies occur, rerun the gate command above once
with `<FallbackGateModuleFlags>` and record the trigger reason and result.
-5. `R8` parameterized compliance scan (annotation block parsing):
+4.1 Unified final-gate runner (recommended):
```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-
-name_pattern = re.compile(r'name\s*=\s*"\{0\}"')
-token = "@ParameterizedTest"
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- pos = 0
- while True:
- token_pos = source.find(token, pos)
- if token_pos < 0:
- break
- line = source.count("\n", 0, token_pos) + 1
- cursor = token_pos + len(token)
- while cursor < len(source) and source[cursor].isspace():
- cursor += 1
- if cursor >= len(source) or "(" != source[cursor]:
- violations.append(f"{path}:{line}")
- pos = token_pos + len(token)
- continue
- depth = 1
- end = cursor + 1
- while end < len(source) and depth:
- if "(" == source[end]:
- depth += 1
- elif ")" == source[end]:
- depth -= 1
- end += 1
- if depth or not name_pattern.search(source[cursor + 1:end - 1]):
- violations.append(f"{path}:{line}")
- pos = end
-if violations:
- print("[R8] @ParameterizedTest must use name = \"{0}\"")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
+python3 scripts/run_quality_gates.py --workdir <RepoRoot> \
+ --state-file /tmp/gen-ut-gate-state.json \
+ --tracked-path <ResolvedTestFileSet> \
+ --reuse-gate coverage \
+ --record-gate coverage \
+ --record-gate hard-gate=rule-scan \
+ --gate coverage="./mvnw <GateModuleFlags> -DskipITs -Djacoco.skip=false test
jacoco:report" \
+ --gate checkstyle="./mvnw <GateModuleFlags> -Pcheck checkstyle:check
-DskipTests" \
+ --gate spotless="./mvnw <GateModuleFlags> -Pcheck spotless:check
-DskipTests" \
+ --gate hard-gate="python3 scripts/scan_quality_rules.py --baseline-before
/tmp/gen-ut-status-before.txt <ResolvedTestFileSet>"
```
+If the environment cannot or should not parallelize, rerun the same command
with `--serial`.
+Coverage still remains the authoritative source for target-class counters, and
the runner does not relax any gate.
-5.1 `R15-A` high-fit candidate enforcement scan (shape-based):
+5. Consolidated hard-gate scan (`R8`, `R14`, `R15-A/B/C/D/E/F/G/H/I`):
```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-from collections import defaultdict
-
-IGNORE = {"assertThat", "assertTrue", "assertFalse", "mock", "when", "verify",
"is", "not"}
-
-def extract_block(text, brace_index):
- depth = 0
- i = brace_index
- while i < len(text):
- if "{" == text[i]:
- depth += 1
- elif "}" == text[i]:
- depth -= 1
- if 0 == depth:
- return text[brace_index + 1:i]
- i += 1
- return ""
-
-decl =
re.compile(r"(?:@Test|@ParameterizedTest(?:\\([^)]*\\))?)\\s+void\\s+(assert\\w+)\\s*\\([^)]*\\)\\s*\\{",
re.S)
-call = re.compile(r"\\b\\w+\\.(\\w+)\\s*\\(")
-param_targets = defaultdict(set)
-plain_target_count = defaultdict(lambda: defaultdict(int))
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- for match in decl.finditer(source):
- brace_index = source.find("{", match.start())
- body = extract_block(source, brace_index)
- methods = [each for each in call.findall(body) if each not in IGNORE]
- if not methods:
- continue
- target = methods[0]
- header = source[max(0, match.start() - 160):match.start()]
- if "@ParameterizedTest" in header:
- param_targets[path].add(target)
- else:
- plain_target_count[path][target] += 1
-violations = []
-for path, each_counter in plain_target_count.items():
- for method_name, count in each_counter.items():
- if count >= 3 and method_name not in param_targets[path]:
- violations.append(f"{path}: method={method_name}
nonParameterizedCount={count}")
-if violations:
- print("[R15-A] high-fit candidate likely exists but no parameterized test
found:")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
+python3 scripts/scan_quality_rules.py --baseline-before
/tmp/gen-ut-status-before.txt <ResolvedTestFileSet>
```
-
-5.2 `R15-D` parameterized minimum arguments scan:
+If the user explicitly requested metadata accessor tests in the current turn:
```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-
-PARAM_METHOD_PATTERN =
re.compile(r"@ParameterizedTest(?:\\s*\\([^)]*\\))?\\s*((?:@\\w+(?:\\s*\\([^)]*\\))?\\s*)*)void\\s+(assert\\w+)\\s*\\(",
re.S)
-METHOD_SOURCE_PATTERN = re.compile(r"@MethodSource(?:\\s*\\(([^)]*)\\))?")
-METHOD_DECL_PATTERN =
re.compile(r"(?:private|protected|public)?\\s*(?:static\\s+)?[\\w$<>\\[\\],
?]+\\s+(\\w+)\\s*\\([^)]*\\)\\s*\\{", re.S)
-ARGUMENT_ROW_PATTERN = re.compile(r"\\b(?:Arguments\\.of|arguments)\\s*\\(")
-
-def extract_block(text, brace_index):
- depth = 0
- index = brace_index
- while index < len(text):
- if "{" == text[index]:
- depth += 1
- elif "}" == text[index]:
- depth -= 1
- if 0 == depth:
- return text[brace_index + 1:index]
- index += 1
- return ""
-
-def parse_method_sources(method_name, annotation_block):
- resolved = []
- matches = list(METHOD_SOURCE_PATTERN.finditer(annotation_block))
- if not matches:
- return resolved
- for each in matches:
- raw = each.group(1)
- if raw is None or not raw.strip():
- resolved.append(method_name)
- continue
- raw = raw.strip()
- normalized = re.sub(r"\\bvalue\\s*=\\s*", "", raw)
- names = re.findall(r'"([^"]+)"', normalized)
- for name in names:
- # Ignore external references such as "pkg.Class#method"; they are
unresolved in this scan.
- resolved.append(name.split("#", 1)[-1])
- return resolved
-
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- method_bodies = {}
- for match in METHOD_DECL_PATTERN.finditer(source):
- method_name = match.group(1)
- brace_index = source.find("{", match.start())
- if brace_index < 0:
- continue
- method_bodies[method_name] = extract_block(source, brace_index)
- for match in PARAM_METHOD_PATTERN.finditer(source):
- annotation_block = match.group(1)
- method_name = match.group(2)
- line = source.count("\\n", 0, match.start()) + 1
- source_methods = parse_method_sources(method_name, annotation_block)
- if not source_methods:
- violations.append(f"{path}:{line} method={method_name} missing
@MethodSource")
- continue
- total_rows = 0
- unresolved = []
- for provider in source_methods:
- body = method_bodies.get(provider)
- if body is None:
- unresolved.append(provider)
- continue
- total_rows += len(ARGUMENT_ROW_PATTERN.findall(body))
- if unresolved:
- violations.append(f"{path}:{line} method={method_name}
unresolvedProviders={','.join(unresolved)}")
- continue
- if total_rows < 3:
- violations.append(f"{path}:{line} method={method_name}
argumentsRows={total_rows}")
-if violations:
- print("[R15-D] each @ParameterizedTest must have >= 3 Arguments rows from
@MethodSource")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
+python3 scripts/scan_quality_rules.py --allow-metadata-accessor-tests
--baseline-before /tmp/gen-ut-status-before.txt <ResolvedTestFileSet>
```
-
-5.3 `R15-E` parameterized first-parameter scan:
+The script consolidates repeated file parsing and git-diff inspection without
changing rule accuracy. It also evaluates `R15-C` by comparing the current git
status against `/tmp/gen-ut-status-before.txt`.
+For machine-readable automation or quick summaries, the script also supports:
```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-
-PARAM_METHOD_PATTERN =
re.compile(r"@ParameterizedTest(?:\\s*\\([^)]*\\))?\\s*(?:@\\w+(?:\\s*\\([^)]*\\))?\\s*)*void\\s+(assert\\w+)\\s*\\(([^)]*)\\)",
re.S)
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- for match in PARAM_METHOD_PATTERN.finditer(source):
- method_name = match.group(1)
- params = match.group(2).strip()
- line = source.count("\\n", 0, match.start()) + 1
- if not params:
- violations.append(f"{path}:{line} method={method_name}
missingParameters")
- continue
- first_param = params.split(",", 1)[0].strip()
- normalized = re.sub(r"\\s+", " ", first_param)
- if "final String name" != normalized:
- violations.append(f"{path}:{line} method={method_name}
firstParam={first_param}")
-if violations:
- print("[R15-E] each @ParameterizedTest method must declare first parameter
as `final String name`")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
+python3 scripts/scan_quality_rules.py --json --baseline-before
/tmp/gen-ut-status-before.txt <ResolvedTestFileSet>
+python3 scripts/scan_quality_rules.py --summary-only --baseline-before
/tmp/gen-ut-status-before.txt <ResolvedTestFileSet>
```
-5.4 `R15-F` parameterized switch ban scan:
-```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-
-PARAM_METHOD_PATTERN =
re.compile(r"@ParameterizedTest(?:\\s*\\([^)]*\\))?\\s*(?:@\\w+(?:\\s*\\([^)]*\\))?\\s*)*void\\s+(assert\\w+)\\s*\\([^)]*\\)\\s*\\{",
re.S)
-SWITCH_PATTERN = re.compile(r"\\bswitch\\s*\\(")
-
-def extract_block(text, brace_index):
- depth = 0
- index = brace_index
- while index < len(text):
- if "{" == text[index]:
- depth += 1
- elif "}" == text[index]:
- depth -= 1
- if 0 == depth:
- return text[brace_index + 1:index]
- index += 1
- return ""
-
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- for match in PARAM_METHOD_PATTERN.finditer(source):
- method_name = match.group(1)
- line = source.count("\\n", 0, match.start()) + 1
- brace_index = source.find("{", match.start())
- if brace_index < 0:
- continue
- body = extract_block(source, brace_index)
- if SWITCH_PATTERN.search(body):
- violations.append(f"{path}:{line} method={method_name}")
-if violations:
- print("[R15-F] @ParameterizedTest method body must not contain switch")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
-```
-
-5.5 `R15-G` parameterized nested-type ban scan (diff-based):
-```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import subprocess
-import sys
-from pathlib import Path
-
-TYPE_DECL_PATTERN =
re.compile(r"^\+\s+(?:(?:public|protected|private|static|final|abstract)\s+)*(class|interface|enum|record)\b")
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- if "@ParameterizedTest" not in source:
- continue
- diff = subprocess.run(["git", "diff", "-U0", "--", path], check=True,
capture_output=True, text=True).stdout.splitlines()
- for line in diff:
- if line.startswith("+++") or line.startswith("@@"):
- continue
- if TYPE_DECL_PATTERN.search(line):
- violations.append(f"{path}: {line[1:].strip()}")
-if violations:
- print("[R15-G] parameterized tests must not introduce nested helper type
declarations")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
-```
-
-5.6 `R15-I` parameterized Consumer ban scan:
-```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-
-PARAM_METHOD_PATTERN =
re.compile(r"@ParameterizedTest(?:\\s*\\([^)]*\\))?\\s*(?:@\\w+(?:\\s*\\([^)]*\\))?\\s*)*void\\s+(assert\\w+)\\s*\\(([^)]*)\\)",
re.S)
-METHOD_SOURCE_PATTERN = re.compile(r"@MethodSource(?:\\s*\\(([^)]*)\\))?")
-METHOD_DECL_PATTERN =
re.compile(r"(?:private|protected|public)?\\s*(?:static\\s+)?[\\w$<>\\[\\],
?]+\\s+(\\w+)\\s*\\([^)]*\\)\\s*\\{", re.S)
-CONSUMER_TOKEN_PATTERN = re.compile(r"\\bConsumer\\s*(?:<|\\b)")
-
-def extract_block(text, brace_index):
- depth = 0
- index = brace_index
- while index < len(text):
- if "{" == text[index]:
- depth += 1
- elif "}" == text[index]:
- depth -= 1
- if 0 == depth:
- return text[brace_index + 1:index]
- index += 1
- return ""
-
-def parse_method_sources(method_name, source, method_start):
- header = source[max(0, method_start - 320):method_start]
- matches = list(METHOD_SOURCE_PATTERN.finditer(header))
- if not matches:
- return []
- resolved = []
- for each in matches:
- raw = each.group(1)
- if raw is None or not raw.strip():
- resolved.append(method_name)
- continue
- normalized = re.sub(r"\\bvalue\\s*=\\s*", "", raw.strip())
- for name in re.findall(r'"([^"]+)"', normalized):
- resolved.append(name.split("#", 1)[-1])
- return resolved
-
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- if "@ParameterizedTest" not in source:
- continue
- method_bodies = {}
- for match in METHOD_DECL_PATTERN.finditer(source):
- method_name = match.group(1)
- brace_index = source.find("{", match.start())
- if brace_index < 0:
- continue
- method_bodies[method_name] = extract_block(source, brace_index)
- for match in PARAM_METHOD_PATTERN.finditer(source):
- method_name = match.group(1)
- params = match.group(2)
- line = source.count("\\n", 0, match.start()) + 1
- if CONSUMER_TOKEN_PATTERN.search(params):
- violations.append(f"{path}:{line} method={method_name}
reason=consumerInParameterizedMethodSignature")
- provider_names = parse_method_sources(method_name, source,
match.start())
- for each_provider in provider_names:
- body = method_bodies.get(each_provider)
- if body and CONSUMER_TOKEN_PATTERN.search(body):
- violations.append(f"{path}:{line} method={method_name}
provider={each_provider} reason=consumerInMethodSourceArguments")
-if violations:
- print("[R15-I] parameterized tests must not use Consumer in signatures or
@MethodSource argument rows")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
-```
-
-6. `R14` hard-gate scan:
-```bash
-bash -lc '
-BOOLEAN_ASSERTION_BAN_REGEX="assertThat\s*\((?s:.*?)is\s*\(\s*(?:true|false|Boolean\.TRUE|Boolean\.FALSE)\s*\)\s*\)|assertEquals\s*\(\s*(?:true|false|Boolean\.TRUE|Boolean\.FALSE)\s*,"
-BOOLEAN_ASSERTION_BAN_REGEX+="|assertEquals\s*\((?s:.*?),\s*(?:true|false|Boolean\.TRUE|Boolean\.FALSE)\s*\)"
-if rg -n -U --pcre2 "$BOOLEAN_ASSERTION_BAN_REGEX" <ResolvedTestFileSet>; then
- echo "[R14] forbidden boolean assertion found"
- exit 1
-fi'
-```
-
-6.1 `R15-H` boolean control-flow dispatch scan:
-```bash
-bash -lc '
-python3 - <ResolvedTestFileSet> <<'"'"'PY'"'"'
-import re
-import sys
-from pathlib import Path
-
-METHOD_DECL_PATTERN =
re.compile(r"(?:@Test|@ParameterizedTest(?:\\s*\\([^)]*\\))?(?:\\s*@\\w+(?:\\s*\\([^)]*\\))?)*)\\s*void\\s+(assert\\w+)\\s*\\([^)]*\\)\\s*\\{",
re.S)
-IF_ELSE_PATTERN =
re.compile(r"if\\s*\\([^)]*\\)\\s*\\{[\\s\\S]*?assertTrue\\s*\\([^;]+\\)\\s*;[\\s\\S]*?\\}\\s*else\\s*\\{[\\s\\S]*?assertFalse\\s*\\([^;]+\\)\\s*;[\\s\\S]*?\\}|if\\s*\\([^)]*\\)\\s*\\{[\\s\\S]*?assertFalse\\s*\\([^;]+\\)\\s*;[\\s\\S]*?\\}\\s*else\\s*\\{[\\s\\S]*?assertTrue\\s*\\([^;]+\\)\\s*;[\\s\\S]*?\\}",
re.S)
-IF_RETURN_PATTERN =
re.compile(r"if\\s*\\([^)]*\\)\\s*\\{[\\s\\S]*?assertTrue\\s*\\([^;]+\\)\\s*;[\\s\\S]*?return\\s*;[\\s\\S]*?\\}\\s*assertFalse\\s*\\([^;]+\\)\\s*;|if\\s*\\([^)]*\\)\\s*\\{[\\s\\S]*?assertFalse\\s*\\([^;]+\\)\\s*;[\\s\\S]*?return\\s*;[\\s\\S]*?\\}\\s*assertTrue\\s*\\([^;]+\\)\\s*;",
re.S)
-
-def extract_block(text, brace_index):
- depth = 0
- i = brace_index
- while i < len(text):
- if "{" == text[i]:
- depth += 1
- elif "}" == text[i]:
- depth -= 1
- if 0 == depth:
- return text[brace_index + 1:i]
- i += 1
- return ""
-
-violations = []
-for path in (each for each in sys.argv[1:] if each.endswith(".java")):
- source = Path(path).read_text(encoding="utf-8")
- for match in METHOD_DECL_PATTERN.finditer(source):
- method_name = match.group(1)
- line = source.count("\\n", 0, match.start()) + 1
- brace_index = source.find("{", match.start())
- if brace_index < 0:
- continue
- body = extract_block(source, brace_index)
- if IF_ELSE_PATTERN.search(body) or IF_RETURN_PATTERN.search(body):
- violations.append(f"{path}:{line} method={method_name}")
-if violations:
- print("[R15-H] do not dispatch boolean assertions by control flow to
choose assertTrue/assertFalse")
- for each in violations:
- print(each)
- sys.exit(1)
-PY
-'
-```
-
-7. `R15-B` metadata accessor test ban scan (skip only when explicitly
requested by user):
-```bash
-bash -lc '
-if rg -n -U
"@Test(?s:.*?)void\\s+assert\\w*(GetType|GetOrder|GetTypeClass)\\b|assertThat\\((?s:.*?)\\.getType\\(\\)|assertThat\\((?s:.*?)\\.getOrder\\(\\)|assertThat\\((?s:.*?)\\.getTypeClass\\(\\)"
<ResolvedTestFileSet>; then
- echo "[R15-B] metadata accessor test detected without explicit user request"
- exit 1
-fi'
-```
-
-8. Scope validation:
+6. Scope validation:
```bash
git diff --name-only
```
-9. `R15-C` production-path mutation guard (baseline-based):
-```bash
-bash -lc '
-# capture once at task start:
-# git status --porcelain > /tmp/gen-ut-status-before.txt
-git status --porcelain > /tmp/gen-ut-status-after.txt
-python3 - <<'"'"'PY'"'"'
-from pathlib import Path
-
-before_path = Path("/tmp/gen-ut-status-before.txt")
-after_path = Path("/tmp/gen-ut-status-after.txt")
-before = set(before_path.read_text(encoding="utf-8").splitlines()) if
before_path.exists() else set()
-after = set(after_path.read_text(encoding="utf-8").splitlines())
-introduced = sorted(after - before)
-violations = []
-for each in introduced:
- path = each[3:].strip()
- if "/src/main/" in path or path.startswith("src/main/"):
- violations.append(path)
-if violations:
- print("[R15-C] out-of-scope production path modified:")
- for each in violations:
- print(each)
- raise SystemExit(1)
-PY
-'
-```
-
## Final Output Requirements
- `MUST` include a status line `R10=<state>`.
diff --git a/.codex/skills/gen-ut/scripts/collect_quality_baseline.py
b/.codex/skills/gen-ut/scripts/collect_quality_baseline.py
new file mode 100644
index 00000000000..953e6d8f907
--- /dev/null
+++ b/.codex/skills/gen-ut/scripts/collect_quality_baseline.py
@@ -0,0 +1,187 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/usr/bin/env python3
+"""
+Collect a baseline quality summary for gen-ut before editing begins.
+"""
+
+import argparse
+import subprocess
+import sys
+import time
+import xml.etree.ElementTree as ET
+from dataclasses import dataclass
+from pathlib import Path
+
+
+SCRIPT_DIR = Path(__file__).resolve().parent
+if str(SCRIPT_DIR) not in sys.path:
+ sys.path.insert(0, str(SCRIPT_DIR))
+
+import scan_quality_rules as quality_rules
+
+
+@dataclass(frozen=True)
+class CommandResult:
+ command: str
+ returncode: int
+ stdout: str
+ stderr: str
+ duration_seconds: float
+
+
+def run_command(command: str, workdir: Path) -> CommandResult:
+ started = time.monotonic()
+ completed = subprocess.run(command, shell=True, cwd=workdir,
capture_output=True, text=True)
+ return CommandResult(
+ command=command,
+ returncode=completed.returncode,
+ stdout=completed.stdout,
+ stderr=completed.stderr,
+ duration_seconds=time.monotonic() - started,
+ )
+
+
+def parse_target_classes(raw: str) -> list[str]:
+ result = [each.strip() for each in raw.split(",") if each.strip()]
+ if result:
+ return result
+ raise ValueError("target-classes must not be empty")
+
+
+def validate_workdir(path: Path) -> Path:
+ workdir = path.resolve()
+ if not workdir.exists():
+ raise ValueError(f"working directory does not exist: {workdir}")
+ if not workdir.is_dir():
+ raise ValueError(f"working directory is not a directory: {workdir}")
+ return workdir
+
+
+def find_sourcefile_node(root: ET.Element, fqcn: str) -> ET.Element | None:
+ package_name, simple_name = fqcn.rsplit(".", 1)
+ package_path = package_name.replace(".", "/")
+ source_name = f"{simple_name}.java"
+ for package in root.findall("package"):
+ if package.get("name") != package_path:
+ continue
+ for sourcefile in package.findall("sourcefile"):
+ if sourcefile.get("name") == source_name:
+ return sourcefile
+ return None
+
+
+def summarize_target_coverage(root: ET.Element, fqcn: str) -> tuple[dict[str,
tuple[int, int, float]], list[int]]:
+ class_name = fqcn.replace(".", "/")
+ matched_nodes = [each for each in root.iter("class") if each.get("name")
== class_name or each.get("name", "").startswith(class_name + "$")]
+ counters = {}
+ for counter_type in ("CLASS", "LINE", "BRANCH"):
+ covered = 0
+ missed = 0
+ for each in matched_nodes:
+ counter = next((item for item in each.findall("counter") if
item.get("type") == counter_type), None)
+ if counter is None:
+ continue
+ covered += int(counter.get("covered"))
+ missed += int(counter.get("missed"))
+ total = covered + missed
+ counters[counter_type] = (covered, missed, 100.0 if 0 == total else
covered * 100.0 / total)
+ sourcefile = find_sourcefile_node(root, fqcn)
+ missed_branch_lines = []
+ if sourcefile is not None:
+ missed_branch_lines = [int(each.get("nr")) for each in
sourcefile.findall("line") if int(each.get("mb", "0")) > 0]
+ return counters, missed_branch_lines
+
+
+def print_rule_baseline(scan_result: dict) -> None:
+ print(f"[baseline] javaFiles={scan_result['java_file_count']}")
+ if scan_result["candidates"]:
+ print("[R8-CANDIDATES]")
+ for each in scan_result["candidates"]:
+ print(quality_rules.describe_candidate(each))
+ else:
+ print("[R8-CANDIDATES] no candidates")
+ for rule in quality_rules.RULE_ORDER:
+ violations = scan_result["rules"][rule]["violations"]
+ if violations:
+ print(f"[{rule}] {scan_result['rules'][rule]['message']}")
+ for each in violations:
+ print(each)
+ else:
+ print(f"[{rule}] ok")
+
+
+def print_coverage_baseline(jacoco_xml_path: Path, target_classes: list[str])
-> None:
+ root = ET.parse(jacoco_xml_path).getroot()
+ for fqcn in target_classes:
+ counters, missed_branch_lines = summarize_target_coverage(root, fqcn)
+ for counter_type in ("CLASS", "LINE", "BRANCH"):
+ covered, missed, ratio = counters[counter_type]
+ print(f"[baseline] {fqcn} (+inner) {counter_type}
covered={covered} missed={missed} ratio={ratio:.2f}%")
+ if missed_branch_lines:
+ line_text = ",".join(str(each) for each in missed_branch_lines)
+ print(f"[baseline] {fqcn} branchMissLines={line_text}")
+ else:
+ print(f"[baseline] {fqcn} branchMissLines=none")
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(description="Collect baseline coverage
and quality-rule diagnostics for gen-ut.")
+ parser.add_argument("--workdir", default=".", help="Working directory for
the coverage command.")
+ parser.add_argument("--coverage-command", required=True, help="Shell
command that generates the jacoco report for the target test scope.")
+ parser.add_argument("--jacoco-xml-path", required=True, help="Path to the
jacoco.xml generated by the coverage command.")
+ parser.add_argument("--target-classes", required=True,
help="Comma-separated target production classes.")
+ parser.add_argument("--baseline-before", help="Path to the baseline git
status captured at task start.")
+ parser.add_argument("--allow-metadata-accessor-tests",
action="store_true", help="Allow R15-B when the user explicitly requested
metadata accessor tests.")
+ parser.add_argument("paths", nargs="+", help="Resolved test file set.")
+ args = parser.parse_args()
+
+ try:
+ workdir = validate_workdir(Path(args.workdir))
+ target_classes = parse_target_classes(args.target_classes)
+ except ValueError as ex:
+ parser.error(str(ex))
+ java_paths = [Path(each) for each in args.paths if each.endswith(".java")]
+ baseline_path = Path(args.baseline_before) if args.baseline_before else
None
+ scan_result = quality_rules.collect_scan_result(java_paths, baseline_path,
args.allow_metadata_accessor_tests)
+ print_rule_baseline(scan_result)
+ coverage_result = run_command(args.coverage_command, workdir)
+ print(f"[baseline] coverageCommandExit={coverage_result.returncode}
duration={coverage_result.duration_seconds:.2f}s")
+ print(f"[baseline] coverageCommand={coverage_result.command}")
+ if 0 != coverage_result.returncode:
+ if coverage_result.stdout.strip():
+ print("[baseline] coverageStdout:")
+ print(coverage_result.stdout.rstrip())
+ if coverage_result.stderr.strip():
+ print("[baseline] coverageStderr:")
+ print(coverage_result.stderr.rstrip())
+ return coverage_result.returncode
+ jacoco_xml_path = Path(args.jacoco_xml_path)
+ if not jacoco_xml_path.exists():
+ print(f"[baseline] missingJacocoXml={jacoco_xml_path}",
file=sys.stderr)
+ return 2
+ try:
+ print_coverage_baseline(jacoco_xml_path, target_classes)
+ except (OSError, ET.ParseError) as ex:
+ print(f"[baseline] invalidJacocoXml={jacoco_xml_path}: {ex}",
file=sys.stderr)
+ return 2
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.codex/skills/gen-ut/scripts/run_quality_gates.py
b/.codex/skills/gen-ut/scripts/run_quality_gates.py
new file mode 100644
index 00000000000..e41b4ba5559
--- /dev/null
+++ b/.codex/skills/gen-ut/scripts/run_quality_gates.py
@@ -0,0 +1,286 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/usr/bin/env python3
+"""
+Run independent quality gates in parallel with optional gate-level reuse.
+"""
+
+import argparse
+import json
+import sys
+import subprocess
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass
+from pathlib import Path
+
+
+SCRIPT_DIR = Path(__file__).resolve().parent
+if str(SCRIPT_DIR) not in sys.path:
+ sys.path.insert(0, str(SCRIPT_DIR))
+
+import verification_gate_state as snapshot_state
+
+
+@dataclass(frozen=True)
+class GateSpec:
+ name: str
+ command: str
+
+
+@dataclass(frozen=True)
+class GateResult:
+ spec: GateSpec
+ status: str
+ returncode: int
+ stdout: str
+ stderr: str
+ duration_seconds: float
+ state_gate: str | None = None
+ digest: str | None = None
+
+
+def parse_gate(raw: str) -> GateSpec:
+ if "=" not in raw:
+ raise ValueError(f"invalid gate definition: {raw!r}")
+ name, command = raw.split("=", 1)
+ name = name.strip()
+ command = command.strip()
+ if not name or not command:
+ raise ValueError(f"invalid gate definition: {raw!r}")
+ return GateSpec(name=name, command=command)
+
+
+def parse_gates(raw_gates: list[str]) -> list[GateSpec]:
+ result = []
+ seen_names = set()
+ for each in raw_gates:
+ spec = parse_gate(each)
+ if spec.name in seen_names:
+ raise ValueError(f"duplicate gate name: {spec.name}")
+ seen_names.add(spec.name)
+ result.append(spec)
+ return result
+
+
+def parse_gate_mapping(raw: str) -> tuple[str, str]:
+ if "=" not in raw:
+ normalized = raw.strip()
+ if not normalized:
+ raise ValueError(f"invalid gate mapping: {raw!r}")
+ return normalized, normalized
+ gate_name, gate_key = raw.split("=", 1)
+ gate_name = gate_name.strip()
+ gate_key = gate_key.strip()
+ if not gate_name or not gate_key:
+ raise ValueError(f"invalid gate mapping: {raw!r}")
+ return gate_name, gate_key
+
+
+def parse_gate_mappings(raw_values: list[str]) -> dict[str, str]:
+ result = {}
+ for each in raw_values:
+ gate_name, gate_key = parse_gate_mapping(each)
+ if gate_name in result and result[gate_name] != gate_key:
+ raise ValueError(f"duplicate gate mapping for {gate_name}:
{result[gate_name]} vs {gate_key}")
+ result[gate_name] = gate_key
+ return result
+
+
+def validate_gate_mappings(specs: list[GateSpec], mappings: dict[str, str],
option_name: str) -> None:
+ valid_names = {each.name for each in specs}
+ unknown = sorted(each for each in mappings if each not in valid_names)
+ if unknown:
+ raise ValueError(f"{option_name} references unknown gate(s): {',
'.join(unknown)}")
+
+
+def flatten_tracked_paths(values: list[list[str]]) -> list[str]:
+ result = []
+ for group in values:
+ result.extend(group)
+ return result
+
+
+def validate_workdir(path: Path) -> Path:
+ workdir = path.resolve()
+ if not workdir.exists():
+ raise ValueError(f"working directory does not exist: {workdir}")
+ if not workdir.is_dir():
+ raise ValueError(f"working directory is not a directory: {workdir}")
+ return workdir
+
+
+def validate_state_inputs(tracked_paths: list[str], reuse_mapping: dict[str,
str], record_mapping: dict[str, str], state_file: str | None) -> None:
+ if not state_file and (reuse_mapping or record_mapping):
+ raise ValueError("state-file is required when reuse-gate or
record-gate is configured")
+ if state_file and not tracked_paths:
+ raise ValueError("tracked-path is required when state-file is
configured")
+
+
+def current_digest(tracked_paths: list[str]) -> tuple[str, int]:
+ normalized_paths = snapshot_state.tracked_paths(tracked_paths)
+ return snapshot_state.calculate_digest(normalized_paths),
len(normalized_paths)
+
+
+def gate_digest(base_digest: str, spec: GateSpec, workdir: Path) -> str:
+ return snapshot_state.extend_digest(base_digest, [str(workdir), spec.name,
spec.command])
+
+
+def run_gate(spec: GateSpec, workdir: Path) -> GateResult:
+ started = time.monotonic()
+ completed = subprocess.run(spec.command, shell=True, cwd=workdir,
capture_output=True, text=True)
+ return GateResult(
+ spec=spec,
+ status="executed",
+ returncode=completed.returncode,
+ stdout=completed.stdout,
+ stderr=completed.stderr,
+ duration_seconds=time.monotonic() - started,
+ )
+
+
+def run_serial(specs: list[GateSpec], workdir: Path) -> list[GateResult]:
+ return [run_gate(spec, workdir) for spec in specs]
+
+
+def run_parallel(specs: list[GateSpec], workdir: Path, max_parallel: int) ->
list[GateResult]:
+ ordered_results: list[GateResult | None] = [None] * len(specs)
+ with ThreadPoolExecutor(max_workers=max_parallel) as executor:
+ future_to_index = {executor.submit(run_gate, spec, workdir): index for
index, spec in enumerate(specs)}
+ for future in as_completed(future_to_index):
+ ordered_results[future_to_index[future]] = future.result()
+ return [each for each in ordered_results if each is not None]
+
+
+def reusable_result(spec: GateSpec, gate_key: str, digest: str) -> GateResult:
+ return GateResult(spec=spec, status="reused", returncode=0, stdout="",
stderr="", duration_seconds=0.0, state_gate=gate_key, digest=digest)
+
+
+def print_result(result: GateResult) -> None:
+ if "reused" == result.status:
+ print(f"[gate:{result.spec.name}] reused gate={result.state_gate}
digest={result.digest}")
+ return
+ print(f"[gate:{result.spec.name}] exit={result.returncode}
duration={result.duration_seconds:.2f}s")
+ print(f"[gate:{result.spec.name}] cmd={result.spec.command}")
+ if 0 != result.returncode:
+ if result.stdout.strip():
+ print(f"[gate:{result.spec.name}] stdout:")
+ print(result.stdout.rstrip())
+ if result.stderr.strip():
+ print(f"[gate:{result.spec.name}] stderr:")
+ print(result.stderr.rstrip())
+
+
+def load_state(state_file: str | None) -> dict:
+ return snapshot_state.read_state(Path(state_file)) if state_file else {}
+
+
+def gate_digests(specs: list[GateSpec], base_digest: str | None, workdir:
Path) -> dict[str, str]:
+ if base_digest is None:
+ return {}
+ return {each.name: gate_digest(base_digest, each, workdir) for each in
specs}
+
+
+def partition_specs_for_execution(
+ specs: list[GateSpec], reuse_mapping: dict[str, str], digests:
dict[str, str], state: dict,
+) -> tuple[list[GateResult | None], list[GateSpec], list[int]]:
+ indexed_results: list[GateResult | None] = [None] * len(specs)
+ executable_specs = []
+ executable_indices = []
+ for index, spec in enumerate(specs):
+ gate_key = reuse_mapping.get(spec.name)
+ gate_result_digest = digests.get(spec.name)
+ if gate_key and gate_result_digest and
snapshot_state.matches_gate_digest(state, gate_key, gate_result_digest):
+ indexed_results[index] = reusable_result(spec, gate_key,
gate_result_digest)
+ continue
+ executable_specs.append(spec)
+ executable_indices.append(index)
+ return indexed_results, executable_specs, executable_indices
+
+
+def execute_specs(specs: list[GateSpec], workdir: Path, max_parallel: int) ->
list[GateResult]:
+ if 1 == max_parallel:
+ return run_serial(specs, workdir)
+ return run_parallel(specs, workdir, max_parallel)
+
+
+def mark_green_gates(
+ state: dict, record_mapping: dict[str, str], results:
list[GateResult], gate_digests: dict[str, str], tracked_file_count: int | None,
+) -> bool:
+ if not gate_digests or tracked_file_count is None:
+ return False
+ updated = False
+ for each in results:
+ if "executed" != each.status or 0 != each.returncode or each.spec.name
not in record_mapping:
+ continue
+ snapshot_state.set_gate_digest(state, record_mapping[each.spec.name],
gate_digests[each.spec.name], tracked_file_count)
+ updated = True
+ return updated
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(description="Run independent quality
gates in parallel.")
+ parser.add_argument("--workdir", default=".", help="Working directory for
every gate command.")
+ parser.add_argument("--serial", action="store_true", help="Run gates
serially instead of in parallel.")
+ parser.add_argument("--max-parallel", type=int, default=4, help="Maximum
parallel gate count.")
+ parser.add_argument("--state-file", help="Optional state file used for
gate-level reuse.")
+ parser.add_argument("--tracked-path", nargs="+", action="append",
default=[], help="Tracked file paths used to compute the verification digest.")
+ parser.add_argument("--reuse-gate", action="append", default=[],
help="Reuse mapping in the form gate-name[=state-gate].")
+ parser.add_argument("--record-gate", action="append", default=[],
help="Record mapping in the form gate-name[=state-gate].")
+ parser.add_argument("--gate", action="append", required=True, help="Gate
definition in the form name=command.")
+ args = parser.parse_args()
+
+ try:
+ workdir = validate_workdir(Path(args.workdir))
+ specs = parse_gates(args.gate)
+ reuse_mapping = parse_gate_mappings(args.reuse_gate)
+ record_mapping = parse_gate_mappings(args.record_gate)
+ validate_gate_mappings(specs, reuse_mapping, "--reuse-gate")
+ validate_gate_mappings(specs, record_mapping, "--record-gate")
+ tracked_paths = flatten_tracked_paths(args.tracked_path)
+ validate_state_inputs(tracked_paths, reuse_mapping, record_mapping,
args.state_file)
+ except ValueError as ex:
+ parser.error(str(ex))
+ try:
+ base_digest = None
+ tracked_file_count = None
+ state = load_state(args.state_file)
+ if tracked_paths:
+ base_digest, tracked_file_count = current_digest(tracked_paths)
+ gate_digest_map = gate_digests(specs, base_digest, workdir)
+ max_parallel = 1 if args.serial else max(1, min(args.max_parallel,
len(specs)))
+ print(f"[quality-gates] mode={'serial' if 1 == max_parallel else
'parallel'} gateCount={len(specs)} workdir={workdir}")
+ indexed_results, executable_specs, executable_indices =
partition_specs_for_execution(specs, reuse_mapping, gate_digest_map, state)
+ executed_results = execute_specs(executable_specs, workdir,
max_parallel)
+ for index, result in zip(executable_indices, executed_results):
+ indexed_results[index] = result
+ results = [each for each in indexed_results if each is not None]
+ failed = False
+ for each in results:
+ print_result(each)
+ failed = failed or 0 != each.returncode
+ if args.state_file and mark_green_gates(state, record_mapping,
results, gate_digest_map, tracked_file_count):
+ snapshot_state.write_state(Path(args.state_file), state)
+ return 1 if failed else 0
+ except (ValueError, OSError, json.JSONDecodeError) as ex:
+ print(ex, file=sys.stderr)
+ return 2
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.codex/skills/gen-ut/scripts/scan_quality_rules.py
b/.codex/skills/gen-ut/scripts/scan_quality_rules.py
new file mode 100644
index 00000000000..d53ae7b0a7b
--- /dev/null
+++ b/.codex/skills/gen-ut/scripts/scan_quality_rules.py
@@ -0,0 +1,587 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/usr/bin/env python3
+"""
+Consolidated quality-rule scan for gen-ut.
+"""
+
+import argparse
+import json
+import re
+import subprocess
+import sys
+from collections import defaultdict
+from dataclasses import asdict
+from dataclasses import dataclass
+from pathlib import Path
+
+
+@dataclass(frozen=True)
+class CandidateSummary:
+ path: str
+ method: str
+ plain_test_count: int
+ parameterized_present: bool
+ high_fit: bool
+
+
+@dataclass(frozen=True)
+class RuleSpec:
+ name: str
+ message: str
+
+
+@dataclass(frozen=True)
+class FileScanContext:
+ path: Path
+ source: str
+ method_bodies: dict[str, str]
+ candidates: list[CandidateSummary]
+
+
+RULE_ORDER = ("R8", "R14", "R15-A", "R15-B", "R15-C", "R15-D", "R15-E",
"R15-F", "R15-G", "R15-H", "R15-I")
+RULE_MESSAGES = {
+ "R8": "@ParameterizedTest must use name = \"{0}\"",
+ "R14": "forbidden boolean assertion found",
+ "R15-A": "high-fit candidate likely exists but no parameterized test
found",
+ "R15-B": "metadata accessor test detected without explicit user request",
+ "R15-C": "out-of-scope production path modified",
+ "R15-D": "each @ParameterizedTest must have >= 3 Arguments rows from
@MethodSource",
+ "R15-E": "each @ParameterizedTest method must declare first parameter as
`final String name`",
+ "R15-F": "@ParameterizedTest method body must not contain switch",
+ "R15-G": "parameterized tests must not introduce nested helper type
declarations",
+ "R15-H": "do not dispatch boolean assertions by control flow to choose
assertTrue/assertFalse",
+ "R15-I": "parameterized tests must not use Consumer in signatures or
@MethodSource argument rows",
+}
+RULE_SPECS = tuple(RuleSpec(each, RULE_MESSAGES[each]) for each in RULE_ORDER)
+BOOLEAN_ASSERTION_BAN_PATTERN = re.compile(
+
r"assertThat\s*\((?s:.*?)is\s*\(\s*(?:true|false|Boolean\.TRUE|Boolean\.FALSE)\s*\)\s*\)"
+ r"|assertEquals\s*\(\s*(?:true|false|Boolean\.TRUE|Boolean\.FALSE)\s*,"
+
r"|assertEquals\s*\((?s:.*?),\s*(?:true|false|Boolean\.TRUE|Boolean\.FALSE)\s*\)",
+ re.S,
+)
+CONSUMER_TOKEN_PATTERN = re.compile(r"\bConsumer\s*(?:<|\b)")
+CONSTRUCTOR_CALL_PATTERN = re.compile(r"\bnew\s+(\w+)\s*\(")
+METHOD_DECL_PATTERN =
re.compile(r"(?:private|protected|public)?\s*(?:static\s+)?[\w$<>\[\],
?]+\s+(\w+)\s*\([^)]*\)\s*(?:throws [^{]+)?\{", re.S)
+METHOD_SOURCE_PATTERN = re.compile(r"@MethodSource(?:\s*\(([^)]*)\))?")
+PARAM_METHOD_BODY_PATTERN = re.compile(
+
r"@ParameterizedTest(?:\s*\([^)]*\))?\s*(?:@\w+(?:\s*\([^)]*\))?\s*)*void\s+(assert\w+)\s*\([^)]*\)\s*(?:throws
[^{]+)?\{",
+ re.S,
+)
+PARAM_METHOD_PATTERN = re.compile(
+
r"@ParameterizedTest(?:\s*\([^)]*\))?\s*((?:@\w+(?:\s*\([^)]*\))?\s*)*)void\s+(assert\w+)\s*\(([^)]*)\)\s*(?:throws
[^{]+)?",
+ re.S,
+)
+TEST_METHOD_DECL_PATTERN = re.compile(
+
r"((?:@Test(?:\s*\([^)]*\))?|@ParameterizedTest(?:\s*\([^)]*\))?)\s*(?:@\w+(?:\s*\([^)]*\))?\s*)*)"
+ r"void\s+(assert\w+)\s*\([^)]*\)\s*(?:throws [^{]+)?\{",
+ re.S,
+)
+R15_A_CALL_PATTERN = re.compile(r"\b\w+\.(\w+)\s*\(")
+R15_A_IGNORE = {"assertThat", "assertTrue", "assertFalse", "mock", "when",
"verify", "is", "not"}
+R15_G_TYPE_DECL_PATTERN = re.compile(
+
r"^\+\s+(?:(?:public|protected|private|static|final|abstract|sealed|non-sealed)\s+)*(class|interface|enum|record)\b"
+)
+R15_H_IF_ELSE_PATTERN = re.compile(
+
r"if\s*\([^)]*\)\s*\{[\s\S]*?assertTrue\s*\([^;]+\)\s*;[\s\S]*?\}\s*else\s*\{[\s\S]*?assertFalse\s*\([^;]+\)\s*;[\s\S]*?\}"
+
r"|if\s*\([^)]*\)\s*\{[\s\S]*?assertFalse\s*\([^;]+\)\s*;[\s\S]*?\}\s*else\s*\{[\s\S]*?assertTrue\s*\([^;]+\)\s*;[\s\S]*?\}",
+ re.S,
+)
+R15_H_IF_RETURN_PATTERN = re.compile(
+
r"if\s*\([^)]*\)\s*\{[\s\S]*?assertTrue\s*\([^;]+\)\s*;[\s\S]*?return\s*;[\s\S]*?\}\s*assertFalse\s*\([^;]+\)\s*;"
+
r"|if\s*\([^)]*\)\s*\{[\s\S]*?assertFalse\s*\([^;]+\)\s*;[\s\S]*?return\s*;[\s\S]*?\}\s*assertTrue\s*\([^;]+\)\s*;",
+ re.S,
+)
+R15_NAME_PATTERN = re.compile(r'name\s*=\s*"\{0\}"')
+R15_SWITCH_PATTERN = re.compile(r"\bswitch\s*\(")
+R15_B_PATTERN = re.compile(
+ r"@Test(?s:.*?)void\s+assert\w*(GetType|GetOrder|GetTypeClass)\b"
+ r"|assertThat\((?s:.*?)\.getType\(\)"
+ r"|assertThat\((?s:.*?)\.getOrder\(\)"
+ r"|assertThat\((?s:.*?)\.getTypeClass\(\)",
+ re.S,
+)
+TYPE_DECL_LINE_PATTERN = re.compile(
+
r"^\s*(?:(?:public|protected|private|static|final|abstract|sealed|non-sealed)\s+)*(class|interface|enum|record)\s+(\w+)\b"
+)
+UNTRACKED_STATUS_PREFIX = "?? "
+CONSTRUCTOR_TEST_PREFIXES = ("New", "Construct", "Constructor")
+
+
+def line_number(source: str, index: int) -> int:
+ return source.count("\n", 0, index) + 1
+
+
+def opening_brace_index(match: re.Match[str]) -> int:
+ return match.end() - 1
+
+
+def extract_block(text: str, brace_index: int) -> str:
+ depth = 0
+ index = brace_index
+ while index < len(text):
+ if "{" == text[index]:
+ depth += 1
+ elif "}" == text[index]:
+ depth -= 1
+ if 0 == depth:
+ return text[brace_index + 1:index]
+ index += 1
+ return ""
+
+
+def parse_method_sources(method_name: str, annotation_block: str) -> list[str]:
+ result = []
+ matches = list(METHOD_SOURCE_PATTERN.finditer(annotation_block))
+ if not matches:
+ return result
+ for each in matches:
+ raw = each.group(1)
+ if raw is None or not raw.strip():
+ result.append(method_name)
+ continue
+ normalized = re.sub(r"\bvalue\s*=\s*", "", raw.strip())
+ for name in re.findall(r'"([^"]+)"', normalized):
+ result.append(name.split("#", 1)[-1])
+ return result
+
+
+def parse_method_bodies(source: str) -> dict[str, str]:
+ result = {}
+ for match in METHOD_DECL_PATTERN.finditer(source):
+ method_name = match.group(1)
+ brace_index = opening_brace_index(match)
+ if brace_index >= 0:
+ result[method_name] = extract_block(source, brace_index)
+ return result
+
+
+def run_git_command(args: list[str]) -> str:
+ return subprocess.run(args, check=True, capture_output=True,
text=True).stdout
+
+
+def get_git_diff_lines(path: Path, *, cached: bool = False) -> list[str]:
+ command = ["git", "diff"]
+ if cached:
+ command.append("--cached")
+ command.extend(["-U0", "--", str(path)])
+ return run_git_command(command).splitlines()
+
+
+def get_status_line_for_path(path: Path) -> str | None:
+ output = run_git_command(["git", "status", "--porcelain", "--", str(path)])
+ lines = [each for each in output.splitlines() if each]
+ return lines[0] if lines else None
+
+
+def get_added_lines_for_path(path: Path) -> list[str]:
+ result = []
+ for cached in (False, True):
+ result.extend(get_git_diff_lines(path, cached=cached))
+ if result:
+ return list(dict.fromkeys(result))
+ status_line = get_status_line_for_path(path)
+ if status_line and status_line.startswith(UNTRACKED_STATUS_PREFIX):
+ return [f"+{each}" for each in
path.read_text(encoding="utf-8").splitlines()]
+ return result
+
+
+def get_top_level_class_name(source: str) -> str | None:
+ for line in source.splitlines():
+ match = TYPE_DECL_LINE_PATTERN.match(line)
+ if match:
+ return match.group(2)
+ return None
+
+
+def get_target_type_name(source: str) -> str | None:
+ top_level_class_name = get_top_level_class_name(source)
+ if top_level_class_name and top_level_class_name.endswith("Test"):
+ return top_level_class_name[:-4]
+ return None
+
+
+def get_after_status_lines() -> set[str]:
+ output = run_git_command(["git", "status", "--porcelain"])
+ return set(each for each in output.splitlines() if each)
+
+
+def is_src_main_path(path: str) -> bool:
+ return "/src/main/" in path or path.startswith("src/main/")
+
+
+def normalize_status_path(line: str) -> str:
+ path = line[3:].strip()
+ if " -> " in path:
+ path = path.split(" -> ", 1)[1].strip()
+ return path
+
+
+def list_distinct(values: list[str]) -> list[str]:
+ return list(dict.fromkeys(values))
+
+
+def extract_invoked_methods(body: str) -> list[str]:
+ return list_distinct([each for each in R15_A_CALL_PATTERN.findall(body) if
each not in R15_A_IGNORE])
+
+
+def extract_constructed_types(body: str) -> list[str]:
+ return list_distinct(CONSTRUCTOR_CALL_PATTERN.findall(body))
+
+
+def method_name_prefix(method_name: str) -> str:
+ return method_name[0].upper() + method_name[1:] if method_name else
method_name
+
+
+def infer_candidate_target(test_method_name: str, invoked_methods: list[str],
constructed_types: list[str], target_type_name: str | None) -> str | None:
+ raw_name = test_method_name[6:] if test_method_name.startswith("assert")
else test_method_name
+ if target_type_name and target_type_name in constructed_types and
(raw_name.startswith(f"New{target_type_name}") or
raw_name.startswith(CONSTRUCTOR_TEST_PREFIXES)):
+ return f"constructor:{target_type_name}"
+ for candidate_name in (raw_name, raw_name[3:] if
raw_name.startswith("Not") else raw_name):
+ matching_methods = [each for each in invoked_methods if
candidate_name.startswith(method_name_prefix(each))]
+ if matching_methods:
+ return max(matching_methods, key=len)
+ if 1 == len(invoked_methods):
+ return invoked_methods[0]
+ return None
+
+
+def analyze_parameterization_candidates(path: Path, source: str) ->
list[CandidateSummary]:
+ target_type_name = get_target_type_name(source)
+ statistics = defaultdict(lambda: {"plain": 0, "parameterized": False})
+ for match in TEST_METHOD_DECL_PATTERN.finditer(source):
+ annotation_block = match.group(1)
+ test_method_name = match.group(2)
+ brace_index = opening_brace_index(match)
+ body = extract_block(source, brace_index)
+ invoked_methods = extract_invoked_methods(body)
+ constructed_types = extract_constructed_types(body)
+ target = infer_candidate_target(test_method_name, invoked_methods,
constructed_types, target_type_name)
+ if target is None:
+ continue
+ if "@ParameterizedTest" in annotation_block:
+ statistics[target]["parameterized"] = True
+ else:
+ statistics[target]["plain"] += 1
+ result = []
+ for method_name in sorted(statistics):
+ plain_test_count = statistics[method_name]["plain"]
+ parameterized_present = statistics[method_name]["parameterized"]
+ if plain_test_count >= 3 or parameterized_present:
+ result.append(CandidateSummary(
+ path=str(path),
+ method=method_name,
+ plain_test_count=plain_test_count,
+ parameterized_present=parameterized_present,
+ high_fit=plain_test_count >= 3,
+ ))
+ return result
+
+
+def describe_candidate(candidate: dict) -> str:
+ decision = "recommend refactor" if candidate["high_fit"] and not
candidate["parameterized_present"] else "already parameterized" if
candidate["parameterized_present"] else "observe"
+ return f'{candidate["path"]}: method={candidate["method"]}
plainTestCount={candidate["plain_test_count"]}
parameterizedPresent={candidate["parameterized_present"]} decision={decision}'
+
+
+def check_parameterized_name(path: Path, source: str) -> list[str]:
+ violations = []
+ token = "@ParameterizedTest"
+ pos = 0
+ while True:
+ token_pos = source.find(token, pos)
+ if token_pos < 0:
+ break
+ line = line_number(source, token_pos)
+ cursor = token_pos + len(token)
+ while cursor < len(source) and source[cursor].isspace():
+ cursor += 1
+ if cursor >= len(source) or "(" != source[cursor]:
+ violations.append(f"{path}:{line}")
+ pos = token_pos + len(token)
+ continue
+ depth = 1
+ end = cursor + 1
+ while end < len(source) and depth:
+ if "(" == source[end]:
+ depth += 1
+ elif ")" == source[end]:
+ depth -= 1
+ end += 1
+ if depth or not R15_NAME_PATTERN.search(source[cursor + 1:end - 1]):
+ violations.append(f"{path}:{line}")
+ pos = end
+ return violations
+
+
+def check_r15_a(candidates: list[CandidateSummary]) -> list[str]:
+ result = []
+ for each in candidates:
+ if each.high_fit and not each.parameterized_present:
+ result.append(f"{each.path}: method={each.method}
nonParameterizedCount={each.plain_test_count}")
+ return result
+
+
+def check_r15_d(path: Path, source: str, method_bodies: dict[str, str]) ->
list[str]:
+ violations = []
+ for match in PARAM_METHOD_PATTERN.finditer(source):
+ annotation_block = match.group(1)
+ method_name = match.group(2)
+ line = line_number(source, match.start())
+ providers = parse_method_sources(method_name, annotation_block)
+ if not providers:
+ violations.append(f"{path}:{line} method={method_name} missing
@MethodSource")
+ continue
+ total_rows = 0
+ unresolved = []
+ for provider in providers:
+ body = method_bodies.get(provider)
+ if body is None:
+ unresolved.append(provider)
+ continue
+ total_rows +=
len(re.findall(r"\b(?:Arguments\.of|arguments)\s*\(", body))
+ if unresolved:
+ violations.append(f"{path}:{line} method={method_name}
unresolvedProviders={','.join(unresolved)}")
+ continue
+ if total_rows < 3:
+ violations.append(f"{path}:{line} method={method_name}
argumentsRows={total_rows}")
+ return violations
+
+
+def check_r15_e(path: Path, source: str) -> list[str]:
+ violations = []
+ for match in PARAM_METHOD_PATTERN.finditer(source):
+ method_name = match.group(2)
+ params = match.group(3).strip()
+ line = line_number(source, match.start())
+ if not params:
+ violations.append(f"{path}:{line} method={method_name}
missingParameters")
+ continue
+ first_param = params.split(",", 1)[0].strip()
+ normalized = re.sub(r"\s+", " ", first_param)
+ if "final String name" != normalized:
+ violations.append(f"{path}:{line} method={method_name}
firstParam={first_param}")
+ return violations
+
+
+def check_r15_f(path: Path, source: str) -> list[str]:
+ violations = []
+ for match in PARAM_METHOD_BODY_PATTERN.finditer(source):
+ method_name = match.group(1)
+ line = line_number(source, match.start())
+ brace_index = opening_brace_index(match)
+ body = extract_block(source, brace_index)
+ if R15_SWITCH_PATTERN.search(body):
+ violations.append(f"{path}:{line} method={method_name}")
+ return violations
+
+
+def check_r15_g(path: Path, source: str) -> list[str]:
+ if "@ParameterizedTest" not in source:
+ return []
+ top_level_class_name = get_top_level_class_name(source)
+ violations = []
+ for line in get_added_lines_for_path(path):
+ if line.startswith("+++") or line.startswith("@@"):
+ continue
+ if not line.startswith("+"):
+ continue
+ if not R15_G_TYPE_DECL_PATTERN.search(line):
+ continue
+ stripped = line[1:].strip()
+ match = TYPE_DECL_LINE_PATTERN.match(stripped)
+ if match and match.group(2) == top_level_class_name:
+ continue
+ violations.append(f"{path}: {stripped}")
+ return violations
+
+
+def check_r15_i(path: Path, source: str, method_bodies: dict[str, str]) ->
list[str]:
+ violations = []
+ for match in PARAM_METHOD_PATTERN.finditer(source):
+ annotation_block = match.group(1)
+ method_name = match.group(2)
+ params = match.group(3)
+ line = line_number(source, match.start())
+ if CONSUMER_TOKEN_PATTERN.search(params):
+ violations.append(f"{path}:{line} method={method_name}
reason=consumerInParameterizedMethodSignature")
+ for provider in parse_method_sources(method_name, annotation_block):
+ body = method_bodies.get(provider)
+ if body and CONSUMER_TOKEN_PATTERN.search(body):
+ violations.append(f"{path}:{line} method={method_name}
provider={provider} reason=consumerInMethodSourceArguments")
+ return violations
+
+
+def check_r14(path: Path, source: str) -> list[str]:
+ return [f"{path}:{line_number(source, match.start())}" for match in
BOOLEAN_ASSERTION_BAN_PATTERN.finditer(source)]
+
+
+def check_r15_h(path: Path, source: str) -> list[str]:
+ violations = []
+ for match in TEST_METHOD_DECL_PATTERN.finditer(source):
+ method_name = match.group(2)
+ line = line_number(source, match.start())
+ brace_index = opening_brace_index(match)
+ body = extract_block(source, brace_index)
+ if R15_H_IF_ELSE_PATTERN.search(body) or
R15_H_IF_RETURN_PATTERN.search(body):
+ violations.append(f"{path}:{line} method={method_name}")
+ return violations
+
+
+def check_r15_b(path: Path, source: str) -> list[str]:
+ return [f"{path}:{line_number(source, match.start())}" for match in
R15_B_PATTERN.finditer(source)]
+
+
+def scan_java_file(path: Path, allow_metadata_accessor_tests: bool) ->
tuple[dict[str, list[str]], list[CandidateSummary]]:
+ source = path.read_text(encoding="utf-8")
+ method_bodies = parse_method_bodies(source)
+ violations = defaultdict(list)
+ violations["R8"].extend(check_parameterized_name(path, source))
+ violations["R15-A"].extend(check_r15_a(path, source))
+ violations["R15-D"].extend(check_r15_d(path, source, method_bodies))
+ violations["R15-E"].extend(check_r15_e(path, source))
+ violations["R15-F"].extend(check_r15_f(path, source))
+ violations["R15-G"].extend(check_r15_g(path, source))
+ violations["R15-I"].extend(check_r15_i(path, source, method_bodies))
+ violations["R14"].extend(check_r14(path, source))
+ violations["R15-H"].extend(check_r15_h(path, source))
+ if not allow_metadata_accessor_tests:
+ violations["R15-B"].extend(check_r15_b(path, source))
+ return violations, analyze_parameterization_candidates(path, source)
+
+
+def check_r15_c(baseline_before: Path | None) -> list[str]:
+ if baseline_before is None:
+ return []
+ before_lines = baseline_before.read_text(encoding="utf-8").splitlines() if
baseline_before.exists() else []
+ before_paths = {
+ normalize_status_path(each) for each in before_lines
+ if each and is_src_main_path(normalize_status_path(each))
+ }
+ after_paths = {
+ normalize_status_path(each) for each in get_after_status_lines()
+ if is_src_main_path(normalize_status_path(each))
+ }
+ return sorted(after_paths - before_paths)
+
+
+def create_file_scan_context(path: Path) -> FileScanContext:
+ source = path.read_text(encoding="utf-8")
+ return FileScanContext(
+ path=path,
+ source=source,
+ method_bodies=parse_method_bodies(source),
+ candidates=analyze_parameterization_candidates(path, source),
+ )
+
+
+def file_rule_violations(context: FileScanContext,
allow_metadata_accessor_tests: bool) -> dict[str, list[str]]:
+ violations = defaultdict(list)
+ violations["R8"].extend(check_parameterized_name(context.path,
context.source))
+ violations["R14"].extend(check_r14(context.path, context.source))
+ violations["R15-A"].extend(check_r15_a(context.candidates))
+ violations["R15-D"].extend(check_r15_d(context.path, context.source,
context.method_bodies))
+ violations["R15-E"].extend(check_r15_e(context.path, context.source))
+ violations["R15-F"].extend(check_r15_f(context.path, context.source))
+ violations["R15-G"].extend(check_r15_g(context.path, context.source))
+ violations["R15-H"].extend(check_r15_h(context.path, context.source))
+ violations["R15-I"].extend(check_r15_i(context.path, context.source,
context.method_bodies))
+ if not allow_metadata_accessor_tests:
+ violations["R15-B"].extend(check_r15_b(context.path, context.source))
+ return violations
+
+
+def build_rule_result(violations_by_rule: dict[str, list[str]]) -> dict[str,
dict[str, object]]:
+ return {
+ each.name: {
+ "message": each.message,
+ "violations": violations_by_rule[each.name],
+ }
+ for each in RULE_SPECS
+ }
+
+
+def collect_scan_result(java_paths: list[Path], baseline_before: Path | None,
allow_metadata_accessor_tests: bool) -> dict:
+ violations_by_rule = defaultdict(list)
+ contexts = [create_file_scan_context(each) for each in java_paths]
+ for context in contexts:
+ for rule, entries in file_rule_violations(context,
allow_metadata_accessor_tests).items():
+ violations_by_rule[rule].extend(entries)
+ violations_by_rule["R15-C"].extend(check_r15_c(baseline_before))
+ return {
+ "rules": build_rule_result(violations_by_rule),
+ "candidates": [asdict(each) for context in contexts for each in
context.candidates],
+ "java_file_count": len(contexts),
+ }
+
+
+def failed_rule_names(result: dict) -> list[str]:
+ return [each.name for each in RULE_SPECS if
result["rules"][each.name]["violations"]]
+
+
+def print_rule_summary(result: dict) -> int:
+ failed_rules = set(failed_rule_names(result))
+ for each in RULE_SPECS:
+ violations = result["rules"][each.name]["violations"]
+ if each.name in failed_rules:
+ print(f"[{each.name}] {each.message}")
+ for violation in violations:
+ print(violation)
+ continue
+ print(f"[{each.name}] ok")
+ return 1 if failed_rules else 0
+
+
+def print_summary_only(result: dict) -> int:
+ candidates = result["candidates"]
+ if candidates:
+ print("[R8-CANDIDATES]")
+ for each in candidates:
+ print(describe_candidate(each))
+ else:
+ print("[R8-CANDIDATES] no candidates")
+ failed_rules =
[f"{each.name}={len(result['rules'][each.name]['violations'])}" for each in
RULE_SPECS if result["rules"][each.name]["violations"]]
+ print(f"[summary] javaFiles={result['java_file_count']}")
+ if failed_rules:
+ print(f"[summary] violations={' '.join(failed_rules)}")
+ return 1
+ print("[summary] all rules ok")
+ return 0
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(description="Consolidated quality-rule
scan for gen-ut.")
+ parser.add_argument("--baseline-before", help="Path to the baseline git
status captured at task start.")
+ parser.add_argument("--allow-metadata-accessor-tests",
action="store_true", help="Allow R15-B when user explicitly requested metadata
accessor tests.")
+ parser.add_argument("--json", action="store_true", help="Emit JSON output
instead of the default text report.")
+ parser.add_argument("--summary-only", action="store_true", help="Emit a
compact text summary with candidate information.")
+ parser.add_argument("paths", nargs="+", help="Resolved test file set.")
+ args = parser.parse_args()
+
+ java_paths = [Path(each) for each in args.paths if each.endswith(".java")]
+ baseline_path = Path(args.baseline_before) if args.baseline_before else
None
+ result = collect_scan_result(java_paths, baseline_path,
args.allow_metadata_accessor_tests)
+ if args.json:
+ print(json.dumps(result, indent=2, sort_keys=True))
+ return 1 if failed_rule_names(result) else 0
+ if args.summary_only:
+ return print_summary_only(result)
+ return print_rule_summary(result)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.codex/skills/gen-ut/scripts/verification_gate_state.py
b/.codex/skills/gen-ut/scripts/verification_gate_state.py
new file mode 100644
index 00000000000..b16d1c1d39a
--- /dev/null
+++ b/.codex/skills/gen-ut/scripts/verification_gate_state.py
@@ -0,0 +1,232 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/usr/bin/env python3
+"""
+Verification snapshot digest and reusable gate state for gen-ut.
+"""
+
+import argparse
+import hashlib
+import json
+import sys
+import time
+from pathlib import Path
+
+
+TRACKED_SUFFIXES = {".java", ".xml", ".yaml", ".yml", ".properties"}
+GATES_KEY = "gates"
+TARGET_TEST_GATE = "target-test"
+LEGACY_TARGET_DIGEST_KEY = "latest_green_target_test_digest"
+LEGACY_TARGET_UPDATED_AT_KEY = "latest_green_target_test_updated_at"
+LEGACY_TRACKED_FILE_COUNT_KEY = "tracked_file_count"
+
+
+def normalize_paths(paths: list[str]) -> list[Path]:
+ result = []
+ seen = set()
+ for each in paths:
+ path = Path(each)
+ if path.suffix.lower() not in TRACKED_SUFFIXES:
+ continue
+ resolved = path.resolve()
+ if not resolved.exists():
+ raise ValueError(f"tracked path does not exist: {resolved}")
+ if not resolved.is_file():
+ raise ValueError(f"tracked path is not a file: {resolved}")
+ if resolved in seen:
+ continue
+ seen.add(resolved)
+ result.append(resolved)
+ return sorted(result, key=str)
+
+
+def tracked_paths(paths: list[str]) -> list[Path]:
+ result = normalize_paths(paths)
+ if result:
+ return result
+ raise ValueError("no trackable files found in verification snapshot")
+
+
+def calculate_digest(paths: list[Path]) -> str:
+ digest = hashlib.sha256()
+ for path in paths:
+ digest.update(str(path).encode("utf-8"))
+ digest.update(b"\0")
+ digest.update(path.read_bytes())
+ digest.update(b"\0")
+ return digest.hexdigest()
+
+
+def extend_digest(base_digest: str, tokens: list[str]) -> str:
+ digest = hashlib.sha256()
+ digest.update(base_digest.encode("utf-8"))
+ digest.update(b"\0")
+ for each in tokens:
+ digest.update(each.encode("utf-8"))
+ digest.update(b"\0")
+ return digest.hexdigest()
+
+
+def read_state(state_file: Path) -> dict:
+ if not state_file.exists():
+ return {}
+ result = json.loads(state_file.read_text(encoding="utf-8"))
+ if not isinstance(result, dict):
+ raise ValueError(f"invalid state payload in {state_file}: expected
JSON object")
+ return result
+
+
+def write_state(state_file: Path, state: dict) -> None:
+ state_file.parent.mkdir(parents=True, exist_ok=True)
+ state_file.write_text(json.dumps(state, indent=2, sort_keys=True) + "\n",
encoding="utf-8")
+
+
+def ensure_gate_map(state: dict) -> dict[str, dict]:
+ gates = state.get(GATES_KEY)
+ if gates is None:
+ gates = {}
+ state[GATES_KEY] = gates
+ if not isinstance(gates, dict):
+ raise ValueError("invalid gate state payload: expected object for
`gates`")
+ return gates
+
+
+def legacy_target_entry(state: dict) -> dict | None:
+ digest = state.get(LEGACY_TARGET_DIGEST_KEY)
+ if not digest:
+ return None
+ return {
+ "digest": digest,
+ "updated_at": state.get(LEGACY_TARGET_UPDATED_AT_KEY),
+ "tracked_file_count": state.get(LEGACY_TRACKED_FILE_COUNT_KEY),
+ }
+
+
+def get_gate_entry(state: dict, gate_name: str) -> dict | None:
+ gates = ensure_gate_map(state)
+ result = gates.get(gate_name)
+ if result is not None and not isinstance(result, dict):
+ raise ValueError(f"invalid gate entry for {gate_name}: expected JSON
object")
+ if result is None and TARGET_TEST_GATE == gate_name:
+ return legacy_target_entry(state)
+ return result
+
+
+def get_gate_digest(state: dict, gate_name: str) -> str | None:
+ entry = get_gate_entry(state, gate_name)
+ return entry.get("digest") if entry else None
+
+
+def set_gate_digest(state: dict, gate_name: str, digest: str,
tracked_file_count: int) -> None:
+ gates = ensure_gate_map(state)
+ updated_at = int(time.time())
+ gates[gate_name] = {
+ "digest": digest,
+ "tracked_file_count": tracked_file_count,
+ "updated_at": updated_at,
+ }
+ if TARGET_TEST_GATE == gate_name:
+ state[LEGACY_TARGET_DIGEST_KEY] = digest
+ state[LEGACY_TARGET_UPDATED_AT_KEY] = updated_at
+ state[LEGACY_TRACKED_FILE_COUNT_KEY] = tracked_file_count
+
+
+def matches_gate_digest(state: dict, gate_name: str, digest: str) -> bool:
+ return get_gate_digest(state, gate_name) == digest
+
+
+def current_digest(paths: list[str]) -> tuple[str, int]:
+ normalized_paths = tracked_paths(paths)
+ return calculate_digest(normalized_paths), len(normalized_paths)
+
+
+def command_digest(paths: list[str]) -> int:
+ digest, _ = current_digest(paths)
+ print(digest)
+ return 0
+
+
+def command_mark_gate_green(state_file: Path, gate_name: str, paths:
list[str]) -> int:
+ digest, tracked_file_count = current_digest(paths)
+ state = read_state(state_file)
+ set_gate_digest(state, gate_name, digest, tracked_file_count)
+ write_state(state_file, state)
+ print(digest)
+ return 0
+
+
+def command_match_gate_green(state_file: Path, gate_name: str, paths:
list[str]) -> int:
+ digest, _ = current_digest(paths)
+ state = read_state(state_file)
+ latest_green = get_gate_digest(state, gate_name)
+ if latest_green == digest:
+ print(f"MATCH {digest}")
+ return 0
+ if latest_green:
+ print(f"MISMATCH current={digest} recorded={latest_green}")
+ else:
+ print(f"MISSING current={digest}")
+ return 1
+
+
+def execute_command(args: argparse.Namespace) -> int:
+ if "digest" == args.command:
+ return command_digest(args.paths)
+ if "mark-green" == args.command:
+ return command_mark_gate_green(Path(args.state_file),
TARGET_TEST_GATE, args.paths)
+ if "match-green" == args.command:
+ return command_match_gate_green(Path(args.state_file),
TARGET_TEST_GATE, args.paths)
+ if "mark-gate-green" == args.command:
+ return command_mark_gate_green(Path(args.state_file), args.gate,
args.paths)
+ return command_match_gate_green(Path(args.state_file), args.gate,
args.paths)
+
+
+def build_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(description="Verification snapshot digest
and gate reuse state.")
+ subparsers = parser.add_subparsers(dest="command", required=True)
+ digest_parser = subparsers.add_parser("digest", help="Print the current
verification snapshot digest.")
+ digest_parser.add_argument("paths", nargs="+", help="Resolved test file
set.")
+ mark_green_parser = subparsers.add_parser("mark-green", help="Record the
current digest as the latest green target-test digest.")
+ mark_green_parser.add_argument("--state-file", required=True, help="State
file path.")
+ mark_green_parser.add_argument("paths", nargs="+", help="Resolved test
file set.")
+ match_green_parser = subparsers.add_parser("match-green", help="Check
whether the current digest matches the latest green target-test digest.")
+ match_green_parser.add_argument("--state-file", required=True, help="State
file path.")
+ match_green_parser.add_argument("paths", nargs="+", help="Resolved test
file set.")
+ mark_gate_green_parser = subparsers.add_parser("mark-gate-green",
help="Record the current digest as the latest green digest for a named gate.")
+ mark_gate_green_parser.add_argument("--state-file", required=True,
help="State file path.")
+ mark_gate_green_parser.add_argument("--gate", required=True, help="Logical
gate name, for example target-test or coverage.")
+ mark_gate_green_parser.add_argument("paths", nargs="+", help="Resolved
test file set.")
+ match_gate_green_parser = subparsers.add_parser("match-gate-green",
help="Check whether the current digest matches the latest green digest for a
named gate.")
+ match_gate_green_parser.add_argument("--state-file", required=True,
help="State file path.")
+ match_gate_green_parser.add_argument("--gate", required=True,
help="Logical gate name, for example target-test or coverage.")
+ match_gate_green_parser.add_argument("paths", nargs="+", help="Resolved
test file set.")
+ return parser
+
+
+def main() -> int:
+ parser = build_parser()
+ args = parser.parse_args()
+ try:
+ return execute_command(args)
+ except (ValueError, OSError, json.JSONDecodeError) as ex:
+ print(ex, file=sys.stderr)
+ return 2
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git
a/infra/common/src/test/java/org/apache/shardingsphere/infra/datanode/DataNodeTest.java
b/infra/common/src/test/java/org/apache/shardingsphere/infra/datanode/DataNodeTest.java
index c96930950d6..df10653d8dd 100644
---
a/infra/common/src/test/java/org/apache/shardingsphere/infra/datanode/DataNodeTest.java
+++
b/infra/common/src/test/java/org/apache/shardingsphere/infra/datanode/DataNodeTest.java
@@ -21,248 +21,155 @@ import
org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import
org.apache.shardingsphere.infra.exception.kernel.metadata.datanode.InvalidDataNodeFormatException;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.stream.Stream;
-import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.hamcrest.Matchers.is;
+import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
class DataNodeTest {
- @Test
- void assertNewValidDataNode() {
- DataNode dataNode = new DataNode("ds_0.tbl_0");
- assertThat(dataNode.getDataSourceName(), is("ds_0"));
- assertThat(dataNode.getTableName(), is("tbl_0"));
- }
-
- @Test
- void assertNewInValidDataNodeWithoutDelimiter() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0tbl_0"));
- }
-
- @Test
- void assertNewInValidDataNodeWithTwoDelimiters() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0.db_0.tbl_0.tbl_1"));
- }
-
- @Test
- void assertNewValidDataNodeWithInvalidDelimiter() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0,tbl_0"));
- }
-
- @Test
- void assertFormatWithSchema() {
- assertThat(new DataNode("foo_ds", "foo_schema", "foo_tbl").format(),
is("foo_ds.foo_schema.foo_tbl"));
- }
-
- @Test
- void assertFormatWithoutSchema() {
- DataNode dataNode = new DataNode("foo_ds", (String) null, "foo_tbl");
- assertThat(dataNode.format(), is("foo_ds.foo_tbl"));
- }
-
- @SuppressWarnings({"SimplifiableAssertion", "ConstantValue"})
- @Test
- void assertEquals() {
- DataNode dataNode = new DataNode("ds_0.tbl_0");
- assertThat(dataNode, is(dataNode));
- assertThat(dataNode, is(new DataNode("ds_0.tbl_0")));
- assertThat(dataNode, is(new DataNode("DS_0.TBL_0")));
- assertThat(dataNode, not(new DataNode("ds_0.tbl_1")));
- assertFalse(dataNode.equals("ds.tbl"));
- assertFalse(dataNode.equals(null));
- }
-
- @Test
- void assertEqualsWithSchema() {
- DataNode dataNode = new DataNode("ds", "schema1", "tbl");
- assertThat(dataNode, not(new DataNode("ds", "schema2", "tbl")));
- assertThat(dataNode, not(new DataNode("ds", (String) null, "tbl")));
- }
-
- @Test
- void assertHashCode() {
- assertThat(new DataNode("ds_0.tbl_0").hashCode(), is(new
DataNode("ds_0.tbl_0").hashCode()));
- assertThat(new DataNode("ds_0.tbl_0").hashCode(), is(new
DataNode("DS_0.TBL_0").hashCode()));
- assertThat(new DataNode("ds_0.db_0.tbl_0").hashCode(), is(new
DataNode("ds_0.db_0.tbl_0").hashCode()));
- assertThat(new DataNode("ds_0.db_0.tbl_0").hashCode(), is(new
DataNode("DS_0.DB_0.TBL_0").hashCode()));
- assertThat(new DataNode("DS", "SCHEMA", "TBL").hashCode(), is(new
DataNode("ds", "schema", "tbl").hashCode()));
- }
-
- @Test
- void assertToString() {
- assertThat(new DataNode("ds_0.tbl_0").toString(),
is("DataNode(dataSourceName=ds_0, schemaName=null, tableName=tbl_0)"));
- assertThat(new DataNode("ds", "schema", "tbl").toString(),
is("DataNode(dataSourceName=ds, schemaName=schema, tableName=tbl)"));
- assertThat(new DataNode("ds_0.schema_0.tbl_0").toString(),
is("DataNode(dataSourceName=ds_0, schemaName=schema_0, tableName=tbl_0)"));
- }
-
- @Test
- void assertEmptyDataSourceDataNode() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode(".tbl_0"));
- }
-
- @Test
- void assertEmptyTableDataNode() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0."));
- }
-
- @Test
- void assertNewValidDataNodeIncludeInstance() {
- DataNode dataNode = new DataNode("ds_0.schema_0.tbl_0");
- assertThat(dataNode.getDataSourceName(), is("ds_0"));
- assertThat(dataNode.getTableName(), is("tbl_0"));
- assertThat(dataNode.getSchemaName(), is("schema_0"));
- }
+ private static final DatabaseType MYSQL_DATABASE_TYPE =
TypedSPILoader.getService(DatabaseType.class, "MySQL");
- @Test
- void assertNewDataNodeWithOnlyOneSegment() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0"));
- }
+ private static final DatabaseType POSTGRESQL_DATABASE_TYPE =
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");
- @Test
- void assertNewDataNodeWithFourSegments() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0.db_0.tbl_0.col_0"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("dataNodeWithoutSchemaArguments")
+ void assertNewDataNodeWithoutSchema(final String name, final String
dataNodeText, final String expectedDataSourceName, final String
expectedTableName) {
+ DataNode actual = new DataNode(dataNodeText);
+ assertThat(actual.getDataSourceName(), is(expectedDataSourceName));
+ assertNull(actual.getSchemaName());
+ assertThat(actual.getTableName(), is(expectedTableName));
}
- @Test
- void assertNewDataNodeWithEmptySegments() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0..tbl_0"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("dataNodeWithSchemaArguments")
+ void assertNewDataNodeWithSchema(final String name, final String
dataNodeText, final String expectedDataSourceName, final String
expectedSchemaName, final String expectedTableName) {
+ DataNode actual = new DataNode(dataNodeText);
+ assertThat(actual.getDataSourceName(), is(expectedDataSourceName));
+ assertThat(actual.getSchemaName(), is(expectedSchemaName));
+ assertThat(actual.getTableName(), is(expectedTableName));
}
- @Test
- void assertNewDataNodeWithLeadingDelimiter() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode(".ds_0.tbl_0"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("invalidDataNodeArguments")
+ void assertNewDataNodeWithInvalidFormat(final String name, final String
dataNodeText) {
+ assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode(dataNodeText));
}
- @Test
- void assertNewDataNodeWithTrailingDelimiter() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0.tbl_0."));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("databaseTypeDataNodeArguments")
+ void assertNewDataNodeWithDatabaseType(final String name, final String
databaseName, final DatabaseType databaseType, final String dataNodeText,
+ final String
expectedDataSourceName, final String expectedSchemaName, final String
expectedTableName) {
+ DataNode actual = new DataNode(databaseName, databaseType,
dataNodeText);
+ assertThat(actual.getDataSourceName(), is(expectedDataSourceName));
+ assertThat(actual.getSchemaName(), is(expectedSchemaName));
+ assertThat(actual.getTableName(), is(expectedTableName));
}
@Test
- void assertNewDataNodeWithMultipleConsecutiveDelimiters() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0..tbl_0"));
+ void assertNewDataNodeWithDatabaseTypeAndInvalidFormat() {
+ assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("test_db", POSTGRESQL_DATABASE_TYPE,
"invalid_format_without_delimiter"));
}
- @Test
- void assertNewDataNodeWithWhitespaceInSegments() {
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("ds_0 . tbl_0"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("formatArguments")
+ void assertFormat(final String name, final DataNode dataNode, final String
expectedText) {
+ assertThat(dataNode.format(), is(expectedText));
}
- @Test
- void assertNewDataNodeWithSpecialCharacters() {
- DataNode dataNode = new DataNode("ds-0.tbl_0");
- assertThat(dataNode.getDataSourceName(), is("ds-0"));
- assertThat(dataNode.getTableName(), is("tbl_0"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("formatWithDatabaseTypeArguments")
+ void assertFormatWithDatabaseType(final String name, final DataNode
dataNode, final DatabaseType databaseType, final String expectedText) {
+ assertThat(dataNode.format(databaseType), is(expectedText));
}
- @Test
- void assertNewDataNodeWithUnderscores() {
- DataNode dataNode = new DataNode("data_source_0.table_name_0");
- assertThat(dataNode.getDataSourceName(), is("data_source_0"));
- assertThat(dataNode.getTableName(), is("table_name_0"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("equalsArguments")
+ void assertEquals(final String name, final DataNode dataNode, final Object
other, final boolean expectedMatched) {
+ assertThat(dataNode.equals(other), is(expectedMatched));
}
- @Test
- void assertNewDataNodeWithNumbers() {
- DataNode dataNode = new DataNode("ds123.tbl456");
- assertThat(dataNode.getDataSourceName(), is("ds123"));
- assertThat(dataNode.getTableName(), is("tbl456"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("hashCodeArguments")
+ void assertHashCode(final String name, final DataNode dataNode, final
DataNode other) {
+ assertThat(dataNode.hashCode(), is(other.hashCode()));
}
- @Test
- void assertNewDataNodeWithMixedFormat() {
- DataNode dataNode = new DataNode("prod-db-01.schema_01.users");
- assertThat(dataNode.getDataSourceName(), is("prod-db-01"));
- assertThat(dataNode.getTableName(), is("users"));
+ private static Stream<Arguments> dataNodeWithoutSchemaArguments() {
+ return Stream.of(
+ Arguments.of("simple_names", "ds_0.tbl_0", "ds_0", "tbl_0"),
+ Arguments.of("special_characters", "ds-0.tbl_0", "ds-0",
"tbl_0"),
+ Arguments.of("underscores", "data_source_0.table_name_0",
"data_source_0", "table_name_0"),
+ Arguments.of("numbers", "ds123.tbl456", "ds123", "tbl456"),
+ Arguments.of("single_characters", "a.b", "a", "b"));
}
- @Test
- void assertNewDataNodeWithLongNames() {
- String longDataSource =
"very_long_data_source_name_that_exceeds_normal_length";
- String longTable = "very_long_table_name_that_exceeds_normal_length";
- DataNode dataNode = new DataNode(longDataSource + "." + longTable);
- assertThat(dataNode.getDataSourceName(), is(longDataSource));
- assertThat(dataNode.getTableName(), is(longTable));
+ private static Stream<Arguments> dataNodeWithSchemaArguments() {
+ return Stream.of(
+ Arguments.of("simple_schema", "ds_0.schema_0.tbl_0", "ds_0",
"schema_0", "tbl_0"),
+ Arguments.of("mixed_format", "prod-db-01.schema_01.users",
"prod-db-01", "schema_01", "users"),
+ Arguments.of("instance_format", "instance1.database1.table1",
"instance1", "database1", "table1"),
+ Arguments.of("complex_instance_format",
"prod-cluster-01.mysql-master.users", "prod-cluster-01", "mysql-master",
"users"));
}
- @Test
- void assertNewDataNodeWithSingleCharacterNames() {
- DataNode dataNode = new DataNode("a.b");
- assertThat(dataNode.getDataSourceName(), is("a"));
- assertThat(dataNode.getTableName(), is("b"));
+ private static Stream<Arguments> invalidDataNodeArguments() {
+ return Stream.of(
+ Arguments.of("without_delimiter", "ds_0tbl_0"),
+ Arguments.of("too_many_segments", "ds_0.db_0.tbl_0.tbl_1"),
+ Arguments.of("invalid_delimiter", "ds_0,tbl_0"),
+ Arguments.of("empty_data_source", ".tbl_0"),
+ Arguments.of("empty_table", "ds_0."),
+ Arguments.of("consecutive_delimiters", "ds_0..tbl_0"),
+ Arguments.of("whitespace_before_delimiter", "ds_0 .tbl_0"),
+ Arguments.of("trailing_delimiter", "ds_0.tbl_0."),
+ Arguments.of("whitespace_after_delimiter", "ds_0. tbl_0"),
+ Arguments.of("blank_segment_with_tab", "ds.\t.tbl"));
}
- @Test
- void assertNewDataNodeWithInstanceFormat() {
- DataNode dataNode = new DataNode("instance1.database1.table1");
- assertThat(dataNode.getDataSourceName(), is("instance1"));
- assertThat(dataNode.getTableName(), is("table1"));
+ private static Stream<Arguments> databaseTypeDataNodeArguments() {
+ return Stream.of(
+ Arguments.of("postgresql_with_schema", "test_db",
POSTGRESQL_DATABASE_TYPE, "ds.schema.tbl", "ds", "schema", "tbl"),
+ Arguments.of("postgresql_without_schema_segment", "test_db",
POSTGRESQL_DATABASE_TYPE, "ds.tbl", "ds", "*", "tbl"),
+ Arguments.of("mysql_without_schema_support", "test_db",
MYSQL_DATABASE_TYPE, "ds.tbl", "ds", "test_db", "tbl"),
+ Arguments.of("mysql_three_segments_kept_as_table_suffix",
"test_db", MYSQL_DATABASE_TYPE, "ds.schema.tbl", "ds", "test_db", "schema.tbl"),
+ Arguments.of("postgresql_lowercases_table", "test_db",
POSTGRESQL_DATABASE_TYPE, "ds.schema.TABLE", "ds", "schema", "table"));
}
- @Test
- void assertNewDataNodeWithComplexInstanceFormat() {
- DataNode dataNode = new DataNode("prod-cluster-01.mysql-master.users");
- assertThat(dataNode.getDataSourceName(), is("prod-cluster-01"));
- assertThat(dataNode.getTableName(), is("users"));
+ private static Stream<Arguments> formatArguments() {
+ return Stream.of(
+ Arguments.of("with_schema", new DataNode("foo_ds",
"foo_schema", "foo_tbl"), "foo_ds.foo_schema.foo_tbl"),
+ Arguments.of("without_schema", new DataNode("foo_ds", (String)
null, "foo_tbl"), "foo_ds.foo_tbl"));
}
- @Test
- void assertNewDataNodeWithDatabaseType() {
- DataNode dataNode = new DataNode("test_db",
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL"), "ds.schema.tbl");
- assertThat(dataNode.getDataSourceName(), is("ds"));
- assertThat(dataNode.getSchemaName(), is("schema"));
- assertThat(dataNode.getTableName(), is("tbl"));
- }
-
- @Test
- void assertFormatWithDatabaseType() {
- assertThat(new DataNode("ds", "schema",
"tbl").format(TypedSPILoader.getService(DatabaseType.class, "PostgreSQL")),
is("ds.schema.tbl"));
- }
-
- @Test
- void assertFormatWithDatabaseTypeWithoutSchema() {
- DataNode dataNode = new DataNode("ds", (String) null, "tbl");
-
assertThat(dataNode.format(TypedSPILoader.getService(DatabaseType.class,
"MySQL")), is("ds.tbl"));
- }
-
- @Test
- void assertNewDataNodeWithDatabaseTypeWithoutSchemaSupport() {
- DatabaseType databaseType =
TypedSPILoader.getService(DatabaseType.class, "MySQL");
- DataNode dataNode = new DataNode("test_db", databaseType, "ds.tbl");
- assertThat(dataNode.getDataSourceName(), is("ds"));
- assertThat(dataNode.getSchemaName(), is("test_db"));
- assertThat(dataNode.getTableName(), is("tbl"));
+ private static Stream<Arguments> formatWithDatabaseTypeArguments() {
+ return Stream.of(
+ Arguments.of("postgresql_with_schema", new DataNode("ds",
"schema", "tbl"), POSTGRESQL_DATABASE_TYPE, "ds.schema.tbl"),
+ Arguments.of("mysql_without_schema", new DataNode("ds",
(String) null, "tbl"), MYSQL_DATABASE_TYPE, "ds.tbl"),
+ Arguments.of("mysql_ignores_explicit_schema", new
DataNode("ds", "schema", "tbl"), MYSQL_DATABASE_TYPE, "ds.tbl"));
}
- @Test
- void assertNewDataNodeWithDatabaseTypeAndInvalidThreeSegment() {
- DataNode dataNode = new DataNode("test_db",
TypedSPILoader.getService(DatabaseType.class, "MySQL"), "ds.schema.tbl");
- assertThat(dataNode.getDataSourceName(), is("ds"));
- assertThat(dataNode.getSchemaName(), is("test_db"));
- assertThat(dataNode.getTableName(), is("schema.tbl"));
- }
-
- @Test
- void assertNewDataNodeWithDatabaseTypeCheckStateException() {
- DatabaseType databaseType =
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");
- assertThrows(InvalidDataNodeFormatException.class, () -> new
DataNode("test_db", databaseType, "invalid_format_without_delimiter"));
+ private static Stream<Arguments> equalsArguments() {
+ final DataNode self = new DataNode("ds_0.tbl_0");
+ return Stream.of(
+ Arguments.of("self", self, self, true),
+ Arguments.of("null_object", new DataNode("ds_0.tbl_0"), null,
false),
+ Arguments.of("different_type", new DataNode("ds_0.tbl_0"),
"ds.tbl", false),
+ Arguments.of("ignore_case", new DataNode("ds_0.tbl_0"), new
DataNode("DS_0.TBL_0"), true),
+ Arguments.of("different_data_source", new
DataNode("ds_0.tbl_0"), new DataNode("ds_1.tbl_0"), false),
+ Arguments.of("different_table", new DataNode("ds_0.tbl_0"),
new DataNode("ds_0.tbl_1"), false),
+ Arguments.of("different_schema", new DataNode("ds", "schema1",
"tbl"), new DataNode("ds", "schema2", "tbl"), false));
}
- @Test
- void assertFormatWithDatabaseTypeAndNullSchema() {
- DatabaseType databaseType =
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");
- DataNode dataNode = new DataNode("ds", (String) null, "tbl");
- assertThat(dataNode.format(databaseType), is("ds.tbl"));
- }
-
- @Test
- void assertFormatMethodWithTableNameLowercasing() {
- DatabaseType databaseType =
TypedSPILoader.getService(DatabaseType.class, "PostgreSQL");
- DataNode dataNode = new DataNode("test_db", databaseType,
"ds.schema.TABLE");
- assertThat(dataNode.getTableName(), is("table"));
- assertThat(dataNode.getSchemaName(), is("schema"));
+ private static Stream<Arguments> hashCodeArguments() {
+ return Stream.of(
+ Arguments.of("without_schema_ignore_case", new
DataNode("ds_0.tbl_0"), new DataNode("DS_0.TBL_0")),
+ Arguments.of("with_schema_ignore_case", new
DataNode("ds_0.db_0.tbl_0"), new DataNode("DS_0.DB_0.TBL_0")),
+ Arguments.of("manual_constructor_ignore_case", new
DataNode("DS", "SCHEMA", "TBL"), new DataNode("ds", "schema", "tbl")));
}
}
diff --git
a/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/HintValueContextTest.java
b/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/HintValueContextTest.java
index c5649013f2a..0b5715e4d4a 100644
---
a/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/HintValueContextTest.java
+++
b/infra/common/src/test/java/org/apache/shardingsphere/infra/hint/HintValueContextTest.java
@@ -18,12 +18,16 @@
package org.apache.shardingsphere.infra.hint;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
import java.util.Collection;
import java.util.Optional;
+import java.util.stream.Stream;
-import static org.hamcrest.Matchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -43,67 +47,90 @@ class HintValueContextTest {
assertThat(actual.get(), is("foo_ds"));
}
- @Test
- void assertContainsHintShardingDatabaseValue() {
- HintValueContext hintValueContext = new HintValueContext();
-
hintValueContext.getShardingDatabaseValues().put("TABLE.SHARDING_DATABASE_VALUE",
"1");
-
assertTrue(hintValueContext.containsHintShardingDatabaseValue("table"));
-
assertTrue(hintValueContext.containsHintShardingDatabaseValue("TABLE"));
-
assertFalse(hintValueContext.containsHintShardingDatabaseValue("other"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("containsHintShardingDatabaseValueArguments")
+ void assertContainsHintShardingDatabaseValue(final String name, final
HintValueContext hintValueContext, final String tableName, final boolean
expectedContainsHintShardingDatabaseValue) {
+
assertThat(hintValueContext.containsHintShardingDatabaseValue(tableName),
is(expectedContainsHintShardingDatabaseValue));
}
- @Test
- void assertContainsHintShardingTableValue() {
- HintValueContext hintValueContext = new HintValueContext();
-
hintValueContext.getShardingTableValues().put("TABLE.SHARDING_TABLE_VALUE",
"1");
- assertTrue(hintValueContext.containsHintShardingTableValue("table"));
- assertTrue(hintValueContext.containsHintShardingTableValue("TABLE"));
- assertFalse(hintValueContext.containsHintShardingTableValue("other"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("containsHintShardingTableValueArguments")
+ void assertContainsHintShardingTableValue(final String name, final
HintValueContext hintValueContext, final String tableName, final boolean
expectedContainsHintShardingTableValue) {
+ assertThat(hintValueContext.containsHintShardingTableValue(tableName),
is(expectedContainsHintShardingTableValue));
}
- @Test
- void assertContainsHintShardingValue() {
- HintValueContext hintValueContext = new HintValueContext();
-
hintValueContext.getShardingDatabaseValues().put("TABLE.SHARDING_DATABASE_VALUE",
"1");
- assertTrue(hintValueContext.containsHintShardingValue("table"));
- hintValueContext.getShardingDatabaseValues().clear();
-
hintValueContext.getShardingTableValues().put("OTHER_TABLE.SHARDING_TABLE_VALUE",
"1");
- assertFalse(hintValueContext.containsHintShardingValue("table"));
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("containsHintShardingValueArguments")
+ void assertContainsHintShardingValue(final String name, final
HintValueContext hintValueContext, final String tableName, final boolean
expectedContainsHintShardingValue) {
+ assertThat(hintValueContext.containsHintShardingValue(tableName),
is(expectedContainsHintShardingValue));
}
@Test
void assertGetHintShardingTableValueWithTableName() {
HintValueContext hintValueContext = new HintValueContext();
-
hintValueContext.getShardingTableValues().put("TABLE.SHARDING_TABLE_VALUE",
"1");
- Collection<Comparable<?>> actual =
hintValueContext.getHintShardingTableValue("table");
- assertThat(actual.size(), is(1));
- assertThat(actual.iterator().next(), is("1"));
+
hintValueContext.getShardingTableValues().put("FOO_TABLE.SHARDING_TABLE_VALUE",
"foo_value");
+ Collection<Comparable<?>> actualHintShardingTableValue =
hintValueContext.getHintShardingTableValue("foo_table");
+ assertThat(actualHintShardingTableValue.size(), is(1));
+ assertThat(actualHintShardingTableValue.iterator().next(),
is("foo_value"));
}
@Test
- void assertSetHintShardingTableValueWithoutTableName() {
+ void assertGetHintShardingTableValueWithoutTableName() {
HintValueContext hintValueContext = new HintValueContext();
- hintValueContext.getShardingTableValues().put("SHARDING_TABLE_VALUE",
"2");
- Collection<Comparable<?>> actual =
hintValueContext.getHintShardingTableValue("other_table");
- assertThat(actual.size(), is(1));
- assertThat(actual.iterator().next(), is("2"));
+ hintValueContext.getShardingTableValues().put("SHARDING_TABLE_VALUE",
"bar_value");
+ Collection<Comparable<?>> actualHintShardingTableValue =
hintValueContext.getHintShardingTableValue("bar_table");
+ assertThat(actualHintShardingTableValue.size(), is(1));
+ assertThat(actualHintShardingTableValue.iterator().next(),
is("bar_value"));
}
@Test
void assertGetHintShardingDatabaseValueWithTableName() {
HintValueContext hintValueContext = new HintValueContext();
-
hintValueContext.getShardingDatabaseValues().put("TABLE.SHARDING_DATABASE_VALUE",
"1");
- Collection<Comparable<?>> actual =
hintValueContext.getHintShardingDatabaseValue("table");
- assertThat(actual.size(), is(1));
- assertThat(actual.iterator().next(), is("1"));
+
hintValueContext.getShardingDatabaseValues().put("FOO_TABLE.SHARDING_DATABASE_VALUE",
"foo_value");
+ Collection<Comparable<?>> actualHintShardingDatabaseValue =
hintValueContext.getHintShardingDatabaseValue("foo_table");
+ assertThat(actualHintShardingDatabaseValue.size(), is(1));
+ assertThat(actualHintShardingDatabaseValue.iterator().next(),
is("foo_value"));
}
@Test
void assertGetHintShardingDatabaseValueWithoutTableName() {
HintValueContext hintValueContext = new HintValueContext();
-
hintValueContext.getShardingDatabaseValues().put("SHARDING_DATABASE_VALUE",
"2");
- Collection<Comparable<?>> actual =
hintValueContext.getHintShardingDatabaseValue("other_table");
- assertThat(actual.size(), is(1));
- assertThat(actual.iterator().next(), is("2"));
+
hintValueContext.getShardingDatabaseValues().put("SHARDING_DATABASE_VALUE",
"bar_value");
+ Collection<Comparable<?>> actualHintShardingDatabaseValue =
hintValueContext.getHintShardingDatabaseValue("bar_table");
+ assertThat(actualHintShardingDatabaseValue.size(), is(1));
+ assertThat(actualHintShardingDatabaseValue.iterator().next(),
is("bar_value"));
+ }
+
+ private static Stream<Arguments>
containsHintShardingDatabaseValueArguments() {
+ return Stream.of(
+ Arguments.of("table_name_key",
createHintValueContextWithShardingDatabaseValue("FOO_TABLE.SHARDING_DATABASE_VALUE"),
"foo_table", true),
+ Arguments.of("global_key",
createHintValueContextWithShardingDatabaseValue("SHARDING_DATABASE_VALUE"),
"bar_table", true),
+ Arguments.of("missing_key", new HintValueContext(),
"bar_table", false));
+ }
+
+ private static Stream<Arguments> containsHintShardingTableValueArguments()
{
+ return Stream.of(
+ Arguments.of("table_name_key",
createHintValueContextWithShardingTableValue("FOO_TABLE.SHARDING_TABLE_VALUE"),
"foo_table", true),
+ Arguments.of("global_key",
createHintValueContextWithShardingTableValue("SHARDING_TABLE_VALUE"),
"bar_table", true),
+ Arguments.of("missing_key", new HintValueContext(),
"bar_table", false));
+ }
+
+ private static Stream<Arguments> containsHintShardingValueArguments() {
+ return Stream.of(
+ Arguments.of("database_value",
createHintValueContextWithShardingDatabaseValue("FOO_TABLE.SHARDING_DATABASE_VALUE"),
"foo_table", true),
+ Arguments.of("table_value",
createHintValueContextWithShardingTableValue("FOO_TABLE.SHARDING_TABLE_VALUE"),
"foo_table", true),
+ Arguments.of("missing_value", new HintValueContext(),
"bar_table", false));
+ }
+
+ private static HintValueContext
createHintValueContextWithShardingDatabaseValue(final String key) {
+ HintValueContext result = new HintValueContext();
+ result.getShardingDatabaseValues().put(key, "foo_value");
+ return result;
+ }
+
+ private static HintValueContext
createHintValueContextWithShardingTableValue(final String key) {
+ HintValueContext result = new HintValueContext();
+ result.getShardingTableValues().put(key, "foo_value");
+ return result;
}
}