https://github.com/svkeerthy updated 
https://github.com/llvm/llvm-project/pull/164332

>From 3769e8450d3e46a299e8c03f26b5caaf22fbf74a Mon Sep 17 00:00:00 2001
From: svkeerthy <[email protected]>
Date: Mon, 20 Oct 2025 23:03:17 +0000
Subject: [PATCH] support mir vocab training

---
 .../mlgo-utils/IR2Vec/generateTriplets.py     | 165 ++++++++++++++----
 1 file changed, 131 insertions(+), 34 deletions(-)

diff --git a/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py 
b/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py
index 80ac4c61c7871..dba9e2c137586 100644
--- a/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py
+++ b/llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py
@@ -1,14 +1,19 @@
 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""IR2Vec Triplet Generator
+"""IR2Vec/MIR2Vec Triplet Generator
 
-Generates IR2Vec triplets by applying random optimization levels to LLVM IR 
files
-and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
-files: entity2id.txt, relation2id.txt, and train2id.txt.
+Generates IR2Vec or MIR2Vec triplets by applying random optimization levels to
+LLVM IR files (or processing MIR files) and extracting triplets using 
llvm-ir2vec.
+Automatically generates preprocessed files (entity2id.txt, relation2id.txt, and
+train2id.txt) necessary for training IR2Vec or MIR2Vec vocabularies.
 
 Usage:
-    python generateTriplets.py <llvm_build_dir> <num_optimizations> 
<ll_file_list> <output_dir>
+    For LLVM IR:
+        python generateTriplets.py <llvm_build_dir> <num_optimizations> 
<ll_file_list> <output_dir>
+
+    For Machine IR:
+        python generateTriplets.py --mode=mir <llvm_build_dir> <mir_file_list> 
<output_dir>
 """
 
 import argparse
@@ -41,7 +46,7 @@ def __init__(self, triplets: Set[str], max_relation: int):
 
 
 class IR2VecTripletGenerator:
-    """Main class for generating IR2Vec triplets"""
+    """Main class for generating IR2Vec or MIR2Vec triplets"""
 
     def __init__(
         self,
@@ -49,11 +54,13 @@ def __init__(
         num_optimizations: int,
         output_dir: Path,
         max_workers: int = DEFAULT_MAX_WORKERS,
+        mode: str = "llvm",
     ):
         self.llvm_build_dir = llvm_build_dir
         self.num_optimizations = num_optimizations
         self.output_dir = output_dir
         self.max_workers = max_workers
+        self.mode = mode  # "llvm" or "mir"
 
         # Tool paths
         self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
@@ -85,7 +92,11 @@ def _validate_setup(self):
                 f"llvm-ir2vec binary not found or not executable: 
{self.ir2vec_binary}"
             )
 
-        if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
+        if self.mode not in ["llvm", "mir"]:
+            raise ValueError(f"Mode must be 'llvm' or 'mir', got: {self.mode}")
+
+        # For LLVM IR mode, validate optimization count
+        if self.mode == "llvm" and not (1 <= self.num_optimizations <= 
len(OPT_LEVELS)):
             raise ValueError(
                 f"Number of optimizations must be between 1-{len(OPT_LEVELS)}"
             )
@@ -95,19 +106,28 @@ def _select_optimization_levels(self) -> List[str]:
         return random.sample(OPT_LEVELS, self.num_optimizations)
 
     def _process_single_file(self, input_file: Path) -> TripletResult:
-        """Process a single LLVM IR file with multiple optimization levels"""
+        """Process a single LLVM IR or MIR file"""
         all_triplets = set()
         max_relation = 1
-        opt_levels = self._select_optimization_levels()
 
-        for opt_level in opt_levels:
-            triplets, file_max_relation = self._run_pipeline(input_file, 
opt_level)
+        if self.mode == "mir":
+            # For MIR files, process directly without optimization
+            triplets, file_max_relation = self._run_mir_pipeline(input_file)
             if triplets:
                 all_triplets.update(triplets)
                 max_relation = max(max_relation, file_max_relation)
-                logger.debug(
-                    f"Generated {len(triplets)} triplets for {input_file} with 
{opt_level}"
-                )
+                logger.debug(f"Generated {len(triplets)} triplets for 
{input_file}")
+        else:
+            # For LLVM IR files, apply multiple optimization levels
+            opt_levels = self._select_optimization_levels()
+            for opt_level in opt_levels:
+                triplets, file_max_relation = self._run_pipeline(input_file, 
opt_level)
+                if triplets:
+                    all_triplets.update(triplets)
+                    max_relation = max(max_relation, file_max_relation)
+                    logger.debug(
+                        f"Generated {len(triplets)} triplets for {input_file} 
with {opt_level}"
+                    )
 
         return TripletResult(all_triplets, max_relation)
 
@@ -124,7 +144,7 @@ def _run_pipeline(self, input_file: Path, opt_level: str) 
-> Tuple[Set[str], int
 
             # Run llvm-ir2vec with opt's output as input
             ir2vec_proc = subprocess.Popen(
-                [self.ir2vec_binary, "triplets", "-", "-o", "-"],
+                [self.ir2vec_binary, "triplets", "--mode=llvm", "-", "-o", 
"-"],
                 stdin=opt_proc.stdout,
                 stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE,
@@ -143,6 +163,32 @@ def _run_pipeline(self, input_file: Path, opt_level: str) 
-> Tuple[Set[str], int
         except (subprocess.SubprocessError, OSError):
             return set(), 1
 
+    def _run_mir_pipeline(self, input_file: Path) -> Tuple[Set[str], int]:
+        """Run llvm-ir2vec pipeline for MIR files."""
+        try:
+            # Run llvm-ir2vec directly on MIR file
+            result = subprocess.run(
+                [
+                    self.ir2vec_binary,
+                    "triplets",
+                    "--mode=mir",
+                    str(input_file),
+                    "-o",
+                    "-",
+                ],
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE,
+                text=True,
+                check=False,
+            )
+
+            if result.returncode != 0:
+                return set(), 1
+
+            return self._parse_triplet_output(result.stdout)
+        except (subprocess.SubprocessError, OSError):
+            return set(), 1
+
     def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
         """Parse triplet output and extract max relation"""
         if not output.strip():
@@ -160,12 +206,21 @@ def _parse_triplet_output(self, output: str) -> 
Tuple[Set[str], int]:
         return set(lines), max_relation
 
     def generate_triplets(self, file_list: Path) -> None:
-        """Main method to generate triplets from a list of LLVM IR files"""
+        """Main method to generate triplets from a list of LLVM IR or MIR 
files"""
+        # Store file_list_path for later use in entity generation
+        self.file_list_path = file_list
+
         input_files = self._read_file_list(file_list)
-        logger.info(
-            f"Processing {len(input_files)} files with 
{self.num_optimizations} "
-            f"optimization levels using {self.max_workers} workers"
-        )
+
+        if self.mode == "mir":
+            logger.info(
+                f"Processing {len(input_files)} MIR files using 
{self.max_workers} workers"
+            )
+        else:
+            logger.info(
+                f"Processing {len(input_files)} files with 
{self.num_optimizations} "
+                f"optimization levels using {self.max_workers} workers"
+            )
 
         all_triplets = set()
         global_max_relation = 1
@@ -222,28 +277,60 @@ def _generate_output_files(self, all_triplets: Set[str], 
max_relation: int) -> N
 
     def _generate_entity2id(self, output_file: Path) -> None:
         """Generate entity2id.txt using llvm-ir2vec"""
-        subprocess.run(
-            [str(self.ir2vec_binary), "entities", "-o", str(output_file)],
-            check=True,
-            capture_output=True,
-        )
+        if self.mode == "mir":
+            # For MIR mode, we need to provide a sample MIR file to determine 
target
+            # Use the first file from the processed list
+            input_files = self._read_file_list(self.file_list_path)
+            if not input_files:
+                raise ValueError("No input files available for entity 
generation")
+
+            subprocess.run(
+                [
+                    str(self.ir2vec_binary),
+                    "entities",
+                    "--mode=mir",
+                    str(input_files[0]),
+                    "-o",
+                    str(output_file),
+                ],
+                check=True,
+                capture_output=True,
+            )
+        else:
+            subprocess.run(
+                [
+                    str(self.ir2vec_binary),
+                    "entities",
+                    "--mode=llvm",
+                    "-o",
+                    str(output_file),
+                ],
+                check=True,
+                capture_output=True,
+            )
 
     def _generate_relation2id(self, output_file: Path, max_relation: int) -> 
None:
         """Generate relation2id.txt from max relation"""
-        max_relation = max(max_relation, 1)  # At least Type and Next relations
+        max_relation = max(max_relation, 1)  # At least Next relation
         num_relations = max_relation + 1
 
         with open(output_file, "w") as f:
             f.write(f"{num_relations}\n")
-            f.write("Type\t0\n")
-            f.write("Next\t1\n")
-            f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
+            if self.mode == "llvm":
+                # LLVM IR has Type relation at 0
+                f.write("Type\t0\n")
+                f.write("Next\t1\n")
+                f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, 
num_relations))
+            else:
+                # MIR doesn't have Type relation, starts with Next at 0
+                f.write("Next\t0\n")
+                f.writelines(f"Arg{i-1}\t{i}\n" for i in range(1, 
num_relations))
 
 
 def main():
     """Main entry point"""
     parser = argparse.ArgumentParser(
-        description="Generate IR2Vec triplets from LLVM IR files",
+        description="Generate IR2Vec or MIR2Vec triplets from LLVM IR or 
Machine IR files",
         formatter_class=argparse.RawDescriptionHelpFormatter,
     )
 
@@ -253,16 +340,25 @@ def main():
     parser.add_argument(
         "num_optimizations",
         type=int,
-        help="Number of optimization levels to apply (1-6)",
+        nargs="?",
+        default=1,
+        help="Number of optimization levels to apply (1-6) for LLVM IR mode",
     )
     parser.add_argument(
-        "ll_file_list",
+        "input_file_list",
         type=Path,
-        help="File containing list of LLVM IR files to process",
+        help="File containing list of LLVM IR or MIR files to process",
     )
     parser.add_argument(
         "output_dir", type=Path, help="Output directory for generated files"
     )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        choices=["llvm", "mir"],
+        default="llvm",
+        help="Operation mode: 'llvm' for LLVM IR (default) or 'mir' for 
Machine IR",
+    )
     parser.add_argument(
         "-j",
         "--max-workers",
@@ -296,8 +392,9 @@ def main():
         args.num_optimizations,
         args.output_dir,
         args.max_workers,
+        args.mode,
     )
-    generator.generate_triplets(args.ll_file_list)
+    generator.generate_triplets(args.input_file_list)
 
 
 if __name__ == "__main__":

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to