Author: Hanhan Wang Date: 2021-01-21T22:20:32-08:00 New Revision: 2cb130f7661176f2c2eaa7554f2a55863cfc0ed3
URL: https://github.com/llvm/llvm-project/commit/2cb130f7661176f2c2eaa7554f2a55863cfc0ed3 DIFF: https://github.com/llvm/llvm-project/commit/2cb130f7661176f2c2eaa7554f2a55863cfc0ed3.diff LOG: [mlir][StandardToSPIRV] Add support for lowering uitofp to SPIR-V - Extend spirv::ConstantOp::getZero/One to handle float, vector of int, and vector of float. - Refactor ZeroExtendI1Pattern to use getZero/One methods. - Add one more test for lowering std.zexti which extends vector<4xi1> to vector<4xi64>. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D95120 Added: Modified: mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir Removed: ################################################################################ diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index 72b8c5811695..95bb0eca4496 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -481,16 +481,32 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> { auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); - Attribute zeroAttr, oneAttr; - if (auto vectorType = dstType.dyn_cast<VectorType>()) { - zeroAttr = DenseElementsAttr::get(vectorType, 0); - oneAttr = DenseElementsAttr::get(vectorType, 1); - } else { - zeroAttr = IntegerAttr::get(dstType, 0); - oneAttr = IntegerAttr::get(dstType, 1); - } - Value zero = rewriter.create<ConstantOp>(loc, zeroAttr); - Value one = rewriter.create<ConstantOp>(loc, oneAttr); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp<spirv::SelectOp>( + op, dstType, operands.front(), one, zero); + return success(); + } +}; + +/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of +/// i1. +class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> { +public: + using OpConversionPattern<UIToFPOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = operands.front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp<spirv::SelectOp>( op, dstType, operands.front(), one, zero); return success(); @@ -1098,8 +1114,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context, ReturnOpPattern, SelectOpPattern, // Type cast patterns - ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>, + UIToFPI1Pattern, ZeroExtendI1Pattern, + TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>, TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>, + TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>, TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>, TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>, TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index c90895197f43..3d99696d6882 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -25,6 +25,8 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" @@ -1581,6 +1583,25 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, return builder.create<spirv::ConstantOp>( loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } + if (auto floatType = type.dyn_cast<FloatType>()) { + return builder.create<spirv::ConstantOp>( + loc, type, builder.getFloatAttr(floatType, 0.0)); + } + if (auto vectorType = type.dyn_cast<VectorType>()) { + Type elemType = vectorType.getElementType(); + if (elemType.isa<IntegerType>()) { + return builder.create<spirv::ConstantOp>( + loc, type, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, 0.0).getValue())); + } + if (elemType.isa<FloatType>()) { + return builder.create<spirv::ConstantOp>( + loc, type, + DenseFPElementsAttr::get(vectorType, + FloatAttr::get(elemType, 0.0).getValue())); + } + } llvm_unreachable("unimplemented types for ConstantOp::getZero()"); } @@ -1595,6 +1616,25 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, return builder.create<spirv::ConstantOp>( loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } + if (auto floatType = type.dyn_cast<FloatType>()) { + return builder.create<spirv::ConstantOp>( + loc, type, builder.getFloatAttr(floatType, 1.0)); + } + if (auto vectorType = type.dyn_cast<VectorType>()) { + Type elemType = vectorType.getElementType(); + if (elemType.isa<IntegerType>()) { + return builder.create<spirv::ConstantOp>( + loc, type, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, 1.0).getValue())); + } + if (elemType.isa<FloatType>()) { + return builder.create<spirv::ConstantOp>( + loc, type, + DenseFPElementsAttr::get(vectorType, + FloatAttr::get(elemType, 1.0).getValue())); + } + } llvm_unreachable("unimplemented types for ConstantOp::getOne()"); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 633fdbc03550..252bc3eb5095 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -568,6 +568,58 @@ func @sitofp2(%arg0 : i64) -> f64 { return %0 : f64 } +// CHECK-LABEL: @uitofp_i16_f32 +func @uitofp_i16_f32(%arg0: i16) -> f32 { + // CHECK: spv.ConvertUToF %{{.*}} : i16 to f32 + %0 = std.uitofp %arg0 : i16 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i32_f32 +func @uitofp_i32_f32(%arg0 : i32) -> f32 { + // CHECK: spv.ConvertUToF %{{.*}} : i32 to f32 + %0 = std.uitofp %arg0 : i32 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i1_f32 +func @uitofp_i1_f32(%arg0 : i1) -> f32 { + // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f32 + // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f32 + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32 + %0 = std.uitofp %arg0 : i1 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i1_f64 +func @uitofp_i1_f64(%arg0 : i1) -> f64 { + // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f64 + // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f64 + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f64 + %0 = std.uitofp %arg0 : i1 to f64 + return %0 : f64 +} + +// CHECK-LABEL: @uitofp_vec_i1_f32 +func @uitofp_vec_i1_f32(%arg0 : vector<4xi1>) -> vector<4xf32> { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32> + %0 = std.uitofp %arg0 : vector<4xi1> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @uitofp_vec_i1_f64 +spv.func @uitofp_vec_i1_f64(%arg0: vector<4xi1>) -> vector<4xf64> "None" { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf64> + // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf64> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf64> + %0 = spv.constant dense<0.000000e+00> : vector<4xf64> + %1 = spv.constant dense<1.000000e+00> : vector<4xf64> + %2 = spv.Select %arg0, %1, %0 : vector<4xi1>, vector<4xf64> + spv.ReturnValue %2 : vector<4xf64> +} + // CHECK-LABEL: @zexti1 func @zexti1(%arg0: i16) -> i64 { // CHECK: spv.UConvert %{{.*}} : i16 to i64 @@ -600,6 +652,15 @@ func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> { return %0 : vector<4xi32> } +// CHECK-LABEL: @zexti5 +func @zexti5(%arg0 : vector<4xi1>) -> vector<4xi64> { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi64> + // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi64> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi64> + %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi64> + return %0 : vector<4xi64> +} + // CHECK-LABEL: @trunci1 func @trunci1(%arg0 : i64) -> i16 { // CHECK: spv.SConvert %{{.*}} : i64 to i16 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits