Module: Mesa
Branch: main
Commit: 068c6b5a37a96106da817447e3cef706fc8e90d0
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=068c6b5a37a96106da817447e3cef706fc8e90d0

Author: Jesse Natalie <[email protected]>
Date:   Mon Apr 19 07:14:24 2021 -0700

microsoft/clc: Add API to independently specialize SPIR-V

We need the ability to specialize unlinked SPIR-V, so use SPIR-V tools
to specialize prior to linking.

Acked-by: Lionel Landwerlin <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10322>

---

 src/microsoft/clc/clc_compiler.c     | 15 ++++++++
 src/microsoft/clc/clc_compiler.h     |  7 ++++
 src/microsoft/clc/clc_helpers.cpp    | 68 ++++++++++++++++++++++++++++++++++--
 src/microsoft/clc/clc_helpers.h      |  6 ++++
 src/microsoft/clc/clon12compiler.def |  1 +
 5 files changed, 95 insertions(+), 2 deletions(-)

diff --git a/src/microsoft/clc/clc_compiler.c b/src/microsoft/clc/clc_compiler.c
index 6bf4e7a5785..93bde0afb98 100644
--- a/src/microsoft/clc/clc_compiler.c
+++ b/src/microsoft/clc/clc_compiler.c
@@ -661,6 +661,21 @@ void clc_free_parsed_spirv(struct clc_parsed_spirv *data)
    clc_free_kernels_info(data->kernels, data->num_kernels);
 }
 
+bool
+clc_specialize_spirv(const struct clc_binary *in_spirv,
+                     const struct clc_parsed_spirv *parsed_data,
+                     const struct clc_spirv_specialization_consts *consts,
+                     struct clc_binary *out_spirv)
+{
+   if (!clc_spirv_specialize(in_spirv, parsed_data, consts, out_spirv))
+      return false;
+
+   if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV)
+      clc_dump_spirv(out_spirv, stdout);
+
+   return true;
+}
+
 static nir_variable *
 add_kernel_inputs_var(struct clc_dxil_object *dxil, nir_shader *nir,
                       unsigned *cbv_id)
diff --git a/src/microsoft/clc/clc_compiler.h b/src/microsoft/clc/clc_compiler.h
index 301a0166983..9029e6e640d 100644
--- a/src/microsoft/clc/clc_compiler.h
+++ b/src/microsoft/clc/clc_compiler.h
@@ -302,6 +302,13 @@ struct clc_spirv_specialization_consts {
    const struct clc_spirv_specialization *specializations;
    unsigned num_specializations;
 };
+
+bool
+clc_specialize_spirv(const struct clc_binary *in_spirv,
+                     const struct clc_parsed_spirv *parsed_data,
+                     const struct clc_spirv_specialization_consts *consts,
+                     struct clc_binary *out_spirv);
+
 bool
 clc_spirv_to_dxil(struct clc_libclc *lib,
                   const struct clc_binary *linked_spirv,
diff --git a/src/microsoft/clc/clc_helpers.cpp 
b/src/microsoft/clc/clc_helpers.cpp
index 68591bcf1a0..8ea2ef10a86 100644
--- a/src/microsoft/clc/clc_helpers.cpp
+++ b/src/microsoft/clc/clc_helpers.cpp
@@ -46,6 +46,7 @@
 
 #include <spirv-tools/libspirv.hpp>
 #include <spirv-tools/linker.hpp>
+#include <spirv-tools/optimizer.hpp>
 
 #include "util/macros.h"
 #include "glsl_types.h"
@@ -58,6 +59,8 @@
 #include "opencl-c.h.h"
 #include "opencl-c-base.h.h"
 
+constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_0;
+
 using ::llvm::Function;
 using ::llvm::LLVMContext;
 using ::llvm::Module;
@@ -955,7 +958,7 @@ clc_link_spirv_binaries(const struct clc_linker_args *args,
    }
 
    SPIRVMessageConsumer msgconsumer(logger);
-   spvtools::Context context(SPV_ENV_UNIVERSAL_1_0);
+   spvtools::Context context(spirv_target);
    context.SetMessageConsumer(msgconsumer);
    spvtools::LinkerOptions options;
    options.SetAllowPartialLinkage(args->create_library);
@@ -973,10 +976,71 @@ clc_link_spirv_binaries(const struct clc_linker_args 
*args,
    return 0;
 }
 
+int
+clc_spirv_specialize(const struct clc_binary *in_spirv,
+                     const struct clc_parsed_spirv *parsed_data,
+                     const struct clc_spirv_specialization_consts *consts,
+                     struct clc_binary *out_spirv)
+{
+   std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
+   for (unsigned i = 0; i < consts->num_specializations; ++i) {
+      unsigned id = consts->specializations[i].id;
+      auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
+         parsed_data->spec_constants + parsed_data->num_spec_constants,
+         [id](const clc_parsed_spec_constant &c) { return c.id == id; });
+      assert(parsed_spec_const != parsed_data->spec_constants + 
parsed_data->num_spec_constants);
+
+      std::vector<uint32_t> words;
+      switch (parsed_spec_const->type) {
+      case CLC_SPEC_CONSTANT_BOOL:
+         words.push_back(consts->specializations[i].value.b);
+         break;
+      case CLC_SPEC_CONSTANT_INT32:
+      case CLC_SPEC_CONSTANT_UINT32:
+      case CLC_SPEC_CONSTANT_FLOAT:
+         words.push_back(consts->specializations[i].value.u32);
+         break;
+      case CLC_SPEC_CONSTANT_INT16:
+         
words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
+         break;
+      case CLC_SPEC_CONSTANT_INT8:
+         
words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
+         break;
+      case CLC_SPEC_CONSTANT_UINT16:
+         words.push_back((uint32_t)consts->specializations[i].value.u16);
+         break;
+      case CLC_SPEC_CONSTANT_UINT8:
+         words.push_back((uint32_t)consts->specializations[i].value.u8);
+         break;
+      case CLC_SPEC_CONSTANT_DOUBLE:
+      case CLC_SPEC_CONSTANT_INT64:
+      case CLC_SPEC_CONSTANT_UINT64:
+         words.resize(2);
+         memcpy(words.data(), &consts->specializations[i].value.u64, 8);
+         break;
+      }
+
+      ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
+      assert(ret.second);
+   }
+
+   spvtools::Optimizer opt(spirv_target);
+   
opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
+
+   std::vector<uint32_t> result;
+   if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 
4, &result))
+      return false;
+
+   out_spirv->size = result.size() * 4;
+   out_spirv->data = malloc(out_spirv->size);
+   memcpy(out_spirv->data, result.data(), out_spirv->size);
+   return true;
+}
+
 void
 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
 {
-   spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
+   spvtools::SpirvTools tools(spirv_target);
    const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
    std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
    std::string out;
diff --git a/src/microsoft/clc/clc_helpers.h b/src/microsoft/clc/clc_helpers.h
index 7a5cf0c9fa9..6edf261c3bc 100644
--- a/src/microsoft/clc/clc_helpers.h
+++ b/src/microsoft/clc/clc_helpers.h
@@ -70,6 +70,12 @@ clc_link_spirv_binaries(const struct clc_linker_args *args,
                         const struct clc_logger *logger,
                         struct clc_binary *out_spirv);
 
+int
+clc_spirv_specialize(const struct clc_binary *in_spirv,
+                     const struct clc_parsed_spirv *parsed_data,
+                     const struct clc_spirv_specialization_consts *consts,
+                     struct clc_binary *out_spirv);
+
 void
 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f);
 
diff --git a/src/microsoft/clc/clon12compiler.def 
b/src/microsoft/clc/clon12compiler.def
index f0c8a2268dd..bb93d4b8c9e 100644
--- a/src/microsoft/clc/clon12compiler.def
+++ b/src/microsoft/clc/clon12compiler.def
@@ -12,6 +12,7 @@ EXPORTS
     clc_link_spirv
     clc_parse_spirv
     clc_free_parsed_spirv
+    clc_specialize_spirv
     clc_spirv_to_dxil
     clc_free_dxil_object
     clc_compiler_get_version

Reply via email to