https://github.com/mmha created https://github.com/llvm/llvm-project/pull/138317
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect. >From 2b6ecd77c4fac0a2982172294d12ae858f0a2b34 Mon Sep 17 00:00:00 2001 From: Morris Hafner <mhaf...@nvidia.com> Date: Fri, 2 May 2025 20:05:40 +0200 Subject: [PATCH] [CIR] Add cir-simplify pass This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect. --- clang/include/clang/CIR/CIRToCIRPasses.h | 3 +- .../clang/CIR/Dialect/IR/CIRDialect.td | 2 + clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 + clang/include/clang/CIR/Dialect/Passes.h | 1 + clang/include/clang/CIR/Dialect/Passes.td | 14 ++ .../clang/CIR/FrontendAction/CIRGenAction.h | 2 +- clang/include/clang/CIR/MissingFeatures.h | 1 - clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 30 +++ .../Dialect/Transforms/CIRCanonicalize.cpp | 3 +- .../CIR/Dialect/Transforms/CIRSimplify.cpp | 184 ++++++++++++++++++ .../lib/CIR/Dialect/Transforms/CMakeLists.txt | 1 + clang/lib/CIR/FrontendAction/CIRGenAction.cpp | 21 +- clang/lib/CIR/Lowering/CIRPasses.cpp | 6 +- clang/test/CIR/Transforms/select.cir | 60 ++++++ clang/test/CIR/Transforms/ternary-fold.cir | 60 ++++++ clang/tools/cir-opt/cir-opt.cpp | 3 + 16 files changed, 378 insertions(+), 15 deletions(-) create mode 100644 clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp create mode 100644 clang/test/CIR/Transforms/select.cir create mode 100644 clang/test/CIR/Transforms/ternary-fold.cir diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h index 361ebb9e9b840..4a23790ee8b76 100644 --- a/clang/include/clang/CIR/CIRToCIRPasses.h +++ b/clang/include/clang/CIR/CIRToCIRPasses.h @@ -32,7 +32,8 @@ namespace cir { mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule, mlir::MLIRContext &mlirCtx, clang::ASTContext &astCtx, - bool enableVerifier); + bool enableVerifier, + bool enableCIRSimplify); } // namespace cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td index 73759cfa9c3c9..818a605ab74d3 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td @@ -27,6 +27,8 @@ def CIR_Dialect : Dialect { let useDefaultAttributePrinterParser = 0; let useDefaultTypePrinterParser = 0; + let hasConstantMaterializer = 1; + let extraClassDeclaration = [{ static llvm::StringRef getTripleAttrName() { return "cir.triple"; } diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 9215543ab67e6..8205718e0fc30 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure, qualified(type($false_value)) `)` `->` qualified(type($result)) attr-dict }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h index 133eb462dcf1f..dbecf81acf7bb 100644 --- a/clang/include/clang/CIR/Dialect/Passes.h +++ b/clang/include/clang/CIR/Dialect/Passes.h @@ -22,6 +22,7 @@ namespace mlir { std::unique_ptr<Pass> createCIRCanonicalizePass(); std::unique_ptr<Pass> createCIRFlattenCFGPass(); +std::unique_ptr<Pass> createCIRSimplifyPass(); std::unique_ptr<Pass> createHoistAllocasPass(); void populateCIRPreLoweringPasses(mlir::OpPassManager &pm); diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td index 74c255861c879..46fa97da04ca1 100644 --- a/clang/include/clang/CIR/Dialect/Passes.td +++ b/clang/include/clang/CIR/Dialect/Passes.td @@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> { let dependentDialects = ["cir::CIRDialect"]; } +def CIRSimplify : Pass<"cir-simplify"> { + let summary = "Performs CIR simplification and code optimization"; + let description = [{ + The pass performs code simplification and optimization on CIR. + + Unlike the `cir-canonicalize` pass, this pass contains more aggresive code + transformations that could significantly affect CIR-to-source fidelity. + Example transformations performed in this pass include ternary folding, + code hoisting, etc. + }]; + let constructor = "mlir::createCIRSimplifyPass()"; + let dependentDialects = ["cir::CIRDialect"]; +} + def HoistAllocas : Pass<"cir-hoist-allocas"> { let summary = "Hoist allocas to the entry of the function"; let description = [{ diff --git a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h index 99495f4718c5f..b52166b58b882 100644 --- a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h +++ b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h @@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction { public: ~CIRGenAction() override; - OutputType Action; + OutputType action; }; class EmitCIRAction : public CIRGenAction { diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 3db13278261e6..b26144095792d 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -199,7 +199,6 @@ struct MissingFeatures { static bool labelOp() { return false; } static bool ptrDiffOp() { return false; } static bool ptrStrideOp() { return false; } - static bool selectOp() { return false; } static bool switchOp() { return false; } static bool ternaryOp() { return false; } static bool tryOp() { return false; } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index f5d6a424a71f6..5356630ece196 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() { addInterfaces<CIROpAsmDialectInterface>(); } +Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + return builder.create<cir::ConstantOp>(loc, type, + mlir::cast<mlir::TypedAttr>(value)); +} + //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// @@ -1261,6 +1269,28 @@ void cir::TernaryOp::build( result.addTypes(TypeRange{yield.getOperandTypes().front()}); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) { + mlir::Attribute condition = adaptor.getCondition(); + if (condition) { + bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue(); + return conditionValue ? getTrueValue() : getFalseValue(); + } + + // cir.select if %0 then x else x -> x + mlir::Attribute trueValue = adaptor.getTrueValue(); + mlir::Attribute falseValue = adaptor.getFalseValue(); + if (trueValue == falseValue) + return trueValue; + if (getTrueValue() == getFalseValue()) + return getTrueValue(); + + return {}; +} + //===----------------------------------------------------------------------===// // ShiftOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp index cdac69e66dba3..3b4c7bc613133 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp @@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() { getOperation()->walk([&](Operation *op) { assert(!cir::MissingFeatures::switchOp()); assert(!cir::MissingFeatures::tryOp()); - assert(!cir::MissingFeatures::selectOp()); assert(!cir::MissingFeatures::complexCreateOp()); assert(!cir::MissingFeatures::complexRealOp()); assert(!cir::MissingFeatures::complexImagOp()); assert(!cir::MissingFeatures::callOp()); // CastOp and UnaryOp are here to perform a manual `fold` in // applyOpPatternsGreedily. - if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op)) + if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op)) ops.push_back(op); }); diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp new file mode 100644 index 0000000000000..442801d062638 --- /dev/null +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -0,0 +1,184 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/Passes.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace cir; + +//===----------------------------------------------------------------------===// +// Rewrite patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Simplify suitable ternary operations into select operations. +/// +/// For now we only simplify those ternary operations whose true and false +/// branches directly yield a value or a constant. That is, both of the true and +/// the false branch must either contain a cir.yield operation as the only +/// operation in the branch, or contain a cir.const operation followed by a +/// cir.yield operation that yields the constant value. +/// +/// For example, we will simplify the following ternary operation: +/// +/// %0 = cir.ternary (%condition, true { +/// %1 = cir.const ... +/// cir.yield %1 +/// } false { +/// cir.yield %2 +/// }) +/// +/// into the following sequence of operations: +/// +/// %1 = cir.const ... +/// %0 = cir.select if %condition then %1 else %2 +struct SimplifyTernary final : public OpRewritePattern<TernaryOp> { + using OpRewritePattern<TernaryOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(TernaryOp op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) + return mlir::failure(); + + if (!isSimpleTernaryBranch(op.getTrueRegion()) || + !isSimpleTernaryBranch(op.getFalseRegion())) + return mlir::failure(); + + cir::YieldOp trueBranchYieldOp = + mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator()); + cir::YieldOp falseBranchYieldOp = + mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator()); + mlir::Value trueValue = trueBranchYieldOp.getArgs()[0]; + mlir::Value falseValue = falseBranchYieldOp.getArgs()[0]; + + rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op); + rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op); + rewriter.eraseOp(trueBranchYieldOp); + rewriter.eraseOp(falseBranchYieldOp); + rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue, + falseValue); + + return mlir::success(); + } + +private: + bool isSimpleTernaryBranch(mlir::Region ®ion) const { + if (!region.hasOneBlock()) + return false; + + mlir::Block &onlyBlock = region.front(); + mlir::Block::OpListType &ops = onlyBlock.getOperations(); + + // The region/block could only contain at most 2 operations. + if (ops.size() > 2) + return false; + + if (ops.size() == 1) { + // The region/block only contain a cir.yield operation. + return true; + } + + // Check whether the region/block contains a cir.const followed by a + // cir.yield that yields the value. + auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator()); + auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>( + yieldOp.getArgs()[0].getDefiningOp()); + return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock; + } +}; + +struct SimplifySelect : public OpRewritePattern<SelectOp> { + using OpRewritePattern<SelectOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const final { + mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp(); + mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp(); + auto trueValueConstOp = + mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp); + auto falseValueConstOp = + mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp); + if (!trueValueConstOp || !falseValueConstOp) + return mlir::failure(); + + auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue()); + auto falseValue = + mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue()); + if (!trueValue || !falseValue) + return mlir::failure(); + + // cir.select if %0 then #true else #false -> %0 + if (trueValue.getValue() && !falseValue.getValue()) { + rewriter.replaceAllUsesWith(op, op.getCondition()); + rewriter.eraseOp(op); + return mlir::success(); + } + + // cir.select if %0 then #false else #true -> cir.unary not %0 + if (!trueValue.getValue() && falseValue.getValue()) { + rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not, + op.getCondition()); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +//===----------------------------------------------------------------------===// +// CIRSimplifyPass +//===----------------------------------------------------------------------===// + +struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> { + using CIRSimplifyBase::CIRSimplifyBase; + + void runOnOperation() override; +}; + +void populateMergeCleanupPatterns(RewritePatternSet &patterns) { + // clang-format off + patterns.add< + SimplifyTernary, + SimplifySelect + >(patterns.getContext()); + // clang-format on +} + +void CIRSimplifyPass::runOnOperation() { + // Collect rewrite patterns. + RewritePatternSet patterns(&getContext()); + populateMergeCleanupPatterns(patterns); + + // Collect operations to apply patterns. + llvm::SmallVector<Operation *, 16> ops; + getOperation()->walk([&](Operation *op) { + if (isa<TernaryOp, SelectOp>(op)) + ops.push_back(op); + }); + + // Apply patterns. + if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) + signalPassFailure(); +} + +} // namespace + +std::unique_ptr<Pass> mlir::createCIRSimplifyPass() { + return std::make_unique<CIRSimplifyPass>(); +} diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt index 4678435b54c79..4dece5b57e450 100644 --- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt +++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_clang_library(MLIRCIRTransforms CIRCanonicalize.cpp + CIRSimplify.cpp FlattenCFG.cpp HoistAllocas.cpp diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp index a32e6a7584774..570403dda9d9f 100644 --- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp +++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp @@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer { IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS; std::unique_ptr<CIRGenerator> Gen; const FrontendOptions &FEOptions; + CodeGenOptions &codeGenOptions; public: CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI, + CodeGenOptions &codeGenOptions, std::unique_ptr<raw_pwrite_stream> OS) : Action(Action), CI(CI), OutputStream(std::move(OS)), FS(&CI.getVirtualFileSystem()), Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS), CI.getCodeGenOpts())), - FEOptions(CI.getFrontendOpts()) {} + FEOptions(CI.getFrontendOpts()), codeGenOptions(codeGenOptions) {} void Initialize(ASTContext &Ctx) override { assert(!Context && "initialized multiple times"); @@ -102,7 +104,8 @@ class CIRGenConsumer : public clang::ASTConsumer { if (!FEOptions.ClangIRDisablePasses) { // Setup and run CIR pipeline. if (runCIRToCIRPasses(MlirModule, MlirCtx, C, - !FEOptions.ClangIRDisableCIRVerifier) + !FEOptions.ClangIRDisableCIRVerifier, + codeGenOptions.OptimizationLevel > 0) .failed()) { CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed); return; @@ -139,7 +142,7 @@ class CIRGenConsumer : public clang::ASTConsumer { void CIRGenConsumer::anchor() {} CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx) - : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), Action(Act) {} + : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), action(Act) {} CIRGenAction::~CIRGenAction() { MLIRMod.release(); } @@ -162,14 +165,14 @@ getOutputStream(CompilerInstance &CI, StringRef InFile, } std::unique_ptr<ASTConsumer> -CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) { - std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream(); +CIRGenAction::CreateASTConsumer(CompilerInstance &ci, StringRef inFile) { + std::unique_ptr<llvm::raw_pwrite_stream> out = ci.takeOutputStream(); - if (!Out) - Out = getOutputStream(CI, InFile, Action); + if (!out) + out = getOutputStream(ci, inFile, action); - auto Result = - std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out)); + auto Result = std::make_unique<cir::CIRGenConsumer>( + action, ci, ci.getCodeGenOpts(), std::move(out)); return Result; } diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp index a37a0480a56ac..7a581939580a9 100644 --- a/clang/lib/CIR/Lowering/CIRPasses.cpp +++ b/clang/lib/CIR/Lowering/CIRPasses.cpp @@ -20,13 +20,17 @@ namespace cir { mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule, mlir::MLIRContext &mlirContext, clang::ASTContext &astContext, - bool enableVerifier) { + bool enableVerifier, + bool enableCIRSimplify) { llvm::TimeTraceScope scope("CIR To CIR Passes"); mlir::PassManager pm(&mlirContext); pm.addPass(mlir::createCIRCanonicalizePass()); + if (enableCIRSimplify) + pm.addPass(mlir::createCIRSimplifyPass()); + pm.enableVerifier(enableVerifier); (void)mlir::applyPassManagerCLOptions(pm); return pm.run(theModule); diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir new file mode 100644 index 0000000000000..29a5d1ed1ddeb --- /dev/null +++ b/clang/test/CIR/Transforms/select.cir @@ -0,0 +1,60 @@ +// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i { + %0 = cir.const #cir.bool<true> : !cir.bool + %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: cir.return %[[ARG0]] : !s32i + // CHECK-NEXT: } + + cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i { + %0 = cir.const #cir.bool<false> : !cir.bool + %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: cir.return %[[ARG1]] : !s32i + // CHECK-NEXT: } + + cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i { + %0 = cir.const #cir.int<42> : !s32i + %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i { + // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i + // CHECK-NEXT: cir.return %[[#A]] : !s32i + // CHECK-NEXT: } + + cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool { + %0 = cir.const #cir.bool<true> : !cir.bool + %1 = cir.const #cir.bool<false> : !cir.bool + %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %2 : !cir.bool + } + + // CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool { + // CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool + // CHECK-NEXT: } + + cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool { + %0 = cir.const #cir.bool<false> : !cir.bool + %1 = cir.const #cir.bool<true> : !cir.bool + %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %2 : !cir.bool + } + + // CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool { + // CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool + // CHECK-NEXT: cir.return %[[#A]] : !cir.bool + // CHECK-NEXT: } +} diff --git a/clang/test/CIR/Transforms/ternary-fold.cir b/clang/test/CIR/Transforms/ternary-fold.cir new file mode 100644 index 0000000000000..72ba4815b2db2 --- /dev/null +++ b/clang/test/CIR/Transforms/ternary-fold.cir @@ -0,0 +1,60 @@ +// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s +// RUN: FileCheck --input-file=%t.cir %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i { + %0 = cir.const #cir.bool<false> : !cir.bool + %1 = cir.ternary (%0, true { + cir.yield %arg0 : !s32i + }, false { + cir.yield %arg1 : !s32i + }) : (!cir.bool) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: cir.return %[[ARG]] : !s32i + // CHECK-NEXT: } + + cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i { + %0 = cir.ternary (%arg0, true { + %1 = cir.const #cir.int<42> : !s32i + cir.yield %1 : !s32i + }, false { + cir.yield %arg1 : !s32i + }) : (!cir.bool) -> !s32i + cir.return %0 : !s32i + } + + // CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i { + // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i + // CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i + // CHECK-NEXT: cir.return %[[#B]] : !s32i + // CHECK-NEXT: } + + cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] + %1 = cir.ternary (%arg0, true { + %2 = cir.const #cir.int<42> : !s32i + cir.yield %2 : !s32i + }, false { + %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.yield %3 : !s32i + }) : (!cir.bool) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i { + // CHECK-NEXT: %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] + // CHECK-NEXT: %[[#B:]] = cir.ternary(%[[ARG0]], true { + // CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i + // CHECK-NEXT: cir.yield %[[#C]] : !s32i + // CHECK-NEXT: }, false { + // CHECK-NEXT: %[[#D:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i + // CHECK-NEXT: cir.yield %[[#D]] : !s32i + // CHECK-NEXT: }) : (!cir.bool) -> !s32i + // CHECK-NEXT: cir.return %[[#B]] : !s32i + // CHECK-NEXT: } +} diff --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp index e50fa70582966..0e20b97feced8 100644 --- a/clang/tools/cir-opt/cir-opt.cpp +++ b/clang/tools/cir-opt/cir-opt.cpp @@ -37,6 +37,9 @@ int main(int argc, char **argv) { ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return mlir::createCIRCanonicalizePass(); }); + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::createCIRSimplifyPass(); + }); mlir::PassPipelineRegistration<CIRToLLVMPipelineOptions> pipeline( "cir-to-llvm", "", _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits