csigg updated this revision to Diff 434142.
csigg added a comment.

Fix.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D126158/new/

https://reviews.llvm.org/D126158

Files:
  mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
  mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h
  mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
  mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
  mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
  mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
  mlir/test/Dialect/LLVMIR/nvvm.mlir
  mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
  mlir/test/Target/LLVMIR/nvvmir.mlir
  utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Index: utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
===================================================================
--- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3379,7 +3379,9 @@
         ":IR",
         ":LLVMDialect",
         ":LLVMPassIncGen",
+        ":NVVMDialect",
         ":Pass",
+        ":Transforms",
     ],
 )
 
Index: mlir/test/Target/LLVMIR/nvvmir.mlir
===================================================================
--- mlir/test/Target/LLVMIR/nvvmir.mlir
+++ mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -33,6 +33,13 @@
   llvm.return %1 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+llvm.func @nvvm_rcp(%0: f32) -> f32 {
+  // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f
+  %1 = nvvm.rcp.approx.ftz.f %0 : f32
+  llvm.return %1 : f32
+}
+
 // CHECK-LABEL: @llvm_nvvm_barrier0
 llvm.func @llvm_nvvm_barrier0() {
   // CHECK: call void @llvm.nvvm.barrier0()
Index: mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
===================================================================
--- /dev/null
+++ mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -llvm-optimize-for-nvvm-target | FileCheck %s
+
+// CHECK-LABEL: llvm.func @fdiv_fp16
+llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
+  // CHECK-DAG: %[[c0:.*]]      = llvm.mlir.constant(0 : ui32) : i32
+  // CHECK-DAG: %[[mask:.*]]    = llvm.mlir.constant(2139095040 : ui32) : i32
+  // CHECK-DAG: %[[lhs:.*]]     = llvm.fpext %arg0 : f16 to f32
+  // CHECK-DAG: %[[rhs:.*]]     = llvm.fpext %arg1 : f16 to f32
+  // CHECK-DAG: %[[rcp:.*]]     = nvvm.rcp.approx.ftz.f %[[rhs]] : f32
+  // CHECK-DAG: %[[approx:.*]]  = llvm.fmul %[[lhs]], %[[rcp]] : f32
+  // CHECK-DAG: %[[neg:.*]]     = llvm.fneg %[[rhs]] : f32
+  // CHECK-DAG: %[[err:.*]]     = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32
+  // CHECK-DAG: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32
+  // CHECK-DAG: %[[cast:.*]]    = llvm.bitcast %[[approx]] : f32 to i32
+  // CHECK-DAG: %[[exp:.*]]     = llvm.and %[[cast]], %[[mask]] : i32
+  // CHECK-DAG: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32
+  // CHECK-DAG: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32
+  // CHECK-DAG: %[[pred:.*]]    = llvm.or %[[is_zero]], %[[is_mask]] : i1
+  // CHECK-DAG: %[[select:.*]]  = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32
+  // CHECK-DAG: %[[result:.*]]  = llvm.fptrunc %[[select]] : f32 to f16
+  %result = llvm.fdiv %arg0, %arg1 : f16
+  // CHECK: llvm.return %[[result]] : f16
+  llvm.return %result : f16
+}
Index: mlir/test/Dialect/LLVMIR/nvvm.mlir
===================================================================
--- mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -29,6 +29,13 @@
   llvm.return %0 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+func.func @nvvm_rcp(%arg0: f32) -> f32 {
+  // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32
+  %0 = nvvm.rcp.approx.ftz.f %arg0 : f32
+  llvm.return %0 : f32
+}
+
 // CHECK-LABEL: @llvm_nvvm_barrier0
 func.func @llvm_nvvm_barrier0() {
   // CHECK: nvvm.barrier0
Index: mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
===================================================================
--- /dev/null
+++ mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -0,0 +1,97 @@
+//===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
+// (conditional) Newton iteration.
+//
+// This as accurate as promoting the division to fp32 in the NVPTX backend, but
+// faster because it performs less Newton iterations, avoids the slow path
+// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
+// by the same divisor.
+struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
+  using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(LLVM::FDivOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+struct NVVMOptimizeForTarget
+    : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<NVVM::NVVMDialect>();
+  }
+};
+} // namespace
+
+LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
+                                            PatternRewriter &rewriter) const {
+  if (!op.getType().isF16())
+    return rewriter.notifyMatchFailure(op, "not f16");
+  Location loc = op.getLoc();
+
+  Type f32Type = rewriter.getF32Type();
+  Type i32Type = rewriter.getI32Type();
+
+  // Extend lhs and rhs to fp32.
+  Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
+  Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
+
+  // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
+  Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
+  Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
+
+  // Refine the approximation with one Newton iteration:
+  // float refined = approx + (lhs - approx * rhs) * rcp;
+  Value err = rewriter.create<LLVM::FMAOp>(
+      loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
+  Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
+
+  // Use refined value if approx is normal (exponent neither all 0 or all 1).
+  Value mask = rewriter.create<LLVM::ConstantOp>(
+      loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
+  Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
+  Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
+  Value zero = rewriter.create<LLVM::ConstantOp>(
+      loc, i32Type, rewriter.getUI32IntegerAttr(0));
+  Value pred = rewriter.create<LLVM::OrOp>(
+      loc,
+      rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
+      rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
+  Value result =
+      rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
+
+  // Replace with trucation back to fp16.
+  rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
+
+  return success();
+}
+
+void NVVMOptimizeForTarget::runOnOperation() {
+  MLIRContext *ctx = getOperation()->getContext();
+  RewritePatternSet patterns(ctx);
+  patterns.add<ExpandDivF16>(ctx);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    return signalPassFailure();
+}
+
+std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
+  return std::make_unique<NVVMOptimizeForTarget>();
+}
Index: mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
===================================================================
--- mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRLLVMIRTransforms
   LegalizeForExport.cpp
+  OptimizeForNVVM.cpp
 
   DEPENDS
   MLIRLLVMPassIncGen
Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -16,4 +16,9 @@
   let constructor = "mlir::LLVM::createLegalizeForExportPass()";
 }
 
+def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> {
+  let summary = "Optimize NVVM IR";
+  let constructor = "mlir::NVVM::createOptimizeForTargetPass()";
+}
+
 #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES
Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
+++ mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h
@@ -0,0 +1,25 @@
+//===- OptimizeForNVVM.h - Optimize LLVM IR for NVVM -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+namespace NVVM {
+
+/// Creates a pass that optimizes LLVM IR for the NVVM target.
+std::unique_ptr<Pass> createOptimizeForTargetPass();
+
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -51,21 +51,21 @@
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
 
-class NVVM_IntrOp<string mnem, list<int> overloadedResults,
-                  list<int> overloadedOperands, list<Trait> traits,
+class NVVM_IntrOp<string mnem, list<Trait> traits,
                   int numResults>
   : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
-                    overloadedResults, overloadedOperands, traits, numResults>;
+                    /*list<int> overloadedResults=*/[],
+                    /*list<int> overloadedOperands=*/[],
+                    traits, numResults>;
 
 
 //===----------------------------------------------------------------------===//
 // NVVM special register op definitions
 //===----------------------------------------------------------------------===//
 
-class NVVM_SpecialRegisterOp<string mnemonic,
-    list<Trait> traits = []> :
-  NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>,
-  Arguments<(ins)> {
+class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> {
+  let arguments = (ins);
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
@@ -92,6 +92,16 @@
 def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
 def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
 
+//===----------------------------------------------------------------------===//
+// NVVM approximate op definitions
+//===----------------------------------------------------------------------===//
+
+def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> {
+  let arguments = (ins F32:$arg);
+  let results = (outs F32:$res);
+  let assemblyFormat = "$arg attr-dict `:` type($res)";
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM synchronization op definitions
 //===----------------------------------------------------------------------===//
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to