Module: Mesa
Branch: main
Commit: 8dcf7648f155d9c0a3826bfb460916bf6f4a2250
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=8dcf7648f155d9c0a3826bfb460916bf6f4a2250

Author: Alyssa Rosenzweig <[email protected]>
Date:   Thu Nov 24 20:40:50 2022 -0500

agx: Lower VBOs in NIR

Now we support all the vertex formats! This means we don't hit u_vbuf for format
translation, which helps performance in lots of applications. By doing the
lowering in NIR, the vertex fetch code itself can be optimized by NIR (e.g.
nir_opt_algebraic) which can improve generated code quality.

In my first implementation of this, I had a big switch statement mapping format
enums to interchange formats and post-processing code. This ends up being really
unwieldly, the combinatorics of bit packing + conversion + swizzles is
enormous and for performance we want to support everything (no u_vbuf
fallbacks). To keep the combinatorics in check, we rely on parsing the
util_format_description to separate out the issues of bit packing, conversion,
and swizzling, allowing us to handle bizarro formats like B10G10R10A2_SNORM with
no special casing.

In an effort to support everything in one shot, this handles all the formats
needed for the extensions EXT_vertex_array_bgra, ARB_vertex_type_2_10_10_10_rev,
and ARB_vertex_type_10f_11f_11f_rev.

Passes dEQP-GLES3.functional.vertex_arrays.*

Signed-off-by: Alyssa Rosenzweig <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19996>

---

 docs/features.txt                     |   4 +-
 src/asahi/compiler/agx_compile.c      |  68 +---------
 src/asahi/compiler/agx_compile.h      |  53 --------
 src/asahi/lib/agx_formats.c           |  66 ----------
 src/asahi/lib/agx_formats.h           |   2 -
 src/asahi/lib/agx_nir_lower_vbo.c     | 239 ++++++++++++++++++++++++++++++++++
 src/asahi/lib/agx_nir_lower_vbo.h     |  46 +++++++
 src/asahi/lib/meson.build             |   1 +
 src/gallium/drivers/asahi/agx_pipe.c  |  14 +-
 src/gallium/drivers/asahi/agx_state.c |  38 +++---
 src/gallium/drivers/asahi/agx_state.h |   3 +
 11 files changed, 315 insertions(+), 219 deletions(-)

diff --git a/docs/features.txt b/docs/features.txt
index 965c05a1a79..40ebfd68028 100644
--- a/docs/features.txt
+++ b/docs/features.txt
@@ -108,7 +108,7 @@ GL 3.3, GLSL 3.30 --- all DONE: freedreno, i965, nv50, 
nvc0, r600, radeonsi, llv
   GL_ARB_texture_swizzle                                DONE (v3d, vc4, 
panfrost, lima, asahi)
   GL_ARB_timer_query                                    DONE ()
   GL_ARB_instanced_arrays                               DONE (etnaviv/HALTI2, 
v3d, panfrost)
-  GL_ARB_vertex_type_2_10_10_10_rev                     DONE (v3d, panfrost)
+  GL_ARB_vertex_type_2_10_10_10_rev                     DONE (v3d, panfrost, 
asahi)
 
 
 GL 4.0, GLSL 4.00 --- all DONE: freedreno/a6xx, i965/gen7+, nvc0, r600, 
radeonsi, llvmpipe, virgl, zink, d3d12
@@ -208,7 +208,7 @@ GL 4.4, GLSL 4.40 -- all DONE: freedreno/a6xx, i965/gen8+, 
nvc0, r600, radeonsi,
   GL_ARB_query_buffer_object                            DONE (freedreno/a6xx, 
i965/hsw+, virgl)
   GL_ARB_texture_mirror_clamp_to_edge                   DONE (freedreno, i965, 
nv50, softpipe, virgl, v3d, panfrost)
   GL_ARB_texture_stencil8                               DONE (freedreno, 
i965/hsw+, nv50, softpipe, virgl, v3d, panfrost, d3d12, asahi)
-  GL_ARB_vertex_type_10f_11f_11f_rev                    DONE (freedreno, i965, 
nv50, softpipe, virgl, panfrost, d3d12)
+  GL_ARB_vertex_type_10f_11f_11f_rev                    DONE (freedreno, i965, 
nv50, softpipe, virgl, panfrost, d3d12, asahi)
 
 GL 4.5, GLSL 4.50 -- all DONE: freedreno/a6xx, nvc0, r600, radeonsi, llvmpipe, 
zink
 
diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c
index 8deae01f4b3..a27e27671b7 100644
--- a/src/asahi/compiler/agx_compile.c
+++ b/src/asahi/compiler/agx_compile.c
@@ -358,61 +358,6 @@ agx_format_for_pipe(enum pipe_format format)
    unreachable("Invalid format");
 }
 
-/* AGX appears to lack support for vertex attributes. Lower to global loads. */
-static void
-agx_emit_load_attr(agx_builder *b, agx_index dest, nir_intrinsic_instr *instr)
-{
-   nir_src *offset_src = nir_get_io_offset_src(instr);
-   assert(nir_src_is_const(*offset_src) && "no attribute indirects");
-   unsigned index = nir_intrinsic_base(instr) +
-                    nir_src_as_uint(*offset_src);
-
-   struct agx_shader_key *key = b->shader->key;
-   struct agx_attribute attrib = key->vs.attributes[index];
-
-   /* address = base + (stride * vertex_id) + src_offset */
-   unsigned buf = attrib.buf;
-   unsigned stride = key->vs.vbuf_strides[buf];
-   unsigned shift = agx_format_shift(attrib.format);
-
-   agx_index shifted_stride = agx_mov_imm(b, 32, stride >> shift);
-   agx_index src_offset = agx_mov_imm(b, 32, attrib.src_offset);
-
-   /* A nonzero divisor requires dividing the instance ID. A zero divisor
-    * specifies per-instance data. */
-   agx_index element_id = (attrib.divisor == 0) ? agx_vertex_id(b) :
-                          agx_udiv_const(b, agx_instance_id(b), 
attrib.divisor);
-
-   agx_index offset = agx_imad(b, element_id, shifted_stride, src_offset, 0);
-
-   /* Each VBO has a 64-bit = 4 x 16-bit address, lookup the base address as a
-    * sysval.  Mov around the base to handle uniform restrictions, copyprop 
will
-    * usually clean that up.
-    */
-   agx_index base = agx_mov(b, agx_vbo_base(b->shader, buf));
-
-   /* Load the data */
-   assert(instr->num_components <= 4);
-
-   unsigned actual_comps = (attrib.nr_comps_minus_1 + 1);
-   agx_index vec = agx_vec_for_dest(b->shader, &instr->dest);
-   agx_device_load_to(b, vec, base, offset, attrib.format,
-                      BITFIELD_MASK(attrib.nr_comps_minus_1 + 1), 0, 0);
-   agx_wait(b, 0);
-
-   agx_index dests[4] = { agx_null() };
-   agx_emit_split(b, dests, vec, actual_comps);
-
-   agx_index one = agx_mov_imm(b, 32, fui(1.0));
-   agx_index zero = agx_mov_imm(b, 32, 0);
-   agx_index default_value[4] = { zero, zero, zero, one };
-
-   for (unsigned i = actual_comps; i < instr->num_components; ++i)
-      dests[i] = default_value[i];
-
-   agx_emit_collect_to(b, dest, instr->num_components, dests);
-}
-
 static void
 agx_emit_load_vary_flat(agx_builder *b, agx_index dest, nir_intrinsic_instr 
*instr)
 {
@@ -733,13 +678,8 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr 
*instr)
      return NULL;
 
   case nir_intrinsic_load_input:
-     if (stage == MESA_SHADER_FRAGMENT)
-        agx_emit_load_vary_flat(b, dst, instr);
-     else if (stage == MESA_SHADER_VERTEX)
-        agx_emit_load_attr(b, dst, instr);
-     else
-        unreachable("Unsupported shader stage");
-
+     assert(stage == MESA_SHADER_FRAGMENT && "vertex loads lowered");
+     agx_emit_load_vary_flat(b, dst, instr);
      return NULL;
 
   case nir_intrinsic_load_global:
@@ -785,6 +725,10 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr 
*instr)
               nir_src_as_uint(instr->src[0]) * 4,
               b->shader->nir->info.num_ubos * 4));
 
+  case nir_intrinsic_load_vbo_base_agx:
+     return agx_mov_to(b, dst,
+                       agx_vbo_base(b->shader, 
nir_src_as_uint(instr->src[0])));
+
   case nir_intrinsic_load_vertex_id:
      return agx_mov_to(b, dst, agx_abs(agx_vertex_id(b)));
 
diff --git a/src/asahi/compiler/agx_compile.h b/src/asahi/compiler/agx_compile.h
index 00cdf0a83b0..e549ee181d7 100644
--- a/src/asahi/compiler/agx_compile.h
+++ b/src/asahi/compiler/agx_compile.h
@@ -182,8 +182,6 @@ struct agx_shader_info {
 };
 
 #define AGX_MAX_RTS (8)
-#define AGX_MAX_ATTRIBS (16)
-#define AGX_MAX_VBUFS (16)
 
 enum agx_format {
    AGX_FORMAT_I8 = 0,
@@ -203,56 +201,6 @@ enum agx_format {
    AGX_NUM_FORMATS,
 };
 
-/* Returns the number of bits at the bottom of the address required to be zero.
- * That is, returns the base-2 logarithm of the minimum alignment for an
- * agx_format, where the minimum alignment is 2^n where n is the result of this
- * function. The offset argument to device_load is left-shifted by this amount
- * in the hardware */
-
-static inline unsigned
-agx_format_shift(enum agx_format format)
-{
-   switch (format) {
-   case AGX_FORMAT_I8:
-   case AGX_FORMAT_U8NORM:
-   case AGX_FORMAT_S8NORM:
-   case AGX_FORMAT_SRGBA8:
-      return 0;
-
-   case AGX_FORMAT_I16:
-   case AGX_FORMAT_F16:
-   case AGX_FORMAT_U16NORM:
-   case AGX_FORMAT_S16NORM:
-      return 1;
-
-   case AGX_FORMAT_I32:
-   case AGX_FORMAT_RGB10A2:
-   case AGX_FORMAT_RG11B10F:
-   case AGX_FORMAT_RGB9E5:
-      return 2;
-
-   default:
-      unreachable("invalid format");
-   }
-}
-
-struct agx_attribute {
-   uint32_t divisor;
-
-   unsigned buf : 5;
-   unsigned src_offset : 16;
-   unsigned nr_comps_minus_1 : 2;
-   enum agx_format format : 4;
-   unsigned padding : 5;
-};
-
-struct agx_vs_shader_key {
-   unsigned num_vbufs;
-   unsigned vbuf_strides[AGX_MAX_VBUFS];
-
-   struct agx_attribute attributes[AGX_MAX_ATTRIBS];
-};
-
 struct agx_fs_shader_key {
    /* Normally, access to the tilebuffer must be guarded by appropriate fencing
     * instructions to ensure correct results in the presence of out-of-order
@@ -269,7 +217,6 @@ struct agx_fs_shader_key {
 
 struct agx_shader_key {
    union {
-      struct agx_vs_shader_key vs;
       struct agx_fs_shader_key fs;
    };
 };
diff --git a/src/asahi/lib/agx_formats.c b/src/asahi/lib/agx_formats.c
index 5f5923cf452..55f592d2f9c 100644
--- a/src/asahi/lib/agx_formats.c
+++ b/src/asahi/lib/agx_formats.c
@@ -190,69 +190,3 @@ const struct agx_pixel_format_entry 
agx_pixel_format[PIPE_FORMAT_COUNT] = {
    AGX_FMT(BPTC_RGBA_UNORM,         BC7,           UNORM,  F, _),
    AGX_FMT(BPTC_SRGBA,              BC7,           UNORM,  F, _),
 };
-
-const enum agx_format
-agx_vertex_format[PIPE_FORMAT_COUNT] = {
-   [PIPE_FORMAT_R32_FLOAT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32_SINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32_FLOAT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32_SINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32_FLOAT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32_SINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32A32_FLOAT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32A32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32A32_SINT] = AGX_FORMAT_I32,
-
-   [PIPE_FORMAT_R8_UNORM] = AGX_FORMAT_U8NORM,
-   [PIPE_FORMAT_R8G8_UNORM] = AGX_FORMAT_U8NORM,
-   [PIPE_FORMAT_R8G8B8_UNORM] = AGX_FORMAT_U8NORM,
-   [PIPE_FORMAT_R8G8B8A8_UNORM] = AGX_FORMAT_U8NORM,
-
-   [PIPE_FORMAT_R8_SNORM] = AGX_FORMAT_S8NORM,
-   [PIPE_FORMAT_R8G8_SNORM] = AGX_FORMAT_S8NORM,
-   [PIPE_FORMAT_R8G8B8_SNORM] = AGX_FORMAT_S8NORM,
-   [PIPE_FORMAT_R8G8B8A8_SNORM] = AGX_FORMAT_S8NORM,
-
-   [PIPE_FORMAT_R16_UNORM] = AGX_FORMAT_U16NORM,
-   [PIPE_FORMAT_R16G16_UNORM] = AGX_FORMAT_U16NORM,
-   [PIPE_FORMAT_R16G16B16_UNORM] = AGX_FORMAT_U16NORM,
-   [PIPE_FORMAT_R16G16B16A16_UNORM] = AGX_FORMAT_U16NORM,
-
-   [PIPE_FORMAT_R16_SNORM] = AGX_FORMAT_S16NORM,
-   [PIPE_FORMAT_R16G16_SNORM] = AGX_FORMAT_S16NORM,
-   [PIPE_FORMAT_R16G16B16_SNORM] = AGX_FORMAT_S16NORM,
-   [PIPE_FORMAT_R16G16B16A16_SNORM] = AGX_FORMAT_S16NORM,
-
-   [PIPE_FORMAT_R8_UINT] = AGX_FORMAT_I8,
-   [PIPE_FORMAT_R8G8_UINT] = AGX_FORMAT_I8,
-   [PIPE_FORMAT_R8G8B8_UINT] = AGX_FORMAT_I8,
-   [PIPE_FORMAT_R8G8B8A8_UINT] = AGX_FORMAT_I8,
-
-   [PIPE_FORMAT_R8_SINT] = AGX_FORMAT_I8,
-   [PIPE_FORMAT_R8G8_SINT] = AGX_FORMAT_I8,
-   [PIPE_FORMAT_R8G8B8_SINT] = AGX_FORMAT_I8,
-   [PIPE_FORMAT_R8G8B8A8_SINT] = AGX_FORMAT_I8,
-
-   [PIPE_FORMAT_R16_UINT] = AGX_FORMAT_I16,
-   [PIPE_FORMAT_R16G16_UINT] = AGX_FORMAT_I16,
-   [PIPE_FORMAT_R16G16B16_UINT] = AGX_FORMAT_I16,
-   [PIPE_FORMAT_R16G16B16A16_UINT] = AGX_FORMAT_I16,
-
-   [PIPE_FORMAT_R16_SINT] = AGX_FORMAT_I16,
-   [PIPE_FORMAT_R16G16_SINT] = AGX_FORMAT_I16,
-   [PIPE_FORMAT_R16G16B16_SINT] = AGX_FORMAT_I16,
-   [PIPE_FORMAT_R16G16B16A16_SINT] = AGX_FORMAT_I16,
-
-   [PIPE_FORMAT_R32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32_UINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32A32_UINT] = AGX_FORMAT_I32,
-
-   [PIPE_FORMAT_R32_SINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32_SINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32_SINT] = AGX_FORMAT_I32,
-   [PIPE_FORMAT_R32G32B32A32_SINT] = AGX_FORMAT_I32,
-};
diff --git a/src/asahi/lib/agx_formats.h b/src/asahi/lib/agx_formats.h
index c1e33fbc1d8..b463418b0e6 100644
--- a/src/asahi/lib/agx_formats.h
+++ b/src/asahi/lib/agx_formats.h
@@ -26,7 +26,6 @@
 #define __AGX_FORMATS_H_
 
 #include "util/format/u_format.h"
-#include "asahi/compiler/agx_compile.h"
 
 struct agx_pixel_format_entry {
    uint8_t channels;
@@ -36,7 +35,6 @@ struct agx_pixel_format_entry {
 };
 
 extern const struct agx_pixel_format_entry agx_pixel_format[PIPE_FORMAT_COUNT];
-extern const enum agx_format agx_vertex_format[PIPE_FORMAT_COUNT];
 
 /* N.b. hardware=0 corresponds to R8 UNORM, which is renderable. So a zero
  * entry indicates an invalid format. */
diff --git a/src/asahi/lib/agx_nir_lower_vbo.c 
b/src/asahi/lib/agx_nir_lower_vbo.c
new file mode 100644
index 00000000000..d796f2d5489
--- /dev/null
+++ b/src/asahi/lib/agx_nir_lower_vbo.c
@@ -0,0 +1,239 @@
+/*
+ * Copyright 2022 Alyssa Rosenzweig
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "agx_nir_lower_vbo.h"
+#include "compiler/nir/nir_builder.h"
+#include "compiler/nir/nir_format_convert.h"
+#include "util/u_math.h"
+
+static bool
+is_rgb10_a2(const struct util_format_description *desc)
+{
+   return desc->channel[0].shift ==  0 && desc->channel[0].size == 10 &&
+          desc->channel[1].shift == 10 && desc->channel[1].size == 10 &&
+          desc->channel[2].shift == 20 && desc->channel[2].size == 10 &&
+          desc->channel[3].shift == 30 && desc->channel[3].size == 2;
+}
+
+static enum pipe_format
+agx_vbo_internal_format(enum pipe_format format)
+{
+   const struct util_format_description *desc = 
util_format_description(format);
+
+   /* RGB10A2 formats are native for UNORM and unpacked otherwise */
+   if (is_rgb10_a2(desc)) {
+      if (desc->is_unorm)
+         return PIPE_FORMAT_R10G10B10A2_UNORM;
+      else
+         return PIPE_FORMAT_R32_UINT;
+   }
+
+   /* R11G11B10F is native and special */
+   if (format == PIPE_FORMAT_R11G11B10_FLOAT)
+      return format;
+
+   /* No other non-array formats handled */
+   if (!desc->is_array)
+      return PIPE_FORMAT_NONE;
+
+   /* Otherwise look at one (any) channel */
+   int idx = util_format_get_first_non_void_channel(format);
+   if (idx < 0)
+      return PIPE_FORMAT_NONE;
+
+   /* We only handle RGB formats (we could do SRGB if we wanted though?) */
+   if ((desc->colorspace != UTIL_FORMAT_COLORSPACE_RGB) ||
+       (desc->layout != UTIL_FORMAT_LAYOUT_PLAIN))
+      return PIPE_FORMAT_NONE;
+
+   /* We have native 8-bit and 16-bit normalized formats */
+   struct util_format_channel_description chan = desc->channel[idx];
+
+   if (chan.normalized) {
+      if (chan.size == 8)
+         return desc->is_unorm ? PIPE_FORMAT_R8_UNORM : PIPE_FORMAT_R8_SNORM;
+      else if (chan.size == 16)
+         return desc->is_unorm ? PIPE_FORMAT_R16_UNORM : PIPE_FORMAT_R16_SNORM;
+   }
+
+   /* Otherwise map to the corresponding integer format */
+   switch (chan.size) {
+   case 32: return PIPE_FORMAT_R32_UINT;
+   case 16: return PIPE_FORMAT_R16_UINT;
+   case  8: return PIPE_FORMAT_R8_UINT;
+   default: return PIPE_FORMAT_NONE;
+   }
+}
+
+bool
+agx_vbo_supports_format(enum pipe_format format)
+{
+   return agx_vbo_internal_format(format) != PIPE_FORMAT_NONE;
+}
+
+static nir_ssa_def *
+apply_swizzle_channel(nir_builder *b, nir_ssa_def *vec,
+                      unsigned swizzle, bool is_int)
+{
+   switch (swizzle) {
+   case PIPE_SWIZZLE_X: return nir_channel(b, vec, 0);
+   case PIPE_SWIZZLE_Y: return nir_channel(b, vec, 1);
+   case PIPE_SWIZZLE_Z: return nir_channel(b, vec, 2);
+   case PIPE_SWIZZLE_W: return nir_channel(b, vec, 3);
+   case PIPE_SWIZZLE_0: return nir_imm_intN_t(b, 0, vec->bit_size);
+   case PIPE_SWIZZLE_1: return is_int ? nir_imm_intN_t(b, 1, vec->bit_size) :
+                                        nir_imm_floatN_t(b, 1.0, 
vec->bit_size);
+   default: unreachable("Invalid swizzle channel");
+   }
+}
+
+static bool
+pass(struct nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+   if (intr->intrinsic != nir_intrinsic_load_input)
+      return false;
+
+   struct agx_vbufs *vbufs = data;
+   b->cursor = nir_before_instr(instr);
+
+   nir_src *offset_src = nir_get_io_offset_src(intr);
+   assert(nir_src_is_const(*offset_src) && "no attribute indirects");
+   unsigned index = nir_intrinsic_base(intr) + nir_src_as_uint(*offset_src);
+
+   struct agx_attribute attrib = vbufs->attributes[index];
+   uint32_t stride = vbufs->strides[attrib.buf];
+   uint16_t offset = attrib.src_offset;
+
+   const struct util_format_description *desc =
+      util_format_description(attrib.format);
+   int chan = util_format_get_first_non_void_channel(attrib.format);
+   assert(chan >= 0);
+
+   bool is_float    = desc->channel[chan].type == UTIL_FORMAT_TYPE_FLOAT;
+   bool is_unsigned = desc->channel[chan].type == UTIL_FORMAT_TYPE_UNSIGNED;
+   bool is_signed   = desc->channel[chan].type == UTIL_FORMAT_TYPE_SIGNED;
+   bool is_fixed    = desc->channel[chan].type == UTIL_FORMAT_TYPE_FIXED;
+   bool is_int      = util_format_is_pure_integer(attrib.format);
+
+   assert((is_float ^ is_unsigned ^ is_signed ^ is_fixed) && "Invalid format");
+
+   enum pipe_format interchange_format = 
agx_vbo_internal_format(attrib.format);
+   assert(interchange_format != PIPE_FORMAT_NONE);
+
+   unsigned interchange_align = util_format_get_blocksize(interchange_format);
+   unsigned interchange_comps = util_format_get_nr_components(attrib.format);
+
+   /* In the hardware, uint formats zero-extend and float formats convert.
+    * However, non-uint formats using a uint interchange format shouldn't be
+    * zero extended.
+    */
+   unsigned interchange_register_size =
+      util_format_is_pure_uint(interchange_format) && 
!util_format_is_pure_uint(attrib.format) ?
+      (interchange_align * 8):
+      nir_dest_bit_size(intr->dest);
+
+   /* Non-UNORM R10G10B10A2 loaded as a scalar and unpacked */
+   if (interchange_format == PIPE_FORMAT_R32_UINT && !desc->is_array)
+      interchange_comps = 1;
+
+   /* Calculate the element to fetch the vertex for. Divide the instance ID by
+    * the divisor for per-instance data. Divisor=0 specifies per-vertex data.
+    */
+   nir_ssa_def *el = (attrib.divisor == 0) ?
+                     nir_load_vertex_id(b) :
+                     nir_udiv_imm(b, nir_load_instance_id(b), attrib.divisor);
+
+   nir_ssa_def *base = nir_load_vbo_base_agx(b, nir_imm_int(b, attrib.buf));
+
+   assert((stride % interchange_align) == 0 && "must be aligned");
+   assert((offset % interchange_align) == 0 && "must be aligned");
+
+   unsigned stride_el = stride / interchange_align;
+   unsigned offset_el = offset / interchange_align;
+
+   nir_ssa_def *stride_offset_el =
+      nir_iadd_imm(b, nir_imul_imm(b, el, stride_el), offset_el);
+
+   /* Load the raw vector */
+   nir_ssa_def *memory =
+      nir_load_constant_agx(b, interchange_comps,
+                            interchange_register_size,
+                            base,
+                            stride_offset_el,
+                            .format = interchange_format);
+
+   unsigned dest_size = nir_dest_bit_size(intr->dest);
+
+   /* Unpack but do not convert non-native non-array formats */
+   if (is_rgb10_a2(desc) && interchange_format == PIPE_FORMAT_R32_UINT) {
+      unsigned bits[] = { 10, 10, 10, 2 };
+
+      if (is_signed)
+         memory = nir_format_unpack_sint(b, memory, bits, 4);
+      else
+         memory = nir_format_unpack_uint(b, memory, bits, 4);
+   }
+
+   if (desc->channel[chan].normalized) {
+      /* 8/16-bit normalized formats are native, others converted here */
+      if (is_rgb10_a2(desc) && is_signed) {
+         unsigned bits[] = { 10, 10, 10, 2 };
+         memory = nir_format_snorm_to_float(b, memory, bits);
+      } else if (desc->channel[chan].size == 32) {
+         assert(desc->is_array && "no non-array 32-bit norm formats");
+         unsigned bits[] = { 32, 32, 32, 32 };
+
+         if (is_signed)
+            memory = nir_format_snorm_to_float(b, memory, bits);
+         else
+            memory = nir_format_unorm_to_float(b, memory, bits);
+      }
+   } else if (desc->channel[chan].pure_integer) {
+      /* Zero-extension is native, may need to sign extend */
+      if (is_signed)
+         memory = nir_i2iN(b, memory, dest_size);
+   } else {
+      if (is_unsigned)
+         memory = nir_u2fN(b, memory, dest_size);
+      else if (is_signed || is_fixed)
+         memory = nir_i2fN(b, memory, dest_size);
+      else
+         memory = nir_f2fN(b, memory, dest_size);
+
+      /* 16.16 fixed-point weirdo GL formats need to be scaled */
+      if (is_fixed) {
+         assert(desc->is_array && desc->channel[chan].size == 32);
+         assert(dest_size == 32 && "overflow if smaller");
+         memory = nir_fmul_imm(b, memory, 1.0 / 65536.0);
+      }
+   }
+
+   /* We now have a properly formatted vector of the components in memory. 
Apply
+    * the format swizzle forwards to trim/pad/reorder as needed.
+    */
+   nir_ssa_def *channels[4] = { NULL };
+   assert(nir_intrinsic_component(intr) == 0 && "unimplemented");
+
+   for (unsigned i = 0; i < intr->num_components; ++i)
+      channels[i] = apply_swizzle_channel(b, memory, desc->swizzle[i], is_int);
+
+   nir_ssa_def *logical = nir_vec(b, channels, intr->num_components);
+   nir_ssa_def_rewrite_uses(&intr->dest.ssa, logical);
+   return true;
+}
+
+bool
+agx_nir_lower_vbo(nir_shader *shader, struct agx_vbufs *vbufs)
+{
+   assert(shader->info.stage == MESA_SHADER_VERTEX);
+   return nir_shader_instructions_pass(shader, pass,
+                                       nir_metadata_block_index |
+                                       nir_metadata_dominance,
+                                       vbufs);
+}
diff --git a/src/asahi/lib/agx_nir_lower_vbo.h 
b/src/asahi/lib/agx_nir_lower_vbo.h
new file mode 100644
index 00000000000..ab014707d57
--- /dev/null
+++ b/src/asahi/lib/agx_nir_lower_vbo.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2022 Alyssa Rosenzweig
+ * SPDX-License-Identifier: MIT
+ */
+
+#ifndef __AGX_NIR_LOWER_VBO_H
+#define __AGX_NIR_LOWER_VBO_H
+
+#include <stdint.h>
+#include <stdbool.h>
+#include "nir.h"
+#include "util/format/u_formats.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define AGX_MAX_ATTRIBS (16)
+#define AGX_MAX_VBUFS (16)
+
+/* See pipe_vertex_element for justification on the sizes. This structure 
should
+ * be small so it can be embedded into a shader key.
+ */
+struct agx_attribute {
+   uint32_t divisor;
+   uint16_t src_offset;
+   uint8_t buf;
+
+   /* pipe_format, all vertex formats should be <= 255 */
+   uint8_t format;
+};
+
+struct agx_vbufs {
+   unsigned count;
+   uint32_t strides[AGX_MAX_VBUFS];
+   struct agx_attribute attributes[AGX_MAX_ATTRIBS];
+};
+
+bool agx_nir_lower_vbo(nir_shader *shader, struct agx_vbufs *vbufs);
+bool agx_vbo_supports_format(enum pipe_format format);
+
+#ifdef __cplusplus
+} /* extern C */
+#endif
+
+#endif
diff --git a/src/asahi/lib/meson.build b/src/asahi/lib/meson.build
index d449d7bff68..5bff002b442 100644
--- a/src/asahi/lib/meson.build
+++ b/src/asahi/lib/meson.build
@@ -27,6 +27,7 @@ libasahi_lib_files = files(
   'agx_meta.c',
   'agx_tilebuffer.c',
   'agx_nir_lower_tilebuffer.c',
+  'agx_nir_lower_vbo.c',
   'agx_ppp.h',
   'pool.c',
 )
diff --git a/src/gallium/drivers/asahi/agx_pipe.c 
b/src/gallium/drivers/asahi/agx_pipe.c
index 3d4b621fa6e..fccf7b1c8cd 100644
--- a/src/gallium/drivers/asahi/agx_pipe.c
+++ b/src/gallium/drivers/asahi/agx_pipe.c
@@ -1573,18 +1573,8 @@ agx_is_format_supported(struct pipe_screen* pscreen,
          return false;
    }
 
-   /* TODO: formats */
-   if (usage & PIPE_BIND_VERTEX_BUFFER) {
-      switch (format) {
-      case PIPE_FORMAT_R32_FLOAT:
-      case PIPE_FORMAT_R32G32_FLOAT:
-      case PIPE_FORMAT_R32G32B32_FLOAT:
-      case PIPE_FORMAT_R32G32B32A32_FLOAT:
-         break;
-      default:
-         return false;
-      }
-   }
+   if ((usage & PIPE_BIND_VERTEX_BUFFER) && !agx_vbo_supports_format(format))
+      return false;
 
    if (usage & PIPE_BIND_DEPTH_STENCIL) {
       switch (format) {
diff --git a/src/gallium/drivers/asahi/agx_state.c 
b/src/gallium/drivers/asahi/agx_state.c
index 90f7b129453..889b08bf8a8 100644
--- a/src/gallium/drivers/asahi/agx_state.c
+++ b/src/gallium/drivers/asahi/agx_state.c
@@ -987,18 +987,13 @@ agx_create_vertex_elements(struct pipe_context *ctx,
 
       const struct util_format_description *desc =
          util_format_description(ve.src_format);
-
       unsigned chan_size = desc->channel[0].size / 8;
-
-      assert(chan_size == 1 || chan_size == 2 || chan_size == 4);
-      assert(desc->nr_channels >= 1 && desc->nr_channels <= 4);
       assert((ve.src_offset & (chan_size - 1)) == 0);
 
       attribs[i] = (struct agx_attribute) {
          .buf = ve.vertex_buffer_index,
-         .src_offset = ve.src_offset / chan_size,
-         .nr_comps_minus_1 = desc->nr_channels - 1,
-         .format = agx_vertex_format[ve.src_format],
+         .src_offset = ve.src_offset,
+         .format = ve.src_format,
          .divisor = ve.instance_divisor
       };
    }
@@ -1184,7 +1179,9 @@ agx_compile_variant(struct agx_device *dev,
 
    agx_preprocess_nir(nir);
 
-   if (nir->info.stage == MESA_SHADER_FRAGMENT) {
+   if (nir->info.stage == MESA_SHADER_VERTEX) {
+      NIR_PASS_V(nir, agx_nir_lower_vbo, &key->vbuf);
+   } else {
       struct agx_tilebuffer_layout tib =
          agx_build_tilebuffer_layout(key->rt_formats, key->nr_cbufs, 1);
 
@@ -1243,13 +1240,12 @@ agx_create_shader_state(struct pipe_context *pctx,
       switch (so->nir->info.stage) {
       case MESA_SHADER_VERTEX:
       {
-         key.base.vs.num_vbufs = AGX_MAX_VBUFS;
+         key.vbuf.count = AGX_MAX_VBUFS;
          for (unsigned i = 0; i < AGX_MAX_VBUFS; ++i) {
-            key.base.vs.vbuf_strides[i] = 16;
-            key.base.vs.attributes[i] = (struct agx_attribute) {
+            key.vbuf.strides[i] = 16;
+            key.vbuf.attributes[i] = (struct agx_attribute) {
                .buf = i,
-               .nr_comps_minus_1 = 4 - 1,
-               .format = AGX_FORMAT_I32
+               .format = PIPE_FORMAT_R32G32B32A32_FLOAT
             };
          }
 
@@ -1295,20 +1291,18 @@ agx_update_shader(struct agx_context *ctx, struct 
agx_compiled_shader **out,
 static bool
 agx_update_vs(struct agx_context *ctx)
 {
-   struct agx_vs_shader_key key = { 0 };
+   struct asahi_shader_key key = {
+      .vbuf.count = util_last_bit(ctx->vb_mask),
+   };
 
-   memcpy(key.attributes, ctx->attributes,
-          sizeof(key.attributes[0]) * AGX_MAX_ATTRIBS);
+   memcpy(key.vbuf.attributes, ctx->attributes,
+          sizeof(key.vbuf.attributes[0]) * AGX_MAX_ATTRIBS);
 
    u_foreach_bit(i, ctx->vb_mask) {
-      key.vbuf_strides[i] = ctx->vertex_buffers[i].stride;
+      key.vbuf.strides[i] = ctx->vertex_buffers[i].stride;
    }
 
-   struct asahi_shader_key akey = {
-      .base.vs = key
-   };
-
-   return agx_update_shader(ctx, &ctx->vs, PIPE_SHADER_VERTEX, &akey);
+   return agx_update_shader(ctx, &ctx->vs, PIPE_SHADER_VERTEX, &key);
 }
 
 static bool
diff --git a/src/gallium/drivers/asahi/agx_state.h 
b/src/gallium/drivers/asahi/agx_state.h
index ea1c3d65721..cdd70260355 100644
--- a/src/gallium/drivers/asahi/agx_state.h
+++ b/src/gallium/drivers/asahi/agx_state.h
@@ -34,6 +34,7 @@
 #include "asahi/lib/agx_device.h"
 #include "asahi/lib/pool.h"
 #include "asahi/lib/agx_tilebuffer.h"
+#include "asahi/lib/agx_nir_lower_vbo.h"
 #include "asahi/compiler/agx_compile.h"
 #include "asahi/layout/layout.h"
 #include "compiler/nir/nir_lower_blend.h"
@@ -142,6 +143,8 @@ struct agx_blend {
 
 struct asahi_shader_key {
    struct agx_shader_key base;
+   struct agx_vbufs vbuf;
+
    struct agx_blend blend;
    unsigned nr_cbufs;
 

Reply via email to