This is an automated email from the ASF dual-hosted git repository.

ryankert01 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f21c1076 [QDP] Close AMD-vs-CUDA encoder parity: add iqp, iqp-z, 
phase (#1292)
4f21c1076 is described below

commit 4f21c1076f435d381b8a664f1f3055b60fc92e3d
Author: Ryan Huang <[email protected]>
AuthorDate: Mon May 11 01:04:54 2026 +0800

    [QDP] Close AMD-vs-CUDA encoder parity: add iqp, iqp-z, phase (#1292)
    
    * [QDP] Close AMD-vs-CUDA encoder parity gap: add iqp, iqp-z, phase
    
    CUDA QdpEngine accepts amplitude, angle, basis, iqp, iqp-z, and phase.
    The Triton AMD path only implemented the first three, so AMD users hit
    a hard error on the IQP- and phase-family encodings (e.g. SVHN-IQP).
    
    This adds vectorized PyTorch implementations for the missing methods on
    TritonAmdEngine, dispatched through the same ``encode(method=...)``
    contract:
    
    - ``iqp`` — full ZZ entanglement: phase = Σ x_i·data_i + Σ_{i<j} x_i 
x_j·data_ij,
      followed by an n-stage Walsh-Hadamard butterfly and 1/2^n scaling.
    - ``iqp-z`` — Z-only diagonal: same FWT path with no ZZ pairs.
    - ``phase`` — per-qubit product state (1/√2^n)·exp(i·Σ_k phases_k·b_k).
    
    Parity tests added against ``qumat_qdp.torch_ref.iqp_encode`` (which is
    already validated against CUDA upstream) and a local pure-torch phase
    reference. Also added unit-norm structural checks, param-count
    validation, float64 precision contract, and a router test that the
    public ``QdpEngine(backend="amd")`` accepts the new methods.
    
    Verified on AMD Instinct MI300X (ROCm 7.2 / torch 2.9.0+rocm6.4 /
    triton 3.5.0): full triton_amd test file is 18 passed, 2 skipped
    (NVIDIA CUDA-only references).
    
    * Optimize TritonAmdEngine encoders + Triton @jit phase kernel
    
    Addresses Copilot review on PR #1292 and pushes general kernel-level
    optimization across all six AMD encoders.
    
    PR review responses:
    
    - Drop the unreachable `test_triton_amd_iqp_cuda_reference_optional`
      (decorator required `torch.version.cuda` while body required
      `is_triton_amd_available()` → mutually exclusive). Replace with a
      meaningful float64 IQP precision contract test that actually runs.
    - Qualify README about the CUDA-tensor `phase` limitation: the Python
      extension's CUDA-tensor allowlist (`CUDA_ENCODING_METHODS`) does not
      yet include `phase`, so cuda-resident torch tensors must `.cpu()`
      first. Tracked as a follow-up.
    - The pair-matrix-rewrite suggestion (per-pair Python loop) is
      rejected — n² tiny kernel launches lose to one matmul on every
      modern GPU; the current path matches `torch_ref.iqp_encode` and the
      CUDA FWT phase kernel. Add a `_IQP_PAIR_MATRIX_MAX_N` guard that
      *does* fall back to a pair loop past n=20 (where the (2^n × n_pairs)
      workspace dominates HBM), so the OOM scenario is bounded.
    
    Encoder optimizations (verified on MI300X vs `qumat_qdp.torch_ref`,
    batch=64, fp32 input):
    
    |        | q=8   | q=12  | q=16  |
    |--------|-------|-------|-------|
    | amplitude | 0.95× | 0.95× | 1.00× |
    | angle     | 1.57× | 1.37× | 1.04× |
    | basis     | 2.18× | 2.10× | 2.14× |
    | iqp(ZZ)   | 1.96× | 1.81× | 1.14× |
    | iqp-z     | 1.35× | 1.32× | 0.91× |
    | **phase** | **5.29×** | **5.39×** | **5.30×** |
    
    What changed:
    
    - **Real `@triton.jit` phase kernel** (fp32 / n ≤ 32). One HIP kernel
      fuses bit-pattern materialization + θ(b) accumulation + cos/sin +
      1/√2^n scaling + complex-pack, writing the output buffer interleaved
      via `view_as_real`. The PyTorch fallback path (used at fp64 or n > 32)
      was making 5 intermediate (B, S) allocations; the kernel makes one.
    - **Per-engine bits-table cache** (`_bits_cache`): the
      `((idx >> arange(n)) & 1).to(real)` table was being rebuilt on every
      call by `angle`/`iqp`/`phase`. Now cached per (n, dtype). At n=16
      that's a ~4 MiB int64 + ~4 MiB real allocation saved per call.
    - **Pair-index cache** (`_pair_cache`): `torch.combinations(arange(n))`
      cached per n.
    - **`encode_amplitude`**: replaced
      `torch.complex(amp, zeros_like(amp)).to(complex_dtype)` with a single
      `amp.to(complex_dtype)` (writes (real, 0) interleaved in one kernel
      vs. zeros_like + complex_pack + cast = three).
    - **`encode_angle`**: collapsed the n-step Python product loop (which
      reallocated a (B, S) tensor per qubit) into a single
      `where(bits, sin, cos).prod(dim=2)` reduction.
    - **`encode_iqp`**: in-place n-stage Walsh-Hadamard butterfly using a
      single `(B, S/2)` scratch buffer. The previous
      `cat([lo+hi, lo-hi], dim=2)` allocated a fresh (B, S) tensor every
      stage; now `sub(out=scratch); a.add_(b); b.copy_(scratch)` reuses one
      workspace across all n stages. Also packs `f` via `torch.complex(cos,
      sin)` in one shot rather than writing to strided `.real`/`.imag`.
    - **`_to_2d` fast path**: skip `as_tensor` + `.contiguous()` work when
      the caller already supplies a 2-D, contiguous, on-device,
      correctly-typed torch tensor (the common case for benchmarks).
    
    Test parity: 19 passed, 1 skipped (`test_triton_amd_cuda_reference_optional`
    is the pre-existing amplitude cross-ref; same skipif latency as the
    new iqp test we just deleted — out of scope here).
    
    Numerical parity: all encoders still match `torch_ref` /
    `_torch_phase_ref` within float-rounding tolerance; the IQP fp64
    contract test confirms `atol=1e-12`.
---
 qdp/qdp-python/README.md                        |  13 +-
 qdp/qdp-python/TRITON_AMD_BACKEND.md            |   9 +-
 qdp/qdp-python/qumat_qdp/triton_amd.py          | 339 +++++++++++++++++++++---
 qdp/qdp-python/tests/test_triton_amd_backend.py | 175 ++++++++++++
 4 files changed, 499 insertions(+), 37 deletions(-)

diff --git a/qdp/qdp-python/README.md b/qdp/qdp-python/README.md
index 8b3d964ac..cacfa0417 100644
--- a/qdp/qdp-python/README.md
+++ b/qdp/qdp-python/README.md
@@ -73,11 +73,18 @@ See `qdp/qdp-python/TRITON_AMD_BACKEND.md` for Triton AMD 
setup and validation d
 | `amplitude` | Normalize input as quantum amplitudes |
 | `angle` | Map values to rotation angles (one per qubit) |
 | `basis` | Encode integer as computational basis state |
-| `iqp` | IQP-style encoding with entanglement |
+| `iqp` | IQP-style encoding with full ZZ entanglement |
+| `iqp-z` | IQP encoding with Z-only diagonal (no ZZ pairs) |
+| `phase` | Per-qubit phase product state via H⊗P(x_k) |
 
 Backend support boundary:
-- CUDA (`QdpEngine`): `amplitude`, `angle`, `basis`, `iqp`
-- AMD (`QdpEngine(..., backend="amd")`): `amplitude`, `angle`, `basis` (no 
`iqp` yet)
+- CUDA (`QdpEngine`): `amplitude`, `angle`, `basis`, `iqp`, `iqp-z`, `phase`
+  - `phase` is currently only reachable on the CUDA path via host inputs
+    (Python list / NumPy / file / CPU torch tensor). The Python extension's
+    CUDA-tensor validation does not yet allowlist `phase`; cuda-resident
+    torch tensors must use `.cpu()` first when targeting `phase`. Tracked as
+    a follow-up.
+- AMD (`QdpEngine(..., backend="amd")`): `amplitude`, `angle`, `basis`, `iqp`, 
`iqp-z`, `phase`
 
 ## Input Sources
 
diff --git a/qdp/qdp-python/TRITON_AMD_BACKEND.md 
b/qdp/qdp-python/TRITON_AMD_BACKEND.md
index 93e01d466..3b3196187 100644
--- a/qdp/qdp-python/TRITON_AMD_BACKEND.md
+++ b/qdp/qdp-python/TRITON_AMD_BACKEND.md
@@ -65,9 +65,9 @@ Supported methods:
 - `amplitude`
 - `angle`
 - `basis`
-
-Not supported in the AMD route yet:
-- `iqp` (currently CUDA backend only)
+- `iqp` (full, with ZZ entanglement)
+- `iqp-z` (Z-only diagonal, no ZZ pairs)
+- `phase`
 
 ## Correctness tests
 
@@ -79,7 +79,8 @@ uv run --project qdp/qdp-python pytest -m rocm 
qdp/qdp-python/tests -q
 ```
 
 Tests include:
-- parity against Torch reference outputs (amplitude/angle/basis)
+- parity against Torch reference outputs (amplitude/angle/basis/iqp)
+- structural checks for `phase` (output is a unit-norm product state)
 - optional parity against CUDA backend reference (when NVIDIA CUDA path is 
present)
 
 ## Baseline benchmark
diff --git a/qdp/qdp-python/qumat_qdp/triton_amd.py 
b/qdp/qdp-python/qumat_qdp/triton_amd.py
index 1a531c4e7..8bcbd5d02 100644
--- a/qdp/qdp-python/qumat_qdp/triton_amd.py
+++ b/qdp/qdp-python/qumat_qdp/triton_amd.py
@@ -18,7 +18,8 @@
 
 from __future__ import annotations
 
-from dataclasses import dataclass
+import math
+from dataclasses import dataclass, field
 from importlib import import_module
 from typing import Any
 
@@ -34,6 +35,7 @@ def _load_optional_module(name: str) -> Any | None:
 
 torch_mod = _load_optional_module("torch")
 triton_mod = _load_optional_module("triton")
+triton_lang = _load_optional_module("triton.language")
 
 
 def _is_rocm_runtime() -> bool:
@@ -54,13 +56,85 @@ def is_triton_amd_available() -> bool:
         return True
 
 
+# ---------------------------------------------------------------------------
+# Triton kernel: fused phase encoder (real-only path).
+#
+# One kernel per program covers BLOCK output basis-states for a single sample,
+# fusing: bit-pattern materialization + θ(b) accumulation + sin/cos + 1/√2^n
+# scaling + complex-pack into the (B, S) real/imag planes. The PyTorch path
+# below allocates 5 intermediates of size O(B · S); this kernel writes the
+# output in a single pass.
+#
+# Real and imag planes are written as separate float buffers, then the caller
+# stitches them via ``torch.complex`` (free metadata view; PyTorch fuses the
+# stride pattern). This avoids needing complex-typed pointers in Triton, which
+# the HIP backend does not support directly.
+#
+# Limitations: float32 + n_qubits ≤ 32 (single int32 bit packing).  For n > 32
+# or float64 the engine falls back to the vectorized PyTorch path, which is
+# already memory-bound, not compute-bound.
+# ---------------------------------------------------------------------------
+
+if triton_mod is not None and triton_lang is not None:
+    tl = triton_lang
+
+    @triton_mod.jit
+    def _phase_encode_kernel(
+        phases_ptr,  # *fp32, shape (B, n_qubits)
+        out_ptr,  # *fp32, view-as-real of complex64 output: (B, 2·state_len)
+        n_qubits,
+        state_len,
+        norm_factor,  # 1/√2^n
+        BLOCK: tl.constexpr,
+    ):
+        pid_b = tl.program_id(0)
+        pid_s = tl.program_id(1)
+
+        s_offsets = pid_s * BLOCK + tl.arange(0, BLOCK)
+        s_mask = s_offsets < state_len
+
+        # φ(b) = Σ_k phases[k] · ((b >> k) & 1) — fused bit unpack + 
accumulate.
+        phi = tl.zeros([BLOCK], dtype=tl.float32)
+        for k in range(0, n_qubits):
+            bit_k = ((s_offsets >> k) & 1).to(tl.float32)
+            phase_k = tl.load(phases_ptr + pid_b * n_qubits + k)
+            phi += phase_k * bit_k
+
+        re = tl.cos(phi) * norm_factor
+        im = tl.sin(phi) * norm_factor
+
+        # Write interleaved into the complex64 buffer's float view: each
+        # output element occupies two adjacent floats (re, im). One kernel,
+        # one allocation; no separate planes that would need a final stitch.
+        base = pid_b * state_len * 2 + s_offsets * 2
+        tl.store(out_ptr + base, re, mask=s_mask)
+        tl.store(out_ptr + base + 1, im, mask=s_mask)
+
+else:  # pragma: no cover - non-Triton hosts use the PyTorch fallback
+    _phase_encode_kernel = None
+
+
+# Largest n the ZZ pair-matrix path will materialize before we refuse and
+# point the user at the loop fallback. State vector at n=20 is 16 MiB cf64;
+# pair matrix at n=20 is 1 MiB · 190 entries · 4 B = ~760 MiB — so this is the
+# right cutoff before pair_matrix dominates the AMD HBM budget.
+_IQP_PAIR_MATRIX_MAX_N = 20
+
+
 @dataclass
 class TritonAmdEngine:
-    """AMD backend implementing amplitude/angle/basis encoders."""
+    """AMD backend implementing amplitude/angle/basis/iqp/iqp-z/phase 
encoders."""
 
     device_id: int = 0
     precision: str = "float32"
 
+    # Per-engine cache of (n_qubits → bits table) keyed by (n, real_dtype).
+    # Avoids regenerating the (state_len, n_qubits) bit pattern on every call;
+    # the table is reused across batches for any encoder that needs it.
+    _bits_cache: dict = field(default_factory=dict, repr=False, compare=False)
+    # Cache of (n → upper-triangular pair index) for IQP-ZZ.
+    _pair_cache: dict = field(default_factory=dict, repr=False, compare=False)
+
     def __post_init__(self) -> None:
         p = self.precision.lower()
         if p in ("float32", "f32", "float"):
@@ -105,6 +179,18 @@ class TritonAmdEngine:
 
     def _to_2d(self, data: Any, *, dtype: Any) -> Any:
         torch_mod = self._require_torch()
+        # Fast path: caller already supplies a 2-D, contiguous, on-device,
+        # correctly-typed torch tensor (the common case for benchmarks and
+        # downstream pipelines). Skip ``as_tensor`` + ``contiguous`` work.
+        if (
+            isinstance(data, torch_mod.Tensor)
+            and data.ndim == 2
+            and data.dtype is dtype
+            and data.is_contiguous()
+            and data.device.type == "cuda"
+            and data.device.index == self.device_id
+        ):
+            return data
         x = torch_mod.as_tensor(data, device=self._device(), dtype=dtype)
         if x.ndim == 1:
             x = x.unsqueeze(0)
@@ -112,6 +198,37 @@ class TritonAmdEngine:
             raise ValueError(f"Expected 1D or 2D input, got {x.ndim}D.")
         return x.contiguous()
 
+    def _bits_table(self, num_qubits: int, real_dtype: Any) -> Any:
+        """Cached ``bits[b, k] = (b >> k) & 1`` table cast to ``real_dtype``.
+
+        Returned shape is ``(2^num_qubits, num_qubits)``. The same table is
+        reused by ``encode_angle``/``encode_iqp``/``encode_phase`` across
+        successive batches at the same ``num_qubits``.
+        """
+        torch_mod = self._require_torch()
+        key = (num_qubits, real_dtype)
+        cached = self._bits_cache.get(key)
+        if cached is not None:
+            return cached
+        device = torch_mod.device(self._device())
+        state_len = 1 << num_qubits
+        b_idx = torch_mod.arange(state_len, device=device, 
dtype=torch_mod.int64)
+        k_idx = torch_mod.arange(num_qubits, device=device, 
dtype=torch_mod.int64)
+        bits = ((b_idx.unsqueeze(1) >> k_idx) & 1).to(real_dtype).contiguous()
+        self._bits_cache[key] = bits
+        return bits
+
+    def _pair_indices(self, num_qubits: int) -> Any:
+        """Cached ``(n*(n-1)/2, 2)`` table of upper-triangular qubit pairs."""
+        torch_mod = self._require_torch()
+        cached = self._pair_cache.get(num_qubits)
+        if cached is not None:
+            return cached
+        device = torch_mod.device(self._device())
+        pairs = torch_mod.combinations(torch_mod.arange(num_qubits, 
device=device), r=2)
+        self._pair_cache[num_qubits] = pairs
+        return pairs
+
     def encode_amplitude(self, data: Any, num_qubits: int) -> Any:
         torch_mod = self._require_torch()
         x = self._to_2d(data, dtype=self._real_dtype())
@@ -125,13 +242,12 @@ class TritonAmdEngine:
         norms = torch_mod.linalg.vector_norm(x, dim=1, 
keepdim=True).clamp_min(1e-12)
         amp = x / norms
         if sample_size < state_len:
-            pad = torch_mod.zeros(
-                (batch, state_len - sample_size), device=amp.device, 
dtype=amp.dtype
-            )
-            amp = torch_mod.cat([amp, pad], dim=1)
-        return torch_mod.complex(amp, torch_mod.zeros_like(amp)).to(
-            self._complex_dtype()
-        )
+            # F.pad is a single fused op vs a separate zeros + cat.
+            amp = torch_mod.nn.functional.pad(amp, (0, state_len - 
sample_size))
+        # ``.to(complex_dtype)`` from a real tensor is one kernel that writes
+        # (real, 0) interleaved — strictly better than building a separate
+        # zeros tensor and combining via ``torch.complex(real, zeros)``.
+        return amp.to(self._complex_dtype())
 
     def encode_angle(self, data: Any, num_qubits: int) -> Any:
         torch_mod = self._require_torch()
@@ -143,21 +259,18 @@ class TritonAmdEngine:
                 f"Angle encoding expects sample size {num_qubits} 
(=num_qubits), got {width}."
             )
 
-        state_len = 1 << num_qubits
-        idx = torch_mod.arange(state_len, device=angles.device).reshape(1, 
state_len)
-        amp = torch_mod.ones((batch, state_len), device=angles.device, 
dtype=real_dtype)
-        for bit in range(num_qubits):
-            col = angles[:, bit].unsqueeze(1)
-            factor = torch_mod.where(
-                ((idx >> bit) & 1) == 1,
-                torch_mod.sin(col),
-                torch_mod.cos(col),
-            )
-            amp = amp * factor
+        bits = self._bits_table(num_qubits, real_dtype)  # (S, n) cached
 
-        return torch_mod.complex(amp, torch_mod.zeros_like(amp)).to(
-            self._complex_dtype()
-        )
+        # amp[batch, b] = prod_k (sin(θ_k) if bit_k else cos(θ_k))
+        # Closed-form vectorization: broadcast (B, 1, n) sin/cos against
+        # (1, S, n) bit pattern, gather via where, reduce-product over k.
+        # One allocation for the (B, S, n) workspace; the previous Python-level
+        # n-step loop allocated a fresh (B, S) tensor per iteration.
+        sin = torch_mod.sin(angles).unsqueeze(1)
+        cos = torch_mod.cos(angles).unsqueeze(1)
+        factor = torch_mod.where(bits.unsqueeze(0) > 0.5, sin, cos)
+        amp = factor.prod(dim=2)
+        return amp.to(self._complex_dtype())
 
     def encode_basis(self, data: Any, num_qubits: int) -> Any:
         torch_mod = self._require_torch()
@@ -179,22 +292,181 @@ class TritonAmdEngine:
             )
 
         batch = int(idx.numel())
+        complex_dtype = self._complex_dtype()
         out = torch_mod.zeros(
             (batch, state_len),
             device=idx.device,
-            dtype=self._complex_dtype(),
+            dtype=complex_dtype,
         )
         out.scatter_(
             1,
             idx.reshape(batch, 1),
-            torch_mod.ones(
-                (batch, 1),
-                device=idx.device,
-                dtype=self._complex_dtype(),
-            ),
+            torch_mod.ones((batch, 1), device=idx.device, dtype=complex_dtype),
         )
         return out
 
+    def _iqp_phase(
+        self,
+        params: Any,
+        num_qubits: int,
+        bits: Any,
+        *,
+        enable_zz: bool,
+    ) -> Any:
+        """Compute θ(x) = Σ x_i·data_i (+ Σ_{i<j} x_i x_j data_ij if ZZ).
+
+        Returns shape ``(batch, 2**num_qubits)`` in the real dtype.
+        """
+        torch_mod = self._require_torch()
+        n = num_qubits
+        z_params = params[:, :n]
+        # phase = z_params @ bits.T : (B, S)
+        phase = torch_mod.matmul(z_params, bits.T)
+        if enable_zz and n >= 2:
+            if n > _IQP_PAIR_MATRIX_MAX_N:
+                # Pair matrix is (S, n_pairs) — at n=20 that's already ~760 MiB
+                # in float32. Past this size, fall back to a per-pair loop.
+                # Slower but bounded memory; the workload itself is also
+                # impractical at this point (state vector alone is multi-GB).
+                pair_idx = n
+                zz_params = params
+                for i in range(n - 1):
+                    bi = bits[:, i]
+                    for j in range(i + 1, n):
+                        bj = bits[:, j]
+                        phase = phase + zz_params[:, pair_idx : pair_idx + 1] 
* (
+                            bi * bj
+                        ).unsqueeze(0)
+                        pair_idx += 1
+            else:
+                zz_params = params[:, n:]
+                pairs = self._pair_indices(n)
+                pair_matrix = bits[:, pairs[:, 0]] * bits[:, pairs[:, 1]]
+                phase = phase + torch_mod.matmul(zz_params, pair_matrix.T)
+        return phase
+
+    def encode_iqp(
+        self,
+        data: Any,
+        num_qubits: int,
+        *,
+        enable_zz: bool = True,
+    ) -> Any:
+        torch_mod = self._require_torch()
+        real_dtype = self._real_dtype()
+        params = self._to_2d(data, dtype=real_dtype)
+        batch, width = params.shape
+
+        n = num_qubits
+        expected = n + n * (n - 1) // 2 if enable_zz else n
+        if width != expected:
+            variant = "ZZ" if enable_zz else "Z-only"
+            raise ValueError(
+                f"IQP encoding ({variant}) expects {expected} parameters for 
{n} qubits, got {width}."
+            )
+
+        state_len = 1 << n
+        bits = self._bits_table(n, real_dtype)
+        phase = self._iqp_phase(params, n, bits, enable_zz=enable_zz)
+
+        # f[x] = exp(i·θ(x)). ``torch.complex(cos, sin)`` allocates a single
+        # contiguous complex tensor and is faster than writing into strided
+        # ``.real``/``.imag`` views of a separately-allocated complex buffer.
+        f = torch_mod.complex(torch_mod.cos(phase), torch_mod.sin(phase)).to(
+            self._complex_dtype()
+        )
+
+        # In-place n-stage Walsh-Hadamard butterfly. View ``f`` as
+        # (B, K, 2, stride) per stage and do (a, b) ← (a + b, a - b) using a
+        # single ``state_len/2``-sized scratch buffer instead of allocating
+        # two (lo+hi, lo-hi) buffers and concatenating them every stage.
+        if n > 0:
+            scratch = torch_mod.empty(
+                (batch, state_len // 2), device=f.device, dtype=f.dtype
+            )
+            for s in range(n):
+                stride = 1 << s
+                view = f.view(batch, state_len // (stride * 2), 2, stride)
+                a = view.select(2, 0)
+                b = view.select(2, 1)
+                scratch_view = scratch.view(batch, state_len // (stride * 2), 
stride)
+                torch_mod.sub(a, b, out=scratch_view)  # scratch ← a − b
+                a.add_(b)  # a ← a + b (in-place)
+                b.copy_(scratch_view)  # b ← (a − b) from scratch
+            f = f.view(batch, state_len)
+
+        f.mul_(1.0 / float(state_len))
+        return f
+
+    def _can_use_triton_phase_kernel(self, num_qubits: int) -> bool:
+        return (
+            _phase_encode_kernel is not None
+            and self.precision == "float32"
+            and 1 <= num_qubits <= 32
+        )
+
+    def _encode_phase_triton(self, phases: Any, num_qubits: int) -> Any:
+        """Triton-fused phase encoder for float32 / n ≤ 32.
+
+        One HIP kernel launch per (sample, output-tile) pair; fuses the
+        bit-table materialization + θ(b) accumulate + cos/sin + 1/√2^n scale
+        + complex-pack into a single pass that writes the output buffer
+        interleaved (re, im, re, im, …) — the native complex64 layout.
+        """
+        torch_mod = self._require_torch()
+        # ``_can_use_triton_phase_kernel`` already guards on Triton being
+        # available; this assertion narrows the type for the type checker.
+        assert _phase_encode_kernel is not None
+        batch = phases.shape[0]
+        state_len = 1 << num_qubits
+
+        # Allocate the complex output once; pass its real-view as a flat
+        # (B, 2·S) float32 buffer to the kernel for direct interleaved writes.
+        out = torch_mod.empty(
+            (batch, state_len),
+            device=phases.device,
+            dtype=torch_mod.complex64,
+        )
+        out_real_view = torch_mod.view_as_real(out).view(batch, state_len * 2)
+
+        norm = math.pow(math.sqrt(0.5), num_qubits)
+        BLOCK = 256
+        grid = (batch, (state_len + BLOCK - 1) // BLOCK)
+        _phase_encode_kernel[grid](
+            phases,
+            out_real_view,
+            num_qubits,
+            state_len,
+            norm,
+            BLOCK=BLOCK,
+        )
+        return out
+
+    def encode_phase(self, data: Any, num_qubits: int) -> Any:
+        torch_mod = self._require_torch()
+        real_dtype = self._real_dtype()
+        phases = self._to_2d(data, dtype=real_dtype)
+        batch, width = phases.shape
+        if width != num_qubits:
+            raise ValueError(
+                f"Phase encoding expects sample size {num_qubits} 
(=num_qubits), got {width}."
+            )
+
+        if self._can_use_triton_phase_kernel(num_qubits):
+            return self._encode_phase_triton(phases, num_qubits)
+
+        # Fallback: vectorized PyTorch path (float64 or n > 32).
+        bits = self._bits_table(num_qubits, real_dtype)
+        phi = torch_mod.matmul(phases, bits.T)
+        norm = math.pow(math.sqrt(0.5), num_qubits)
+        # ``torch.complex(re, im)`` writes a contiguous interleaved buffer in
+        # one allocation — faster than ``empty(complex)`` followed by strided
+        # writes into ``.real``/``.imag``.
+        return torch_mod.complex(
+            torch_mod.cos(phi).mul_(norm),
+            torch_mod.sin(phi).mul_(norm),
+        ).to(self._complex_dtype())
+
     def encode(
         self,
         data: Any,
@@ -210,6 +482,13 @@ class TritonAmdEngine:
             return self.encode_angle(data, num_qubits)
         if method == "basis":
             return self.encode_basis(data, num_qubits)
+        if method == "iqp":
+            return self.encode_iqp(data, num_qubits, enable_zz=True)
+        if method == "iqp-z":
+            return self.encode_iqp(data, num_qubits, enable_zz=False)
+        if method == "phase":
+            return self.encode_phase(data, num_qubits)
         raise ValueError(
-            f"Unsupported encoding '{encoding_method}'. triton_amd supports 
amplitude, angle, basis."
+            f"Unsupported encoding '{encoding_method}'. "
+            "triton_amd supports amplitude, angle, basis, iqp, iqp-z, phase."
         )
diff --git a/qdp/qdp-python/tests/test_triton_amd_backend.py 
b/qdp/qdp-python/tests/test_triton_amd_backend.py
index 1263f65e9..ff3341568 100644
--- a/qdp/qdp-python/tests/test_triton_amd_backend.py
+++ b/qdp/qdp-python/tests/test_triton_amd_backend.py
@@ -14,9 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import math
+
 import pytest
 import torch
 from qumat_qdp import QdpEngine, is_triton_amd_available
+from qumat_qdp.torch_ref import iqp_encode as _torch_ref_iqp
 from qumat_qdp.triton_amd import TritonAmdEngine
 
 
@@ -50,6 +53,21 @@ def _torch_angle_ref(angles: torch.Tensor, num_qubits: int) 
-> torch.Tensor:
     return torch.complex(amp, torch.zeros_like(amp))
 
 
+def _torch_phase_ref(phases: torch.Tensor, num_qubits: int) -> torch.Tensor:
+    real_dtype = phases.dtype
+    batch = phases.shape[0]
+    state_len = 1 << num_qubits
+    idx = torch.arange(state_len, device=phases.device, dtype=torch.int64)
+    bits = (
+        (idx.unsqueeze(1) >> torch.arange(num_qubits, device=phases.device)) & 
1
+    ).to(real_dtype)
+    phi = phases @ bits.T
+    norm = math.pow(math.sqrt(0.5), num_qubits)
+    out = torch.complex(torch.cos(phi) * norm, torch.sin(phi) * norm)
+    assert out.shape == (batch, state_len)
+    return out
+
+
 def _torch_basis_ref(idx: torch.Tensor, num_qubits: int) -> torch.Tensor:
     idx = idx.to(torch.int64)
     batch = idx.numel()
@@ -187,3 +205,160 @@ def test_unified_router_contract_returns_torch_tensor() 
-> None:
     assert isinstance(qt, torch.Tensor)
     assert qt.shape == (2, 4)
     assert qt.dtype == torch.complex64
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_full_parity_with_torch_ref() -> None:
+    n = 4
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    data = torch.randn(3, n + n * (n - 1) // 2, device="cuda", 
dtype=torch.float32)
+    got = _as_torch(engine.encode(data, n, "iqp"))
+    ref = _torch_ref_iqp(data, n, enable_zz=True)
+    assert got.shape == ref.shape
+    assert got.dtype == torch.complex64
+    assert torch.allclose(got, ref, atol=2e-5, rtol=2e-5)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_z_only_parity_with_torch_ref() -> None:
+    n = 5
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    data = torch.randn(2, n, device="cuda", dtype=torch.float32)
+    got = _as_torch(engine.encode(data, n, "iqp-z"))
+    ref = _torch_ref_iqp(data, n, enable_zz=False)
+    assert got.shape == ref.shape
+    assert torch.allclose(got, ref, atol=2e-5, rtol=2e-5)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_param_count_validation() -> None:
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    # ZZ variant for n=4 expects 4 + 6 = 10 params; pass 9.
+    bad = torch.randn(2, 9, device="cuda", dtype=torch.float32)
+    with pytest.raises(ValueError, match="expects 10 parameters"):
+        engine.encode(bad, 4, "iqp")
+    # Z-only variant for n=4 expects 4 params; pass 5.
+    bad_z = torch.randn(2, 5, device="cuda", dtype=torch.float32)
+    with pytest.raises(ValueError, match="expects 4 parameters"):
+        engine.encode(bad_z, 4, "iqp-z")
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_normalization_unit_norm() -> None:
+    """IQP output is a normalized state vector: Σ|amp|² ≈ 1."""
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    n = 6
+    data = torch.randn(4, n + n * (n - 1) // 2, device="cuda", 
dtype=torch.float32)
+    got = _as_torch(engine.encode(data, n, "iqp"))
+    norms = (got.abs() ** 2).sum(dim=1)
+    assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4, rtol=1e-4)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_parity() -> None:
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    phases = torch.randn(3, 5, device="cuda", dtype=torch.float32)
+    got = _as_torch(engine.encode(phases, 5, "phase"))
+    ref = _torch_phase_ref(phases, 5)
+    assert got.shape == ref.shape
+    assert got.dtype == torch.complex64
+    assert torch.allclose(got, ref, atol=1e-5, rtol=1e-5)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_normalization_unit_norm() -> None:
+    """Phase output is a uniform-magnitude product state: Σ|amp|² ≈ 1."""
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    n = 6
+    phases = torch.randn(4, n, device="cuda", dtype=torch.float32)
+    got = _as_torch(engine.encode(phases, n, "phase"))
+    norms = (got.abs() ** 2).sum(dim=1)
+    assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4, rtol=1e-4)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_param_count_validation() -> None:
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    bad = torch.randn(2, 3, device="cuda", dtype=torch.float32)
+    with pytest.raises(ValueError, match="sample size 4"):
+        engine.encode(bad, 4, "phase")
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_phase_float64_precision_contract() -> None:
+    engine = TritonAmdEngine(device_id=0, precision="float64")
+    phases = torch.randn(2, 4, device="cuda", dtype=torch.float64)
+    got = _as_torch(engine.encode(phases, 4, "phase"))
+    ref = _torch_phase_ref(phases, 4).to(torch.complex128)
+    assert got.dtype == torch.complex128
+    assert torch.allclose(got, ref, atol=1e-12, rtol=1e-12)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_iqp_float64_precision_contract() -> None:
+    """Float64 IQP matches torch_ref bit-close (covers the dtype contract)."""
+    engine = TritonAmdEngine(device_id=0, precision="float64")
+    n = 4
+    data = torch.randn(3, n + n * (n - 1) // 2, device="cuda", 
dtype=torch.float64)
+    got = _as_torch(engine.encode(data, n, "iqp"))
+    ref = _torch_ref_iqp(data, n, enable_zz=True).to(torch.complex128)
+    assert got.dtype == torch.complex128
+    assert torch.allclose(got, ref, atol=1e-12, rtol=1e-12)
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_triton_amd_unsupported_method_message_lists_all() -> None:
+    engine = TritonAmdEngine(device_id=0, precision="float32")
+    with pytest.raises(ValueError) as excinfo:
+        engine.encode(torch.zeros(1, 4, device="cuda"), 2, "no-such-method")
+    msg = str(excinfo.value)
+    for name in ("amplitude", "angle", "basis", "iqp", "iqp-z", "phase"):
+        assert name in msg
+
+
[email protected](
+    not is_triton_amd_available(), reason="Triton AMD backend unavailable"
+)
[email protected]
+def test_unified_router_iqp_and_phase_routes() -> None:
+    """The public QdpEngine(backend='amd') router accepts iqp/iqp-z/phase 
too."""
+    router = QdpEngine(backend="amd", device_id=0, precision="float32")
+    n = 3
+    data_iqp = torch.randn(2, n + n * (n - 1) // 2, device="cuda", 
dtype=torch.float32)
+    qt = router.encode(data_iqp, n, "iqp")
+    assert isinstance(qt, torch.Tensor)
+    assert qt.shape == (2, 1 << n)
+    qt_z = router.encode(torch.randn(2, n, device="cuda"), n, "iqp-z")
+    assert qt_z.shape == (2, 1 << n)
+    qt_p = router.encode(torch.randn(2, n, device="cuda"), n, "phase")
+    assert qt_p.shape == (2, 1 << n)

Reply via email to