antiagainst updated this revision to Diff 241431.
antiagainst marked 16 inline comments as done.
antiagainst added a comment.

Address comments


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D73437/new/

https://reviews.llvm.org/D73437

Files:
  mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
  mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h
  mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
  mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
  mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
  mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
  mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
  mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
  mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
  mlir/lib/Conversion/CMakeLists.txt
  mlir/lib/Conversion/LinalgToSPIRV/CMakeLists.txt
  mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
  mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
  mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
  mlir/lib/Dialect/Linalg/Utils/Utils.cpp
  mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
  mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
  mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
  mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
  mlir/tools/mlir-opt/CMakeLists.txt

Index: mlir/tools/mlir-opt/CMakeLists.txt
===================================================================
--- mlir/tools/mlir-opt/CMakeLists.txt
+++ mlir/tools/mlir-opt/CMakeLists.txt
@@ -40,6 +40,7 @@
   MLIRQuantOps
   MLIRROCDLIR
   MLIRSPIRV
+  MLIRLinalgToSPIRVTransforms
   MLIRStandardToSPIRVTransforms
   MLIRSPIRVTestPasses
   MLIRSPIRVTransforms
Index: mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
===================================================================
--- /dev/null
+++ mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt -split-input-file -convert-linalg-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Single workgroup reduction
+//===----------------------------------------------------------------------===//
+
+#single_workgroup_reduction_trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["reduction"],
+  indexing_maps = [
+    affine_map<(i) -> (i)>,
+    affine_map<(i) -> (0)>
+  ]
+}
+
+module attributes {
+  spv.target_env = {
+    version = 3 : i32,
+    extensions = [],
+    capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic
+  }
+} {
+
+// CHECK:      spv.globalVariable
+// CHECK-SAME: built_in("LocalInvocationId")
+
+// CHECK:      func @single_workgroup_reduction
+// CHECK-SAME: (%[[INPUT:.+]]: !spv.ptr{{.+}}, %[[OUTPUT:.+]]: !spv.ptr{{.+}})
+
+// CHECK:        %[[ZERO:.+]] = spv.constant 0 : i32
+// CHECK:        %[[ID:.+]] = spv.Load "Input" %{{.+}} : vector<3xi32>
+// CHECK:        %[[X:.+]] = spv.CompositeExtract %[[ID]][0 : i32]
+
+// CHECK:        %[[INPTR:.+]] = spv.AccessChain %[[INPUT]][%[[ZERO]], %[[X]]]
+// CHECK:        %[[VAL:.+]] = spv.Load "StorageBuffer" %[[INPTR]] : i32
+// CHECK:        %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32
+
+// CHECK:        %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]]
+// CHECK:        %[[ELECT:.+]] = spv.GroupNonUniformElect "Subgroup" : i1
+
+// CHECK:        spv.selection {
+// CHECK:          spv.BranchConditional %[[ELECT]], ^bb1, ^bb2
+// CHECK:        ^bb1:
+// CHECK:          spv.AtomicIAdd "Device" "AcquireRelease" %[[OUTPTR]], %[[ADD]]
+// CHECK:          spv.Branch ^bb2
+// CHECK:        ^bb2:
+// CHECK:          spv._merge
+// CHECK:        }
+// CHECK:        spv.Return
+
+func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
+  spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}
+} {
+  linalg.generic #single_workgroup_reduction_trait %input, %output {
+    ^bb(%in: i32, %out: i32):
+      %sum = addi %in, %out : i32
+      linalg.yield %sum : i32
+  } : memref<16xi32>, memref<1xi32>
+  spv.Return
+}
+}
+
+// -----
+
+// Missing shader entry point ABI
+
+#single_workgroup_reduction_trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["reduction"],
+  indexing_maps = [
+    affine_map<(i) -> (i)>,
+    affine_map<(i) -> (0)>
+  ]
+}
+
+module attributes {
+  spv.target_env = {
+    version = 3 : i32,
+    extensions = [],
+    capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic
+  }
+} {
+func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) {
+  // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
+  linalg.generic #single_workgroup_reduction_trait %input, %output {
+    ^bb(%in: i32, %out: i32):
+      %sum = addi %in, %out : i32
+      linalg.yield %sum : i32
+  } : memref<16xi32>, memref<1xi32>
+  return
+}
+}
+
+// -----
+
+// Mismatch between shader entry point ABI and input memref shape
+
+#single_workgroup_reduction_trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["reduction"],
+  indexing_maps = [
+    affine_map<(i) -> (i)>,
+    affine_map<(i) -> (0)>
+  ]
+}
+
+module attributes {
+  spv.target_env = {
+    version = 3 : i32,
+    extensions = [],
+    capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic
+  }
+} {
+func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
+  spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>}
+} {
+  // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
+  linalg.generic #single_workgroup_reduction_trait %input, %output {
+    ^bb(%in: i32, %out: i32):
+      %sum = addi %in, %out : i32
+      linalg.yield %sum : i32
+  } : memref<16xi32>, memref<1xi32>
+  spv.Return
+}
+}
+
+// -----
+
+// Unsupported multi-dimension input memref
+
+#single_workgroup_reduction_trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "reduction"],
+  indexing_maps = [
+    affine_map<(i, j) -> (i, j)>,
+    affine_map<(i, j) -> (i)>
+  ]
+}
+
+module attributes {
+  spv.target_env = {
+    version = 3 : i32,
+    extensions = [],
+    capabilities = [1: i32, 63: i32] // Shader, GroupNonUniformArithmetic
+  }
+} {
+func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes {
+  spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>}
+} {
+  // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
+  linalg.generic #single_workgroup_reduction_trait %input, %output {
+    ^bb(%in: i32, %out: i32):
+      %sum = addi %in, %out : i32
+      linalg.yield %sum : i32
+  } : memref<16x8xi32>, memref<16xi32>
+  spv.Return
+}
+}
Index: mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
===================================================================
--- mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/FunctionSupport.h"
 #include "mlir/IR/Operation.h"
 
 using namespace mlir;
@@ -62,3 +63,16 @@
     return attr;
   return getDefaultTargetEnv(op->getContext());
 }
+
+DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
+  while (op && !op->hasTrait<OpTrait::FunctionLike>())
+    op = op->getParentOp();
+  if (!op)
+    return {};
+
+  if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
+          spirv::getEntryPointABIAttrName()))
+    return attr.local_size();
+
+  return {};
+}
Index: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
===================================================================
--- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2755,6 +2755,38 @@
   builder.create<spirv::MergeOp>(getLoc());
 }
 
+spirv::SelectionOp spirv::SelectionOp::createIfThen(
+    Location loc, Value condition,
+    function_ref<void(OpBuilder *builder)> thenBody, OpBuilder *builder) {
+  auto selectionControl = builder->getI32IntegerAttr(
+      static_cast<uint32_t>(spirv::SelectionControl::None));
+  auto selectionOp = builder->create<spirv::SelectionOp>(loc, selectionControl);
+
+  selectionOp.addMergeBlock();
+  Block *mergeBlock = selectionOp.getMergeBlock();
+  Block *thenBlock = nullptr;
+
+  // Build the "then" block.
+  {
+    OpBuilder::InsertionGuard guard(*builder);
+    thenBlock = builder->createBlock(mergeBlock);
+    thenBody(builder);
+    builder->create<spirv::BranchOp>(loc, mergeBlock);
+  }
+
+  // Build the header block.
+  {
+    OpBuilder::InsertionGuard guard(*builder);
+    builder->createBlock(thenBlock);
+    builder->create<spirv::BranchConditionalOp>(
+        loc, condition, thenBlock,
+        /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
+        /*falseArguments=*/ArrayRef<Value>());
+  }
+
+  return selectionOp;
+}
+
 namespace {
 // Blocks from the given `spv.selection` operation must satisfy the following
 // layout:
Index: mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
===================================================================
--- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -292,6 +292,41 @@
   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
 }
 
+//===----------------------------------------------------------------------===//
+// Index calculation
+//===----------------------------------------------------------------------===//
+
+spirv::AccessChainOp mlir::spirv::getElementPtr(
+    SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
+    ArrayRef<Value> indices, Location loc, OpBuilder &builder) {
+  // Get base and offset of the MemRefType and verify they are static.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+    return nullptr;
+  }
+
+  auto indexType = typeConverter.getIndexType(builder.getContext());
+
+  Value ptrLoc = nullptr;
+  assert(indices.size() == strides.size());
+  for (auto index : enumerate(indices)) {
+    Value strideVal = builder.create<spirv::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    ptrLoc =
+        (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
+                : update);
+  }
+  SmallVector<Value, 2> linearizedIndices;
+  // Add a '0' at the start to index into the struct.
+  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
+      loc, indexType, IntegerAttr::get(indexType, 0)));
+  linearizedIndices.push_back(ptrLoc);
+  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+}
+
 //===----------------------------------------------------------------------===//
 // Set ABI attributes for lowering entry functions.
 //===----------------------------------------------------------------------===//
Index: mlir/lib/Dialect/Linalg/Utils/Utils.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -19,6 +19,7 @@
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/STLExtras.h"
@@ -31,6 +32,33 @@
 using namespace mlir::linalg::intrinsics;
 using namespace mlir::loop;
 
+Optional<RegionMatcher::BinaryOpKind>
+RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
+  auto &region = op.region();
+  if (!has_single_element(region))
+    return llvm::None;
+
+  Block &block = region.front();
+  if (block.getNumArguments() != 2 ||
+      !block.getArgument(0).getType().isIntOrFloat() ||
+      !block.getArgument(1).getType().isIntOrFloat())
+    return llvm::None;
+
+  auto &ops = block.getOperations();
+  if (!has_single_element(block.without_terminator()))
+    return llvm::None;
+
+  using mlir::matchers::m_Val;
+  auto a = m_Val(block.getArgument(0));
+  auto b = m_Val(block.getArgument(1));
+
+  auto addPattern = m_Op<linalg::YieldOp>(m_Op<AddIOp>(a, b));
+  if (addPattern.match(&ops.back()))
+    return BinaryOpKind::IAdd;
+
+  return llvm::None;
+}
+
 static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
                                            AffineMap map,
                                            ArrayRef<Value> operandsRef,
Index: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
===================================================================
--- mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -141,48 +141,6 @@
 
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// Utility functions for operation conversion
-//===----------------------------------------------------------------------===//
-
-/// Performs the index computation to get to the element pointed to by
-/// `indices` using the layout map of `baseType`.
-
-// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
-// MemRefType with AffineMap that has static strides. Handle dynamic strides
-static spirv::AccessChainOp getElementPtr(OpBuilder &builder,
-                                          SPIRVTypeConverter &typeConverter,
-                                          Location loc, MemRefType origBaseType,
-                                          Value basePtr,
-                                          ArrayRef<Value> indices) {
-  // Get base and offset of the MemRefType and verify they are static.
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
-      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
-    return nullptr;
-  }
-
-  auto indexType = typeConverter.getIndexType(builder.getContext());
-
-  Value ptrLoc = nullptr;
-  assert(indices.size() == strides.size());
-  for (auto index : enumerate(indices)) {
-    Value strideVal = builder.create<spirv::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
-    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
-    ptrLoc =
-        (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
-                : update);
-  }
-  SmallVector<Value, 2> linearizedIndices;
-  // Add a '0' at the start to index into the struct.
-  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
-      loc, indexType, IntegerAttr::get(indexType, 0)));
-  linearizedIndices.push_back(ptrLoc);
-  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
-}
-
 //===----------------------------------------------------------------------===//
 // ConstantOp with composite type.
 //===----------------------------------------------------------------------===//
@@ -331,9 +289,9 @@
 LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
                                   ConversionPatternRewriter &rewriter) const {
   LoadOpOperandAdaptor loadOperands(operands);
-  auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
-                               loadOp.memref().getType().cast<MemRefType>(),
-                               loadOperands.memref(), loadOperands.indices());
+  auto loadPtr = spirv::getElementPtr(
+      typeConverter, loadOp.memref().getType().cast<MemRefType>(),
+      loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
   return matchSuccess();
 }
@@ -374,10 +332,10 @@
 StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
                                    ConversionPatternRewriter &rewriter) const {
   StoreOpOperandAdaptor storeOperands(operands);
-  auto storePtr =
-      getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
-                    storeOp.memref().getType().cast<MemRefType>(),
-                    storeOperands.memref(), storeOperands.indices());
+  auto storePtr = spirv::getElementPtr(
+      typeConverter, storeOp.memref().getType().cast<MemRefType>(),
+      storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(),
+      rewriter);
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                               storeOperands.value());
   return matchSuccess();
Index: mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
===================================================================
--- /dev/null
+++ mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -0,0 +1,51 @@
+//===- LinalgToSPIRVPass.cpp - Linalg to SPIR-V conversion pass -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
+#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass converting MLIR Linalg ops into SPIR-V ops.
+class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> {
+  void runOnModule() override;
+};
+} // namespace
+
+void LinalgToSPIRVPass::runOnModule() {
+  MLIRContext *context = &getContext();
+  ModuleOp module = getModule();
+
+  SPIRVTypeConverter typeConverter;
+  OwningRewritePatternList patterns;
+  populateLinalgToSPIRVPatterns(context, typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+
+  auto targetEnv = spirv::lookupTargetEnvOrDefault(module);
+  std::unique_ptr<ConversionTarget> target =
+      spirv::SPIRVConversionTarget::get(targetEnv, context);
+
+  // Allow builtin ops.
+  target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  target->addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
+
+  if (failed(applyFullConversion(module, *target, patterns)))
+    return signalPassFailure();
+}
+
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLinalgToSPIRVPass() {
+  return std::make_unique<LinalgToSPIRVPass>();
+}
+
+static PassRegistration<LinalgToSPIRVPass>
+    pass("convert-linalg-to-spirv", "Convert Linalg ops to SPIR-V ops");
Index: mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
===================================================================
--- /dev/null
+++ mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -0,0 +1,210 @@
+//===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
+/// location invocation ID. This function will create necessary operations with
+/// `builder` at the proper region containing `op`.
+static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
+                                       OpBuilder *builder) {
+  assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
+  Value invocation = spirv::getBuiltinVariableValue(
+      op, spirv::BuiltIn::LocalInvocationId, *builder);
+  Type xType = invocation.getType().cast<ShapedType>().getElementType();
+  return builder->create<spirv::CompositeExtractOp>(
+      loc, xType, invocation, builder->getI32ArrayAttr({dim}));
+}
+
+
+//===----------------------------------------------------------------------===//
+// Reduction (single workgroup)
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
+/// that the linalg.generic op is performing reduction with a workload size that
+/// can fit in one workgroup.
+class SingleWorkgroupReduction final
+    : public SPIRVOpLowering<linalg::GenericOp> {
+public:
+  using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;
+
+  /// Matches the given linalg.generic op as performing reduction and returns
+  /// the binary op kind if successful.
+  static Optional<linalg::RegionMatcher::BinaryOpKind>
+  matchAsPerformingReduction(linalg::GenericOp genericOp);
+
+  PatternMatchResult
+  matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+Optional<linalg::RegionMatcher::BinaryOpKind>
+SingleWorkgroupReduction::matchAsPerformingReduction(
+    linalg::GenericOp genericOp) {
+  Operation *op = genericOp.getOperation();
+
+  // Make sure the linalg.generic is working on memrefs.
+  if (!genericOp.hasBufferSemantics())
+    return llvm::None;
+
+  // Make sure this is reudction with one input and one output.
+  if (genericOp.args_in().getZExtValue() != 1 ||
+      genericOp.args_out().getZExtValue() != 1)
+    return llvm::None;
+
+  auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
+  auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
+
+  // Make sure the original input has one dimension.
+  if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
+    return llvm::None;
+  // Make sure the original output has one element.
+  if (!originalOutputType.hasStaticShape() ||
+      originalOutputType.getNumElements() != 1)
+    return llvm::None;
+
+  if (!genericOp.hasSingleReductionLoop())
+    return llvm::None;
+
+  if (genericOp.indexing_maps().getValue().size() != 2)
+    return llvm::None;
+
+  // TODO(nicolasvasilache): create utility functions for these checks in Linalg
+  // and use them.
+  auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
+  auto outputMap =
+      genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
+  // The indexing map for the input should be `(i) -> (i)`.
+  if (inputMap.getValue() !=
+      AffineMap::get(1, 0, {getAffineDimExpr(0, op->getContext())}))
+    return llvm::None;
+  // The indexing map for the input should be `(i) -> (0)`.
+  if (outputMap.getValue() !=
+      AffineMap::get(1, 0, {getAffineConstantExpr(0, op->getContext())}))
+    return llvm::None;
+
+  return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
+}
+
+PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
+    linalg::GenericOp genericOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  Operation *op = genericOp.getOperation();
+  auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
+  auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
+
+  auto binaryOpKind = matchAsPerformingReduction(genericOp);
+  if (!binaryOpKind)
+    return matchFailure();
+
+  // Query the shader interface for local workgroup size to make sure the
+  // invocation configuration fits with the input memref's shape.
+  DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
+  if (!localSize)
+    return matchFailure();
+
+  if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
+    return matchFailure();
+  if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
+                   [](const APInt &size) { return !size.isOneValue(); }))
+    return matchFailure();
+
+  // TODO(antiagainst): Query the target environment to make sure the current
+  // workload fits in a local workgroup.
+
+  Value convertedInput = operands[0], convertedOutput = operands[1];
+  Location loc = genericOp.getLoc();
+
+  // Get the invocation ID.
+  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
+
+  // TODO(antiagainst): Load to Workgroup storage class first.
+
+  // Get the input element accessed by this invocation.
+  Value inputElementPtr = spirv::getElementPtr(
+      typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
+  Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
+
+  // Perform the group reduction operation.
+  Value groupOperation;
+#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp)                         \
+  case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
+    groupOperation = rewriter.create<spirv::spvOp>(                            \
+        loc, originalInputType.getElementType(), spirv::Scope::Subgroup,       \
+        spirv::GroupOperation::Reduce, inputElement,                           \
+        /*cluster_size=*/ArrayRef<Value>());                                   \
+  } break
+  switch (*binaryOpKind) {
+    CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
+  }
+#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
+
+  // Get the output element accessed by this reduction.
+  Value zero = spirv::ConstantOp::getZero(
+      typeConverter.getIndexType(rewriter.getContext()), loc, &rewriter);
+  SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
+  Value outputElementPtr =
+      spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
+                           zeroIndices, loc, rewriter);
+
+  // Write out the final reduction result. This should be only conducted by one
+  // invocation. We use spv.GroupNonUniformElect to find the invocation with the
+  // lowest ID.
+  //
+  // ```
+  // if (spv.GroupNonUniformElect) { output = ... }
+  // ```
+
+  Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
+      loc, spirv::Scope::Subgroup);
+
+  auto createAtomicOp = [&](OpBuilder *builder) {
+#define CREATE_ATOMIC_BIN_OP(opKind, spvOp)                                    \
+  case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
+    builder->create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
+                                  spirv::MemorySemantics::AcquireRelease,      \
+                                  groupOperation);                             \
+  } break
+    switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
+#undef CREATE_ATOMIC_BIN_OP
+  };
+
+  spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter);
+
+  rewriter.eraseOp(genericOp);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
+                                         SPIRVTypeConverter &typeConverter,
+                                         OwningRewritePatternList &patterns) {
+  patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
+}
Index: mlir/lib/Conversion/LinalgToSPIRV/CMakeLists.txt
===================================================================
--- /dev/null
+++ mlir/lib/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_llvm_library(MLIRLinalgToSPIRVTransforms
+  LinalgToSPIRV.cpp
+  LinalgToSPIRVPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+  )
+
+target_link_libraries(MLIRLinalgToSPIRVTransforms
+  MLIRIR
+  MLIRLinalgOps
+  MLIRPass
+  MLIRSPIRV
+  MLIRSupport
+  )
Index: mlir/lib/Conversion/CMakeLists.txt
===================================================================
--- mlir/lib/Conversion/CMakeLists.txt
+++ mlir/lib/Conversion/CMakeLists.txt
@@ -4,6 +4,7 @@
 add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(LinalgToLLVM)
+add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LoopsToGPU)
 add_subdirectory(LoopToStandard)
 add_subdirectory(StandardToLLVM)
Index: mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
===================================================================
--- mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -54,6 +54,12 @@
 /// target environment (SPIR-V 1.0 with Shader capability and no extra
 /// extensions) if not provided.
 TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+
+/// Queries the local workgroup size from entry point ABI on the nearest
+/// function-like op containing the given `op`. Returns null attribute if not
+/// found.
+DenseIntElementsAttr lookupLocalWorkGroupSize(Operation *op);
+
 } // namespace spirv
 } // namespace mlir
 
Index: mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
===================================================================
--- mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -58,6 +58,8 @@
                                         OwningRewritePatternList &patterns);
 
 namespace spirv {
+class AccessChainOp;
+
 class SPIRVConversionTarget : public ConversionTarget {
 public:
   /// Creates a SPIR-V conversion target for the given target environment.
@@ -90,6 +92,16 @@
 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
                               OpBuilder &builder);
 
+/// Performs the index computation to get to the element at `indices` of the
+/// memory pointed to by `basePtr`, using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `baseType` is a MemRefType
+// with AffineMap that has static strides. Extend to handle dynamic strides.
+spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
+                                   MemRefType baseType, Value basePtr,
+                                   ArrayRef<Value> indices, Location loc,
+                                   OpBuilder &builder);
+
 /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
 /// arguments.
 LogicalResult setABIAttrs(FuncOp funcOp, EntryPointABIAttr entryPointInfo,
Index: mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
===================================================================
--- mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -446,14 +446,22 @@
   let regions = (region AnyRegion:$body);
 
   let extraClassDeclaration = [{
-    // Returns the selection header block.
+    /// Returns the selection header block.
     Block *getHeaderBlock();
 
-    // Returns the selection merge block.
+    /// Returns the selection merge block.
     Block *getMergeBlock();
 
-    // Adds a selection merge block containing one spv._merge op.
+    /// Adds a selection merge block containing one spv._merge op.
     void addMergeBlock();
+
+    /// Creates a spv.selection op for `if (<condition>) then { <thenBody> }`
+    /// with `builder`. `builder`'s insertion point will remain at after the
+    /// newly inserted spv.selection op afterwards.
+    static SelectionOp createIfThen(
+        Location loc, Value condition,
+        function_ref<void(OpBuilder *builder)> thenBody,
+        OpBuilder *builder);
   }];
 
   let hasOpcode = 0;
Index: mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
===================================================================
--- mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
+++ mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
@@ -25,6 +25,7 @@
     SPV_ScopeAttr:$memory_scope,
     SPV_MemorySemanticsAttr:$semantics
   );
+
   let results = (outs
     SPV_Integer:$result
   );
@@ -42,9 +43,19 @@
     SPV_MemorySemanticsAttr:$semantics,
     SPV_Integer:$value
   );
+
   let results = (outs
     SPV_Integer:$result
   );
+
+  let builders = [
+    OpBuilder<
+      [{Builder *builder, OperationState &state, Value pointer,
+        ::mlir::spirv::Scope scope, ::mlir::spirv::MemorySemantics memory,
+        Value value}],
+      [{build(builder, state, value.getType(), pointer, scope, memory, value);}]
+    >
+  ];
 }
 
 // -----
Index: mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
===================================================================
--- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -28,6 +28,26 @@
   LinalgOp fusedProducer;
 };
 
+/// A struct containing common matchers over linalg op's region.
+struct RegionMatcher {
+  enum class BinaryOpKind {
+    IAdd,
+  };
+
+  /// Matches the given linalg op if its body is performing binary operation on
+  /// int or float scalar values and returns the binary op kind.
+  ///
+  /// The linalg op's region is expected to be
+  /// ```
+  /// {
+  ///   ^bb(%a: <scalar-type>, %b: <scalar-type>):
+  ///     %0 = <binary-op> %a, %b: <scalar-type>
+  ///     linalg.yield %0: <scalar-type>
+  /// }
+  /// ```
+  static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
+};
+
 /// Checks whether the specific `producer` is the last write to exactly the
 /// whole `consumedView`. This checks structural dominance, that the dependence
 /// is a RAW without any interleaved write to any piece of `consumedView`.
Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
===================================================================
--- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -93,6 +93,12 @@
         cast<ConcreteType>(this->getOperation()).iterator_types());
   }
 
+  bool hasSingleReductionLoop() {
+    auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types();
+    return iterators.size() == 1 &&
+           getNumIterators(getReductionIteratorTypeName(), iterators);
+  }
+
   //==========================================================================//
   // Input arguments handling.
   //==========================================================================//
Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
===================================================================
--- mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -54,6 +54,11 @@
       "Query the number of loops within the current operation.",
       "unsigned", "getNumLoops">,
 
+    InterfaceMethod<
+      [{Returns true if the current operation has only one loop and it's a
+        reduction loop}],
+      "unsigned", "hasSingleReductionLoop">,
+
     //========================================================================//
     // Input arguments handling.
     //========================================================================//
Index: mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h
@@ -0,0 +1,25 @@
+//===- LinalgToSPIRVPass.h -  Linalg to SPIR-V conversion pass --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a pass for Linalg to SPIR-V dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H
+#define MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+/// Creates and returns a pass to convert Linalg ops to SPIR-V ops.
+std::unique_ptr<OpPassBase<ModuleOp>> createLinalgToSPIRVPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H
Index: mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
@@ -0,0 +1,29 @@
+//===- LinalgToSPIRV.h - Linalg to SPIR-V dialect conversion ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides patterns for Linalg to SPIR-V dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
+#define MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
+
+namespace mlir {
+class MLIRContext;
+class OwningRewritePatternList;
+class SPIRVTypeConverter;
+
+/// Appends to a pattern list additional patterns for translating Linalg ops to
+/// SPIR-V ops.
+void populateLinalgToSPIRVPatterns(MLIRContext *context,
+                                   SPIRVTypeConverter &typeConverter,
+                                   OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to