Module: Mesa
Branch: main
Commit: 7aa060ec3800c3f6956eb319f7dc7930b473d060
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=7aa060ec3800c3f6956eb319f7dc7930b473d060

Author: Jesse Natalie <[email protected]>
Date:   Mon Apr 19 05:38:42 2021 -0700

microsoft/clc: Add a test for specializing via SPIRV-Tools

Acked-by: Lionel Landwerlin <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10322>

---

 src/microsoft/clc/clc_compiler_test.cpp | 79 +++++++++++++++++++++++++++++++++
 src/microsoft/clc/compute_test.cpp      | 25 +++++++++++
 src/microsoft/clc/compute_test.h        | 30 +++++++++++++
 src/microsoft/clc/meson.build           |  4 +-
 4 files changed, 136 insertions(+), 2 deletions(-)

diff --git a/src/microsoft/clc/clc_compiler_test.cpp 
b/src/microsoft/clc/clc_compiler_test.cpp
index 2b3a8da5828..f5155c96c49 100644
--- a/src/microsoft/clc/clc_compiler_test.cpp
+++ b/src/microsoft/clc/clc_compiler_test.cpp
@@ -2232,3 +2232,82 @@ TEST_F(ComputeTest, unused_arg)
    for (int i = 0; i < 4; ++i)
       EXPECT_EQ(dest[i], i + 1);
 }
+
+TEST_F(ComputeTest, spec_constant)
+{
+   const char *spirv_asm = R"(
+               OpCapability Addresses
+               OpCapability Kernel
+               OpCapability Int64
+          %1 = OpExtInstImport "OpenCL.std"
+               OpMemoryModel Physical64 OpenCL
+               OpEntryPoint Kernel %2 "main_test" 
%__spirv_BuiltInGlobalInvocationId
+          %4 = OpString "kernel_arg_type.main_test.uint*,"
+               OpSource OpenCL_C 102000
+               OpName %__spirv_BuiltInGlobalInvocationId 
"__spirv_BuiltInGlobalInvocationId"
+               OpName %output "output"
+               OpName %entry "entry"
+               OpName %output_addr "output.addr"
+               OpName %id "id"
+               OpName %call "call"
+               OpName %conv "conv"
+               OpName %idxprom "idxprom"
+               OpName %arrayidx "arrayidx"
+               OpName %add "add"
+               OpName %mul "mul"
+               OpName %idxprom1 "idxprom1"
+               OpName %arrayidx2 "arrayidx2"
+               OpDecorate %__spirv_BuiltInGlobalInvocationId BuiltIn 
GlobalInvocationId
+               OpDecorate %__spirv_BuiltInGlobalInvocationId Constant
+               OpDecorate %id Alignment 4
+               OpDecorate %output_addr Alignment 8
+               OpDecorate %uint_1 SpecId 1
+      %ulong = OpTypeInt 64 0
+       %uint = OpTypeInt 32 0
+     %uint_1 = OpSpecConstant %uint 1
+    %v3ulong = OpTypeVector %ulong 3
+%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
+       %void = OpTypeVoid
+%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
+         %24 = OpTypeFunction %void %_ptr_CrossWorkgroup_uint
+%_ptr_Function__ptr_CrossWorkgroup_uint = OpTypePointer Function 
%_ptr_CrossWorkgroup_uint
+%_ptr_Function_uint = OpTypePointer Function %uint
+%__spirv_BuiltInGlobalInvocationId = OpVariable %_ptr_Input_v3ulong Input
+          %2 = OpFunction %void DontInline %24
+     %output = OpFunctionParameter %_ptr_CrossWorkgroup_uint
+      %entry = OpLabel
+%output_addr = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uint Function
+         %id = OpVariable %_ptr_Function_uint Function
+               OpStore %output_addr %output Aligned 8
+         %27 = OpLoad %v3ulong %__spirv_BuiltInGlobalInvocationId Aligned 32
+       %call = OpCompositeExtract %ulong %27 0
+       %conv = OpUConvert %uint %call
+               OpStore %id %conv Aligned 4
+         %28 = OpLoad %_ptr_CrossWorkgroup_uint %output_addr Aligned 8
+         %29 = OpLoad %uint %id Aligned 4
+    %idxprom = OpUConvert %ulong %29
+   %arrayidx = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uint %28 %idxprom
+         %30 = OpLoad %uint %arrayidx Aligned 4
+         %31 = OpLoad %uint %id Aligned 4
+        %add = OpIAdd %uint %31 %uint_1
+        %mul = OpIMul %uint %30 %add
+         %32 = OpLoad %_ptr_CrossWorkgroup_uint %output_addr Aligned 8
+         %33 = OpLoad %uint %id Aligned 4
+   %idxprom1 = OpUConvert %ulong %33
+  %arrayidx2 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uint %32 %idxprom1
+               OpStore %arrayidx2 %mul Aligned 4
+               OpReturn
+               OpFunctionEnd)";
+   Shader shader = assemble(spirv_asm);
+   Shader spec_shader = specialize(shader, 1, 5);
+
+   auto inout = ShaderArg<uint32_t>({ 0x00000001, 0x10000001, 0x00020002, 
0x04010203 },
+      SHADER_ARG_INOUT);
+   const uint32_t expected[] = {
+      0x00000005, 0x60000006, 0x000e000e, 0x20081018
+   };
+   CompileArgs args = { inout.size(), 1, 1 };
+   run_shader(spec_shader, args, inout);
+   for (int i = 0; i < inout.size(); ++i)
+      EXPECT_EQ(inout[i], expected[i]);
+}
diff --git a/src/microsoft/clc/compute_test.cpp 
b/src/microsoft/clc/compute_test.cpp
index dc7d8507198..e5faa4f5dde 100644
--- a/src/microsoft/clc/compute_test.cpp
+++ b/src/microsoft/clc/compute_test.cpp
@@ -35,6 +35,8 @@
 #include "compute_test.h"
 #include "dxcapi.h"
 
+#include <spirv-tools/libspirv.hpp>
+
 using std::runtime_error;
 using Microsoft::WRL::ComPtr;
 
@@ -850,6 +852,29 @@ ComputeTest::link(const std::vector<Shader> &sources,
    return shader;
 }
 
+ComputeTest::Shader
+ComputeTest::assemble(const char *source)
+{
+   spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+   std::vector<uint32_t> binary;
+   if (!tools.Assemble(source, strlen(source), &binary))
+      throw runtime_error("failed to assemble");
+
+   ComputeTest::Shader shader;
+   shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary 
*spirv)
+      {
+         free(spirv->data);
+         delete spirv;
+      });
+   shader.obj->size = binary.size() * 4;
+   shader.obj->data = malloc(shader.obj->size);
+   memcpy(shader.obj->data, binary.data(), shader.obj->size);
+
+   configure(shader, NULL);
+
+   return shader;
+}
+
 void
 ComputeTest::configure(Shader &shader,
                        const struct clc_runtime_kernel_conf *conf)
diff --git a/src/microsoft/clc/compute_test.h b/src/microsoft/clc/compute_test.h
index c915e74d554..9ac741fc16f 100644
--- a/src/microsoft/clc/compute_test.h
+++ b/src/microsoft/clc/compute_test.h
@@ -167,6 +167,9 @@ protected:
    link(const std::vector<Shader> &sources,
         bool create_library = false);
 
+   Shader
+   assemble(const char *source);
+
    void
    configure(Shader &shader,
              const struct clc_runtime_kernel_conf *conf);
@@ -174,6 +177,33 @@ protected:
    void
    validate(Shader &shader);
 
+   template <typename T>
+   Shader
+   specialize(Shader &shader, uint32_t id, T const& val)
+   {
+      Shader new_shader;
+      new_shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, 
[](clc_binary *spirv)
+         {
+            clc_free_spirv(spirv);
+            delete spirv;
+         });
+      if (!shader.metadata)
+         configure(shader, NULL);
+
+      clc_spirv_specialization spec;
+      spec.id = id;
+      memcpy(&spec.value, &val, sizeof(val));
+      clc_spirv_specialization_consts consts;
+      consts.specializations = &spec;
+      consts.num_specializations = 1;
+      if (!clc_specialize_spirv(shader.obj.get(), shader.metadata.get(), 
&consts, new_shader.obj.get()))
+         throw runtime_error("failed to specialize");
+
+      configure(new_shader, NULL);
+
+      return new_shader;
+   }
+
    enum ShaderArgDirection {
       SHADER_ARG_INPUT = 1,
       SHADER_ARG_OUTPUT = 2,
diff --git a/src/microsoft/clc/meson.build b/src/microsoft/clc/meson.build
index 35a88baffeb..05b48b6a813 100644
--- a/src/microsoft/clc/meson.build
+++ b/src/microsoft/clc/meson.build
@@ -63,8 +63,8 @@ if dep_dxheaders.found()
   clc_compiler_test = executable('clc_compiler_test',
     ['clc_compiler_test.cpp', 'compute_test.cpp'],
     link_with : [libclc_compiler],
-    dependencies : [idep_gtest, idep_mesautil, idep_libdxil_compiler, 
dep_dxheaders],
-    include_directories : [inc_include, inc_src])
+    dependencies : [idep_gtest, idep_mesautil, idep_libdxil_compiler, 
dep_dxheaders, dep_spirv_tools],
+    include_directories : [inc_include, inc_src, inc_spirv])
 
   test('clc_compiler_test', clc_compiler_test, timeout: 180)
 endif

Reply via email to