================ @@ -3529,6 +3549,86 @@ static void genMapInfos(llvm::IRBuilderBase &builder, } } +static llvm::Expected<llvm::Function *> +emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); + +static llvm::Expected<llvm::Function *> +getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast<omp::DeclareMapperOp>(op); + std::string mapperFuncName = + moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName)) + return lookupFunc; + + llvm::Expected<llvm::Function *> mapperFunc = + emitUserDefinedMapper(declMapperOp, builder, moduleTranslation); + if (!mapperFunc) + return mapperFunc.takeError(); + moduleTranslation.mapFunction(mapperFuncName, *mapperFunc); + return mapperFunc; +} + +static llvm::Expected<llvm::Function *> +emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast<omp::DeclareMapperOp>(op); + auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo(); + DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType()); + std::string mapperName = ompBuilder->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + SmallVector<Value> mapVars = declMapperInfoOp.getMapVars(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + // Fill up the arrays with all the mapped variables. + MapInfosTy combinedInfo; + auto genMapInfoCB = + [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI, + llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy { + builder.restoreIP(codeGenIP); + moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI); + moduleTranslation.mapBlock(&declMapperOp.getRegion().front(), + builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(), + /*ignoreArguments=*/true, + builder))) + return llvm::make_error<PreviouslyReportedError>(); + MapInfoData mapData; + collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, + builder); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData); + + // Drop the mapping that is no longer necessary so that the same region can + // be processed multiple times. + moduleTranslation.forgetMapping(declMapperOp.getRegion()); + return combinedInfo; + }; + + auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> { + llvm::Function *mapperFunc = nullptr; + if (combinedInfo.Mappers[i]) { + // Call the corresponding mapper function. + llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc( + combinedInfo.Mappers[i], builder, moduleTranslation); + if (!newFn) + return newFn.takeError(); + mapperFunc = *newFn; + } + return mapperFunc; ---------------- skatrak wrote:
```suggestion if (!combinedInfo.Mappers[i]) return nullptr; // Call the corresponding mapper function. return getOrCreateUserDefinedMapperFunc( combinedInfo.Mappers[i], builder, moduleTranslation); ``` https://github.com/llvm/llvm-project/pull/124746 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits