https://gcc.gnu.org/g:0c517ddf9b136c9077b21142ec4118451d394bdb

commit r16-3028-g0c517ddf9b136c9077b21142ec4118451d394bdb
Author: Kito Cheng <kito.ch...@sifive.com>
Date:   Thu Jul 31 16:25:52 2025 +0800

    RISC-V: Read extension data from riscv-ext*.def for arch-canonicalize
    
    Previously, arch-canonicalize used hardcoded data to handle IMPLIED_EXT.
    But this data often got out of sync with the actual C++ implementation.
    Earlier, we introduced riscv-ext.def to keep track of all extension info
    and generate docs. Now, arch-canonicalize also uses this same data to handle
    extension implication rules directly.
    
    One limitation is that conditional implication rules still need to be 
written
    manually. Luckily, there aren't many of them for now, so it's still 
manageable.
    I really wanted to avoid writing a C++ + Python binding or trying to parse 
C++
    logic in Python...
    
    This version also adds a `--selftest` option to run some unit tests.
    
    gcc/ChangeLog:
    
            * config/riscv/arch-canonicalize: Read extension data from
            riscv-ext*.def and adding unittest.

Diff:
---
 gcc/config/riscv/arch-canonicalize | 563 +++++++++++++++++++++++++++++++------
 1 file changed, 482 insertions(+), 81 deletions(-)

diff --git a/gcc/config/riscv/arch-canonicalize 
b/gcc/config/riscv/arch-canonicalize
index 34dad45233ae..5d24f5eda2fe 100755
--- a/gcc/config/riscv/arch-canonicalize
+++ b/gcc/config/riscv/arch-canonicalize
@@ -20,77 +20,314 @@
 # along with GCC; see the file COPYING3.  If not see
 # <http://www.gnu.org/licenses/>.
 
-# TODO: Extract riscv_subset_t from riscv-common.cc and make it can be compiled
-#       standalone to replace this script, that also prevents us implementing
-#       that twice and keep sync again and again.
-
 from __future__ import print_function
 import sys
 import argparse
 import collections
 import itertools
+import re
+import os
 from functools import reduce
 
 SUPPORTED_ISA_SPEC = ["2.2", "20190608", "20191213"]
 CANONICAL_ORDER = "imafdqlcbkjtpvnh"
 LONG_EXT_PREFIXES = ['z', 's', 'h', 'x']
 
+def parse_define_riscv_ext(content):
+  """Parse DEFINE_RISCV_EXT macros using position-based parsing."""
+  extensions = []
+
+  # Find all DEFINE_RISCV_EXT blocks
+  pattern = r'DEFINE_RISCV_EXT\s*\('
+  matches = []
+
+  pos = 0
+  while True:
+    match = re.search(pattern, content[pos:])
+    if not match:
+      break
+
+    start_pos = pos + match.start()
+    paren_count = 0
+    current_pos = pos + match.end() - 1  # Start at the opening parenthesis
+
+    # Find the matching closing parenthesis
+    while current_pos < len(content):
+      if content[current_pos] == '(':
+        paren_count += 1
+      elif content[current_pos] == ')':
+        paren_count -= 1
+        if paren_count == 0:
+          break
+      current_pos += 1
+
+    if paren_count == 0:
+      # Extract the content inside parentheses
+      macro_content = content[pos + match.end():current_pos]
+      ext_data = parse_macro_arguments(macro_content)
+      if ext_data:
+        extensions.append(ext_data)
+
+    pos = current_pos + 1
+
+  return extensions
+
+def parse_macro_arguments(macro_content):
+  """Parse the arguments of a DEFINE_RISCV_EXT macro."""
+  # Remove comments /* ... */
+  cleaned_content = re.sub(r'/\*[^*]*\*/', '', macro_content)
+
+  # Split arguments by comma, but respect nested structures
+  args = []
+  current_arg = ""
+  paren_count = 0
+  brace_count = 0
+  in_string = False
+  escape_next = False
+
+  for char in cleaned_content:
+    if escape_next:
+      current_arg += char
+      escape_next = False
+      continue
+
+    if char == '\\':
+      escape_next = True
+      current_arg += char
+      continue
+
+    if char == '"' and not escape_next:
+      in_string = not in_string
+      current_arg += char
+      continue
+
+    if in_string:
+      current_arg += char
+      continue
+
+    if char == '(':
+      paren_count += 1
+    elif char == ')':
+      paren_count -= 1
+    elif char == '{':
+      brace_count += 1
+    elif char == '}':
+      brace_count -= 1
+    elif char == ',' and paren_count == 0 and brace_count == 0:
+      args.append(current_arg.strip())
+      current_arg = ""
+      continue
+
+    current_arg += char
+
+  # Add the last argument
+  if current_arg.strip():
+    args.append(current_arg.strip())
+
+  # We need at least 6 arguments to get DEP_EXTS (position 5)
+  if len(args) < 6:
+    return None
+
+  ext_name = args[0].strip()
+  dep_exts_arg = args[5].strip()  # DEP_EXTS is at position 5
+
+  # Parse dependency extensions from the DEP_EXTS argument
+  deps = parse_dep_exts(dep_exts_arg)
+
+  return {
+    'name': ext_name,
+    'dep_exts': deps
+  }
+
+def parse_dep_exts(dep_exts_str):
+  """Parse the DEP_EXTS argument to extract dependency list with conditions."""
+  # Remove outer parentheses if present
+  dep_exts_str = dep_exts_str.strip()
+  if dep_exts_str.startswith('(') and dep_exts_str.endswith(')'):
+    dep_exts_str = dep_exts_str[1:-1].strip()
+
+  # Remove outer braces if present
+  if dep_exts_str.startswith('{') and dep_exts_str.endswith('}'):
+    dep_exts_str = dep_exts_str[1:-1].strip()
+
+  if not dep_exts_str:
+    return []
+
+  deps = []
+
+  # First, find and process conditional dependencies
+  conditional_pattern = 
r'\{\s*"([^"]+)"\s*,\s*(\[.*?\]\s*\([^)]*\)\s*->\s*bool.*?)\}'
+  conditional_matches = []
+
+  for match in re.finditer(conditional_pattern, dep_exts_str, re.DOTALL):
+    ext_name = match.group(1)
+    condition_code = match.group(2)
+    deps.append({'ext': ext_name, 'type': 'conditional', 'condition': 
condition_code})
+    conditional_matches.append((match.start(), match.end()))
+
+  # Remove conditional dependency blocks from the string
+  remaining_str = dep_exts_str
+  for start, end in reversed(conditional_matches):  # Reverse order to 
maintain indices
+    remaining_str = remaining_str[:start] + remaining_str[end:]
+
+  # Now handle simple quoted strings in the remaining text
+  for match in re.finditer(r'"([^"]+)"', remaining_str):
+    deps.append({'ext': match.group(1), 'type': 'simple'})
+
+  # Remove duplicates while preserving order
+  seen = set()
+  unique_deps = []
+  for dep in deps:
+    key = (dep['ext'], dep['type'])
+    if key not in seen:
+      seen.add(key)
+      unique_deps.append(dep)
+
+  return unique_deps
+
+def evaluate_conditional_dependency(ext, dep, xlen, current_exts):
+  """Evaluate whether a conditional dependency should be included."""
+  ext_name = dep['ext']
+  condition = dep['condition']
+  # Parse the condition based on known patterns
+  if ext_name == 'zcf' and ext in ['zca', 'c', 'zce']:
+    # zcf depends on RV32 and F extension
+    return xlen == 32 and 'f' in current_exts
+  elif ext_name == 'zcd' and ext in ['zca', 'c']:
+    # zcd depends on D extension
+    return 'd' in current_exts
+  elif ext_name == 'c' and ext in ['zca']:
+    # Special case for zca -> c conditional dependency
+    if xlen == 32:
+      if 'd' in current_exts:
+        return 'zcf' in current_exts and 'zcd' in current_exts
+      elif 'f' in current_exts:
+        return 'zcf' in current_exts
+      else:
+        return True
+    elif xlen == 64:
+      if 'd' in current_exts:
+        return 'zcd' in current_exts
+      else:
+        return True
+    return False
+  else:
+    # Report error for unhandled conditional dependencies
+    import sys
+    print(f"ERROR: Unhandled conditional dependency: '{ext_name}' with 
condition:", file=sys.stderr)
+    print(f"  Condition code: {condition[:100]}...", file=sys.stderr)
+    print(f"  Current context: xlen={xlen}, exts={sorted(current_exts)}", 
file=sys.stderr)
+    # For now, return False to be safe
+    return False
+
+def resolve_dependencies(arch_parts, xlen):
+  """Resolve all dependencies including conditional ones."""
+  current_exts = set(arch_parts)
+  implied_deps = set()
+
+  # Keep resolving until no new dependencies are found
+  changed = True
+  while changed:
+    changed = False
+    new_deps = set()
+
+    for ext in current_exts | implied_deps:
+      if ext in IMPLIED_EXT:
+        for dep in IMPLIED_EXT[ext]:
+          if dep['type'] == 'simple':
+            if dep['ext'] not in current_exts and dep['ext'] not in 
implied_deps:
+              new_deps.add(dep['ext'])
+              changed = True
+          elif dep['type'] == 'conditional':
+            should_include = evaluate_conditional_dependency(ext, dep, xlen, 
current_exts | implied_deps)
+            if should_include:
+              if dep['ext'] not in current_exts and dep['ext'] not in 
implied_deps:
+                new_deps.add(dep['ext'])
+                changed = True
+
+    implied_deps.update(new_deps)
+
+  return implied_deps
+
+def parse_def_file(file_path, script_dir, processed_files=None, 
collect_all=False):
+  """Parse a single .def file and recursively process #include directives."""
+  if processed_files is None:
+    processed_files = set()
+
+  # Avoid infinite recursion
+  if file_path in processed_files:
+    return ({}, set()) if collect_all else {}
+  processed_files.add(file_path)
+
+  implied_ext = {}
+  all_extensions = set() if collect_all else None
+
+  if not os.path.exists(file_path):
+    return (implied_ext, all_extensions) if collect_all else implied_ext
+
+  with open(file_path, 'r') as f:
+    content = f.read()
+
+  # Process #include directives first
+  include_pattern = r'#include\s+"([^"]+)"'
+  includes = re.findall(include_pattern, content)
+
+  for include_file in includes:
+    include_path = os.path.join(script_dir, include_file)
+    if collect_all:
+      included_ext, included_all = parse_def_file(include_path, script_dir, 
processed_files, collect_all)
+      implied_ext.update(included_ext)
+      all_extensions.update(included_all)
+    else:
+      included_ext = parse_def_file(include_path, script_dir, processed_files, 
collect_all)
+      implied_ext.update(included_ext)
+
+  # Parse DEFINE_RISCV_EXT blocks using position-based parsing
+  parsed_exts = parse_define_riscv_ext(content)
+
+  for ext_data in parsed_exts:
+    ext_name = ext_data['name']
+    deps = ext_data['dep_exts']
+
+    if collect_all:
+      all_extensions.add(ext_name)
+
+    if deps:
+      implied_ext[ext_name] = deps
+
+  return (implied_ext, all_extensions) if collect_all else implied_ext
+
+def parse_def_files():
+  """Parse RISC-V extension definition files starting from riscv-ext.def."""
+  # Get directory containing this script
+  try:
+    script_dir = os.path.dirname(os.path.abspath(__file__))
+  except NameError:
+    # When __file__ is not defined (e.g., interactive mode)
+    script_dir = os.getcwd()
+
+  # Start with the main definition file
+  main_def_file = os.path.join(script_dir, 'riscv-ext.def')
+  return parse_def_file(main_def_file, script_dir)
+
+def get_all_extensions():
+  """Get all supported extensions and their implied extensions."""
+  # Get directory containing this script
+  try:
+    script_dir = os.path.dirname(os.path.abspath(__file__))
+  except NameError:
+    # When __file__ is not defined (e.g., interactive mode)
+    script_dir = os.getcwd()
+
+  # Start with the main definition file
+  main_def_file = os.path.join(script_dir, 'riscv-ext.def')
+  return parse_def_file(main_def_file, script_dir, collect_all=True)
+
 #
 # IMPLIED_EXT(ext) -> implied extension list.
+# This is loaded dynamically from .def files
 #
-IMPLIED_EXT = {
-  "d" : ["f", "zicsr"],
-
-  "a" : ["zaamo", "zalrsc"],
-  "zabha" : ["zaamo"],
-  "zacas" : ["zaamo"],
-
-  "f" : ["zicsr"],
-  "b" : ["zba", "zbb", "zbs"],
-  "zdinx" : ["zfinx", "zicsr"],
-  "zfinx" : ["zicsr"],
-  "zhinx" : ["zhinxmin", "zfinx", "zicsr"],
-  "zhinxmin" : ["zfinx", "zicsr"],
-
-  "zk" : ["zkn", "zkr", "zkt"],
-  "zkn" : ["zbkb", "zbkc", "zbkx", "zkne", "zknd", "zknh"],
-  "zks" : ["zbkb", "zbkc", "zbkx", "zksed", "zksh"],
-
-  "v" : ["zvl128b", "zve64d"],
-  "zve32x" : ["zvl32b"],
-  "zve64x" : ["zve32x", "zvl64b"],
-  "zve32f" : ["f", "zve32x"],
-  "zve64f" : ["f", "zve32f", "zve64x"],
-  "zve64d" : ["d", "zve64f"],
-
-  "zvl64b" : ["zvl32b"],
-  "zvl128b" : ["zvl64b"],
-  "zvl256b" : ["zvl128b"],
-  "zvl512b" : ["zvl256b"],
-  "zvl1024b" : ["zvl512b"],
-  "zvl2048b" : ["zvl1024b"],
-  "zvl4096b" : ["zvl2048b"],
-  "zvl8192b" : ["zvl4096b"],
-  "zvl16384b" : ["zvl8192b"],
-  "zvl32768b" : ["zvl16384b"],
-  "zvl65536b" : ["zvl32768b"],
-
-  "zvkn"   : ["zvkned", "zvknhb", "zvkb", "zvkt"],
-  "zvknc"  : ["zvkn", "zvbc"],
-  "zvkng"  : ["zvkn", "zvkg"],
-  "zvks"   : ["zvksed", "zvksh", "zvkb", "zvkt"],
-  "zvksc"  : ["zvks", "zvbc"],
-  "zvksg"  : ["zvks", "zvkg"],
-  "zvbb"   : ["zvkb"],
-  "zvbc"   : ["zve64x"],
-  "zvkb"   : ["zve32x"],
-  "zvkg"   : ["zve32x"],
-  "zvkned" : ["zve32x"],
-  "zvknha" : ["zve32x"],
-  "zvknhb" : ["zve64x"],
-  "zvksed" : ["zve32x"],
-  "zvksh"  : ["zve32x"],
-}
+IMPLIED_EXT = parse_def_files()
 
 def arch_canonicalize(arch, isa_spec):
   # TODO: Support extension version.
@@ -123,21 +360,31 @@ def arch_canonicalize(arch, isa_spec):
   long_exts += extra_long_ext
 
   #
-  # Handle implied extensions.
+  # Handle implied extensions using new conditional logic.
   #
-  any_change = True
-  while any_change:
-    any_change = False
-    for ext in std_exts + long_exts:
-      if ext in IMPLIED_EXT:
-        implied_exts = IMPLIED_EXT[ext]
-        for implied_ext in implied_exts:
-          if implied_ext == 'zicsr' and is_isa_spec_2p2:
-              continue
+  # Extract xlen from architecture string
+  # TODO: We should support profile here.
+  if arch.startswith('rv32'):
+    xlen = 32
+  elif arch.startswith('rv64'):
+    xlen = 64
+  else:
+    raise Exception("Unsupported prefix `%s`" % arch)
 
-          if implied_ext not in std_exts + long_exts:
-            long_exts.append(implied_ext)
-            any_change = True
+  # Get all current extensions
+  current_exts = std_exts + long_exts
+
+  # Resolve dependencies
+  implied_deps = resolve_dependencies(current_exts, xlen)
+
+  # Filter out zicsr for ISA spec 2.2
+  if is_isa_spec_2p2:
+    implied_deps.discard('zicsr')
+
+  # Add implied dependencies to long_exts
+  for dep in implied_deps:
+    if dep not in current_exts:
+      long_exts.append(dep)
 
   # Single letter extension might appear in the long_exts list,
   # because we just append extensions list to the arch string.
@@ -179,17 +426,171 @@ def arch_canonicalize(arch, isa_spec):
 
   return new_arch
 
-if len(sys.argv) < 2:
-  print ("Usage: %s <arch_str> [<arch_str>*]" % sys.argv)
-  sys.exit(1)
+def dump_all_extensions():
+  """Dump all extensions and their implied extensions."""
+  implied_ext, all_extensions = get_all_extensions()
 
-parser = argparse.ArgumentParser()
-parser.add_argument('-misa-spec', type=str,
-                    default='20191213',
-                    choices=SUPPORTED_ISA_SPEC)
-parser.add_argument('arch_strs', nargs=argparse.REMAINDER)
+  print("All supported RISC-V extensions:")
+  print("=" * 60)
 
-args = parser.parse_args()
+  if not all_extensions:
+    print("No extensions found.")
+    return
 
-for arch in args.arch_strs:
-  print (arch_canonicalize(arch, args.misa_spec))
+  # Sort all extensions for consistent output
+  sorted_all = sorted(all_extensions)
+
+  # Print all extensions with their dependencies (if any)
+  for ext_name in sorted_all:
+    if ext_name in implied_ext:
+      deps = implied_ext[ext_name]
+      dep_strs = []
+      for dep in deps:
+        if dep['type'] == 'simple':
+          dep_strs.append(dep['ext'])
+        else:
+          dep_strs.append(f"{dep['ext']}*")  # Mark conditional deps with *
+      print(f"{ext_name:15} -> {', '.join(dep_strs)}")
+    else:
+      print(f"{ext_name:15} -> (no dependencies)")
+
+  print(f"\nTotal extensions: {len(all_extensions)}")
+  print(f"Extensions with dependencies: {len(implied_ext)}")
+  print(f"Extensions without dependencies: {len(all_extensions) - 
len(implied_ext)}")
+
+def run_unit_tests():
+  """Run unit tests using pytest dynamically imported."""
+  try:
+    import pytest
+  except ImportError:
+    print("Error: pytest is required for running unit tests.")
+    print("Please install pytest: pip install pytest")
+    return 1
+
+  # Define test functions
+  def test_basic_arch_parsing():
+    """Test basic architecture string parsing."""
+    result = arch_canonicalize("rv64i", "20191213")
+    assert result == "rv64i"
+
+  def test_simple_extensions():
+    """Test simple extension handling."""
+    result = arch_canonicalize("rv64im", "20191213")
+    assert "zmmul" in result
+
+  def test_implied_extensions():
+    """Test implied extension resolution."""
+    result = arch_canonicalize("rv64imaf", "20191213")
+    assert "zicsr" in result
+
+  def test_conditional_dependencies():
+    """Test conditional dependency evaluation."""
+    # Test RV32 with F extension should include zcf when c is present
+    result = arch_canonicalize("rv32ifc", "20191213")
+    parts = result.split("_")
+    if "c" in parts:
+      assert "zca" in parts
+      if "f" in parts:
+        assert "zcf" in parts
+
+  def test_parse_dep_exts():
+    """Test dependency parsing function."""
+    # Test simple dependency
+    deps = parse_dep_exts('{"ext1", "ext2"}')
+    assert len(deps) == 2
+    assert deps[0]['ext'] == 'ext1'
+    assert deps[0]['type'] == 'simple'
+
+  def test_evaluate_conditional_dependency():
+    """Test conditional dependency evaluation."""
+    # Test zcf condition for RV32 with F
+    dep = {'ext': 'zcf', 'type': 'conditional', 'condition': 'test'}
+    result = evaluate_conditional_dependency('zce', dep, 32, {'f'})
+    assert result == True
+
+    # Test zcf condition for RV64 with F (should be False)
+    result = evaluate_conditional_dependency('zce', dep, 64, {'f'})
+    assert result == False
+
+  def test_parse_define_riscv_ext():
+    """Test DEFINE_RISCV_EXT parsing."""
+    content = '''
+    DEFINE_RISCV_EXT(
+      /* NAME */ test,
+      /* UPPERCASE_NAME */ TEST,
+      /* FULL_NAME */ "Test extension",
+      /* DESC */ "",
+      /* URL */ ,
+      /* DEP_EXTS */ ({"dep1", "dep2"}),
+      /* SUPPORTED_VERSIONS */ ({{1, 0}}),
+      /* FLAG_GROUP */ test,
+      /* BITMASK_GROUP_ID */ 0,
+      /* BITMASK_BIT_POSITION*/ 0,
+      /* EXTRA_EXTENSION_FLAGS */ 0)
+    '''
+
+    extensions = parse_define_riscv_ext(content)
+    assert len(extensions) == 1
+    assert extensions[0]['name'] == 'test'
+    assert len(extensions[0]['dep_exts']) == 2
+
+  # Collect test functions
+  test_functions = [
+    test_basic_arch_parsing,
+    test_simple_extensions,
+    test_implied_extensions,
+    test_conditional_dependencies,
+    test_parse_dep_exts,
+    test_evaluate_conditional_dependency,
+    test_parse_define_riscv_ext
+  ]
+
+  # Run tests manually first, then optionally with pytest
+  print("Running unit tests...")
+
+  passed = 0
+  failed = 0
+
+  for i, test_func in enumerate(test_functions):
+    try:
+      print(f"  Running {test_func.__name__}...", end=" ")
+      test_func()
+      print("PASSED")
+      passed += 1
+    except Exception as e:
+      print(f"FAILED: {e}")
+      failed += 1
+
+  print(f"\nTest Summary: {passed} passed, {failed} failed")
+
+  if failed == 0:
+    print("\nAll tests passed!")
+    return 0
+  else:
+    print(f"\n{failed} test(s) failed!")
+    return 1
+
+if __name__ == "__main__":
+  parser = argparse.ArgumentParser()
+  parser.add_argument('-misa-spec', type=str,
+                      default='20191213',
+                      choices=SUPPORTED_ISA_SPEC)
+  parser.add_argument('--dump-all', action='store_true',
+                      help='Dump all extensions and their implied extensions')
+  parser.add_argument('--selftest', action='store_true',
+                      help='Run unit tests using pytest')
+  parser.add_argument('arch_strs', nargs='*',
+                      help='Architecture strings to canonicalize')
+
+  args = parser.parse_args()
+
+  if args.dump_all:
+    dump_all_extensions()
+  elif args.selftest:
+    sys.exit(run_unit_tests())
+  elif args.arch_strs:
+    for arch in args.arch_strs:
+      print (arch_canonicalize(arch, args.misa_spec))
+  else:
+    parser.print_help()
+    sys.exit(1)

Reply via email to