https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/84376
Replacing an element of an operation range while traversing the range can make the range invalid. Store the operations in a separate list, and traverse the list instead. Additionally, avoid inspecting an operation that has been replaced. This was detected by address sanitizer. >From 58cda6db7030e178fbd861312dcee372e1558611 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Thu, 7 Mar 2024 14:21:45 -0600 Subject: [PATCH] [flang][CodeGen] Fix use-after-free in BoxedProcedurePass Replacing an element of an operation range while traversing the range can make the range invalid. Store the operations in a separate list, and traverse the list instead. Additionally, avoid inspecting an operation that has been replaced. This was detected by address sanitizer. --- .../lib/Optimizer/CodeGen/BoxedProcedure.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp index 4cf39716a73755..2e34b0a1b492b1 100644 --- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp +++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp @@ -208,7 +208,12 @@ class BoxedProcedurePass mlir::IRRewriter rewriter(context); BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); mlir::Dialect *firDialect = context->getLoadedDialect("fir"); - getModule().walk([&](mlir::Operation *op) { + llvm::SmallVector<mlir::Operation *> operations; + + getModule().walk([&](mlir::Operation *op) { operations.push_back(op); }); + + for (mlir::Operation *op : operations) { + bool opIsValid = true; typeConverter.setLocation(op->getLoc()); if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { mlir::Type ty = addr.getVal().getType(); @@ -220,6 +225,7 @@ class BoxedProcedurePass rewriter.setInsertionPoint(addr); rewriter.replaceOpWithNewOp<ConvertOp>( addr, typeConverter.convertType(addr.getType()), addr.getVal()); + opIsValid = false; } else if (typeConverter.needsConversion(resTy)) { rewriter.startOpModification(op); op->getResult(0).setType(typeConverter.convertType(resTy)); @@ -271,10 +277,12 @@ class BoxedProcedurePass llvm::ArrayRef<mlir::Value>{tramp}); rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, adjustCall.getResult(0)); + opIsValid = false; } else { // Just forward the function as a pointer. rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, embox.getFunc()); + opIsValid = false; } } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { auto ty = global.getType(); @@ -297,6 +305,7 @@ class BoxedProcedurePass rewriter.replaceOpWithNewOp<AllocaOp>( mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), mem.getShape()); + opIsValid = false; } } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { auto ty = mem.getType(); @@ -310,6 +319,7 @@ class BoxedProcedurePass rewriter.replaceOpWithNewOp<AllocMemOp>( mem, toTy, uniqName, bindcName, mem.getTypeparams(), mem.getShape()); + opIsValid = false; } } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { auto ty = coor.getType(); @@ -321,6 +331,7 @@ class BoxedProcedurePass auto toBaseTy = typeConverter.convertType(baseTy); rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), coor.getCoor(), toBaseTy); + opIsValid = false; } } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { auto ty = index.getType(); @@ -332,6 +343,7 @@ class BoxedProcedurePass auto toOnTy = typeConverter.convertType(onTy); rewriter.replaceOpWithNewOp<FieldIndexOp>( index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); + opIsValid = false; } } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { auto ty = index.getType(); @@ -343,6 +355,7 @@ class BoxedProcedurePass auto toOnTy = typeConverter.convertType(onTy); rewriter.replaceOpWithNewOp<LenParamIndexOp>( index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); + opIsValid = false; } } else if (op->getDialect() == firDialect) { rewriter.startOpModification(op); @@ -354,7 +367,7 @@ class BoxedProcedurePass rewriter.finalizeOpModification(op); } // Ensure block arguments are updated if needed. - if (op->getNumRegions() != 0) { + if (opIsValid && op->getNumRegions() != 0) { rewriter.startOpModification(op); for (mlir::Region ®ion : op->getRegions()) for (mlir::Block &block : region.getBlocks()) @@ -366,7 +379,7 @@ class BoxedProcedurePass } rewriter.finalizeOpModification(op); } - }); + } } } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits