This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git


The following commit(s) were added to refs/heads/main by this push:
     new 63ef825bc [QDP] Add basis encoding (#839)
63ef825bc is described below

commit 63ef825bc350675d7f5775c58ada3a6166b74b06
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Jan 16 18:30:03 2026 +0800

    [QDP] Add basis encoding (#839)
    
    * Add basis encoding
    
    * Apply review comments
---
 qdp/qdp-core/src/gpu/encodings/basis.rs | 291 ++++++++++++++++++++++++++++++--
 qdp/qdp-kernels/build.rs                |   4 +-
 qdp/qdp-kernels/src/amplitude.cu        |   1 -
 qdp/qdp-kernels/src/basis.cu            | 165 ++++++++++++++++++
 qdp/qdp-kernels/src/lib.rs              |  53 +++++-
 testing/qdp/test_bindings.py            | 107 ++++++++++++
 6 files changed, 607 insertions(+), 14 deletions(-)

diff --git a/qdp/qdp-core/src/gpu/encodings/basis.rs 
b/qdp/qdp-core/src/gpu/encodings/basis.rs
index fec482174..f93ca5c84 100644
--- a/qdp/qdp-core/src/gpu/encodings/basis.rs
+++ b/qdp/qdp-core/src/gpu/encodings/basis.rs
@@ -14,8 +14,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-// Basis encoding (placeholder)
-// TODO: Map integers to computational basis states
+// Basis encoding: map integers to computational basis states
 
 use super::QuantumEncoder;
 use crate::error::{MahoutError, Result};
@@ -23,20 +22,227 @@ use crate::gpu::memory::GpuStateVector;
 use cudarc::driver::CudaDevice;
 use std::sync::Arc;
 
-/// Basis encoding (not implemented)
-/// TODO: Map integers to basis states (e.g., 3 → |011⟩)
+#[cfg(target_os = "linux")]
+use crate::gpu::memory::map_allocation_error;
+#[cfg(target_os = "linux")]
+use cudarc::driver::DevicePtr;
+#[cfg(target_os = "linux")]
+use std::ffi::c_void;
+
+/// Basis encoding: maps an integer index to a computational basis state.
+///
+/// For n qubits, maps integer i (0 ≤ i < 2^n) to |i⟩, where:
+/// - state[i] = 1.0 + 0.0i
+/// - state[j] = 0.0 + 0.0i for all j ≠ i
+///
+/// Example: index 3 with 3 qubits → |011⟩ (binary representation of 3)
+///
+/// Input format:
+/// - Single encoding: data = [index] (single f64 representing the basis index)
+/// - Batch encoding: data = [idx0, idx1, ..., idxN] (one index per sample)
 pub struct BasisEncoder;
 
 impl QuantumEncoder for BasisEncoder {
     fn encode(
         &self,
-        _device: &Arc<CudaDevice>,
-        _data: &[f64],
-        _num_qubits: usize,
+        #[cfg(target_os = "linux")] device: &Arc<CudaDevice>,
+        #[cfg(not(target_os = "linux"))] _device: &Arc<CudaDevice>,
+        data: &[f64],
+        num_qubits: usize,
     ) -> Result<GpuStateVector> {
-        Err(MahoutError::InvalidInput(
-            "Basis encoding not yet implemented. Use 'amplitude' encoding for 
now.".to_string(),
-        ))
+        // Validate basic input constraints
+        self.validate_input(data, num_qubits)?;
+
+        // For basis encoding, we expect exactly one value: the basis index
+        if data.len() != 1 {
+            return Err(MahoutError::InvalidInput(format!(
+                "Basis encoding expects exactly 1 value (the basis index), got 
{}",
+                data.len()
+            )));
+        }
+
+        let state_len = 1 << num_qubits;
+
+        #[cfg(target_os = "linux")]
+        {
+            // Convert and validate the basis index
+            let basis_index = Self::validate_basis_index(data[0], state_len)?;
+            // Allocate GPU state vector
+            let state_vector = {
+                crate::profile_scope!("GPU::Alloc");
+                GpuStateVector::new(device, num_qubits)?
+            };
+
+            let state_ptr = state_vector.ptr_f64().ok_or_else(|| {
+                MahoutError::InvalidInput(
+                    "State vector precision mismatch (expected float64 
buffer)".to_string(),
+                )
+            })?;
+
+            // Launch basis encoding kernel
+            let ret = {
+                crate::profile_scope!("GPU::KernelLaunch");
+                unsafe {
+                    qdp_kernels::launch_basis_encode(
+                        basis_index,
+                        state_ptr as *mut c_void,
+                        state_len,
+                        std::ptr::null_mut(), // default stream
+                    )
+                }
+            };
+
+            if ret != 0 {
+                return Err(MahoutError::KernelLaunch(format!(
+                    "Basis encoding kernel failed with CUDA error code: {} 
({})",
+                    ret,
+                    cuda_error_to_string(ret)
+                )));
+            }
+
+            {
+                crate::profile_scope!("GPU::Synchronize");
+                device.synchronize().map_err(|e| {
+                    MahoutError::Cuda(format!("CUDA device synchronize failed: 
{:?}", e))
+                })?;
+            }
+
+            Ok(state_vector)
+        }
+
+        #[cfg(not(target_os = "linux"))]
+        {
+            Err(MahoutError::Cuda(
+                "CUDA unavailable (non-Linux)".to_string(),
+            ))
+        }
+    }
+
+    /// Encode multiple basis indices in a single GPU allocation and kernel 
launch
+    #[cfg(target_os = "linux")]
+    fn encode_batch(
+        &self,
+        device: &Arc<CudaDevice>,
+        batch_data: &[f64],
+        num_samples: usize,
+        sample_size: usize,
+        num_qubits: usize,
+    ) -> Result<GpuStateVector> {
+        crate::profile_scope!("BasisEncoder::encode_batch");
+
+        // For basis encoding, each sample should have exactly 1 value (the 
index)
+        if sample_size != 1 {
+            return Err(MahoutError::InvalidInput(format!(
+                "Basis encoding expects sample_size=1 (one index per sample), 
got {}",
+                sample_size
+            )));
+        }
+
+        if batch_data.len() != num_samples {
+            return Err(MahoutError::InvalidInput(format!(
+                "Batch data length {} doesn't match num_samples {}",
+                batch_data.len(),
+                num_samples
+            )));
+        }
+
+        if num_qubits == 0 || num_qubits > 30 {
+            return Err(MahoutError::InvalidInput(format!(
+                "Number of qubits {} must be between 1 and 30",
+                num_qubits
+            )));
+        }
+
+        let state_len = 1 << num_qubits;
+
+        // Convert and validate all basis indices
+        let basis_indices: Vec<usize> = batch_data
+            .iter()
+            .enumerate()
+            .map(|(i, &val)| {
+                Self::validate_basis_index(val, state_len)
+                    .map_err(|e| MahoutError::InvalidInput(format!("Sample {}: 
{}", i, e)))
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        // Allocate batch state vector
+        let batch_state_vector = {
+            crate::profile_scope!("GPU::AllocBatch");
+            GpuStateVector::new_batch(device, num_samples, num_qubits)?
+        };
+
+        // Upload basis indices to GPU
+        let indices_gpu = {
+            crate::profile_scope!("GPU::H2D_Indices");
+            device.htod_sync_copy(&basis_indices).map_err(|e| {
+                map_allocation_error(
+                    num_samples * std::mem::size_of::<usize>(),
+                    "basis indices upload",
+                    Some(num_qubits),
+                    e,
+                )
+            })?
+        };
+
+        let state_ptr = batch_state_vector.ptr_f64().ok_or_else(|| {
+            MahoutError::InvalidInput(
+                "Batch state vector precision mismatch (expected float64 
buffer)".to_string(),
+            )
+        })?;
+
+        // Launch batch kernel
+        {
+            crate::profile_scope!("GPU::BatchKernelLaunch");
+            let ret = unsafe {
+                qdp_kernels::launch_basis_encode_batch(
+                    *indices_gpu.device_ptr() as *const usize,
+                    state_ptr as *mut c_void,
+                    num_samples,
+                    state_len,
+                    num_qubits as u32,
+                    std::ptr::null_mut(), // default stream
+                )
+            };
+
+            if ret != 0 {
+                return Err(MahoutError::KernelLaunch(format!(
+                    "Batch basis encoding kernel failed: {} ({})",
+                    ret,
+                    cuda_error_to_string(ret)
+                )));
+            }
+        }
+
+        // Synchronize
+        {
+            crate::profile_scope!("GPU::Synchronize");
+            device
+                .synchronize()
+                .map_err(|e| MahoutError::Cuda(format!("Sync failed: {:?}", 
e)))?;
+        }
+
+        Ok(batch_state_vector)
+    }
+
+    fn validate_input(&self, data: &[f64], num_qubits: usize) -> Result<()> {
+        // Basic validation: qubits and data availability
+        if num_qubits == 0 {
+            return Err(MahoutError::InvalidInput(
+                "Number of qubits must be at least 1".to_string(),
+            ));
+        }
+        if num_qubits > 30 {
+            return Err(MahoutError::InvalidInput(format!(
+                "Number of qubits {} exceeds practical limit of 30",
+                num_qubits
+            )));
+        }
+        if data.is_empty() {
+            return Err(MahoutError::InvalidInput(
+                "Input data cannot be empty".to_string(),
+            ));
+        }
+        Ok(())
     }
 
     fn name(&self) -> &'static str {
@@ -44,6 +250,69 @@ impl QuantumEncoder for BasisEncoder {
     }
 
     fn description(&self) -> &'static str {
-        "Basis encoding (not implemented)"
+        "Basis encoding: maps integers to computational basis states"
+    }
+}
+
+impl BasisEncoder {
+    /// Validate and convert a f64 value to a valid basis index
+    fn validate_basis_index(value: f64, state_len: usize) -> Result<usize> {
+        // Check for non-finite values
+        if !value.is_finite() {
+            return Err(MahoutError::InvalidInput(
+                "Basis index must be a finite number".to_string(),
+            ));
+        }
+
+        // Check for negative values
+        if value < 0.0 {
+            return Err(MahoutError::InvalidInput(format!(
+                "Basis index must be non-negative, got {}",
+                value
+            )));
+        }
+
+        // Check if the value is an integer
+        if value.fract() != 0.0 {
+            return Err(MahoutError::InvalidInput(format!(
+                "Basis index must be an integer, got {} (hint: use .round() if 
needed)",
+                value
+            )));
+        }
+
+        // Convert to usize
+        let index = value as usize;
+
+        // Check bounds
+        if index >= state_len {
+            return Err(MahoutError::InvalidInput(format!(
+                "Basis index {} exceeds state vector size {} (max index: {})",
+                index,
+                state_len,
+                state_len - 1
+            )));
+        }
+
+        Ok(index)
+    }
+}
+
+/// Convert CUDA error code to human-readable string
+#[cfg(target_os = "linux")]
+fn cuda_error_to_string(code: i32) -> &'static str {
+    match code {
+        0 => "cudaSuccess",
+        1 => "cudaErrorInvalidValue",
+        2 => "cudaErrorMemoryAllocation",
+        3 => "cudaErrorInitializationError",
+        4 => "cudaErrorLaunchFailure",
+        6 => "cudaErrorInvalidDevice",
+        8 => "cudaErrorInvalidConfiguration",
+        11 => "cudaErrorInvalidHostPointer",
+        12 => "cudaErrorInvalidDevicePointer",
+        17 => "cudaErrorInvalidMemcpyDirection",
+        30 => "cudaErrorUnknown",
+        999 => "CUDA unavailable (non-Linux stub)",
+        _ => "Unknown CUDA error",
     }
 }
diff --git a/qdp/qdp-kernels/build.rs b/qdp/qdp-kernels/build.rs
index d25a88d9e..c845580a9 100644
--- a/qdp/qdp-kernels/build.rs
+++ b/qdp/qdp-kernels/build.rs
@@ -27,8 +27,9 @@ use std::env;
 use std::process::Command;
 
 fn main() {
-    // Tell Cargo to rerun this script if the kernel source changes
+    // Tell Cargo to rerun this script if the kernel sources change
     println!("cargo:rerun-if-changed=src/amplitude.cu");
+    println!("cargo:rerun-if-changed=src/basis.cu");
 
     // Check if CUDA is available by looking for nvcc
     let has_cuda = Command::new("nvcc").arg("--version").output().is_ok();
@@ -81,5 +82,6 @@ fn main() {
         // .flag("-gencode")
         // .flag("arch=compute_89,code=sm_89")
         .file("src/amplitude.cu")
+        .file("src/basis.cu")
         .compile("kernels");
 }
diff --git a/qdp/qdp-kernels/src/amplitude.cu b/qdp/qdp-kernels/src/amplitude.cu
index 7cf94ce92..98e96bf32 100644
--- a/qdp/qdp-kernels/src/amplitude.cu
+++ b/qdp/qdp-kernels/src/amplitude.cu
@@ -549,7 +549,6 @@ int convert_state_to_float(
 
 // TODO: Future encoding methods:
 // - launch_angle_encode (angle encoding)
-// - launch_basis_encode (basis encoding)
 // - launch_iqp_encode (IQP encoding)
 
 } // extern "C"
diff --git a/qdp/qdp-kernels/src/basis.cu b/qdp/qdp-kernels/src/basis.cu
new file mode 100644
index 000000000..247bfb1aa
--- /dev/null
+++ b/qdp/qdp-kernels/src/basis.cu
@@ -0,0 +1,165 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Basis Encoding CUDA Kernels
+//
+// Maps integer indices to computational basis states.
+// For index i with n qubits: state[i] = 1.0, all others = 0.0
+// Example: index=3 with 3 qubits → |011⟩ (state[3] = 1.0)
+
+#include <cuda_runtime.h>
+#include <cuComplex.h>
+
+/// Single sample basis encoding kernel
+///
+/// Sets state[basis_index] = 1.0 + 0.0i, all others = 0.0 + 0.0i
+__global__ void basis_encode_kernel(
+    size_t basis_index,
+    cuDoubleComplex* __restrict__ state,
+    size_t state_len
+) {
+    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx >= state_len) return;
+
+    if (idx == basis_index) {
+        state[idx] = make_cuDoubleComplex(1.0, 0.0);
+    } else {
+        state[idx] = make_cuDoubleComplex(0.0, 0.0);
+    }
+}
+
+/// Batch basis encoding kernel
+///
+/// Each sample has its own basis index, resulting in independent basis states.
+/// Memory layout:
+/// - basis_indices: [idx0, idx1, ..., idxN]
+/// - state_batch: [sample0_state | sample1_state | ... | sampleN_state]
+__global__ void basis_encode_batch_kernel(
+    const size_t* __restrict__ basis_indices,
+    cuDoubleComplex* __restrict__ state_batch,
+    size_t num_samples,
+    size_t state_len,
+    unsigned int num_qubits
+) {
+    // Grid-stride loop over all elements across all samples
+    const size_t total_elements = num_samples * state_len;
+    const size_t stride = gridDim.x * blockDim.x;
+    const size_t state_mask = state_len - 1;
+
+    for (size_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+         global_idx < total_elements;
+         global_idx += stride) {
+        // Decompose into (sample_idx, element_idx)
+        // state_len = 2^num_qubits, so division/modulo can use shift/mask
+        const size_t sample_idx = global_idx >> num_qubits;
+        const size_t element_idx = global_idx & state_mask;
+
+        // Get basis index for this sample
+        const size_t basis_index = basis_indices[sample_idx];
+
+        // Set amplitude: 1.0 at basis_index, 0.0 elsewhere
+        if (element_idx == basis_index) {
+            state_batch[global_idx] = make_cuDoubleComplex(1.0, 0.0);
+        } else {
+            state_batch[global_idx] = make_cuDoubleComplex(0.0, 0.0);
+        }
+    }
+}
+
+extern "C" {
+
+/// Launch basis encoding kernel
+///
+/// # Arguments
+/// * basis_index - The computational basis state index (0 to state_len-1)
+/// * state_d - Device pointer to output state vector
+/// * state_len - Target state vector size (2^num_qubits)
+/// * stream - CUDA stream for async execution (nullptr = default stream)
+///
+/// # Returns
+/// CUDA error code (0 = cudaSuccess)
+int launch_basis_encode(
+    size_t basis_index,
+    void* state_d,
+    size_t state_len,
+    cudaStream_t stream
+) {
+    if (state_len == 0) {
+        return cudaErrorInvalidValue;
+    }
+
+    if (basis_index >= state_len) {
+        return cudaErrorInvalidValue;
+    }
+
+    cuDoubleComplex* state_complex_d = static_cast<cuDoubleComplex*>(state_d);
+
+    const int blockSize = 256;
+    const int gridSize = (state_len + blockSize - 1) / blockSize;
+
+    basis_encode_kernel<<<gridSize, blockSize, 0, stream>>>(
+        basis_index,
+        state_complex_d,
+        state_len
+    );
+
+    return (int)cudaGetLastError();
+}
+
+/// Launch batch basis encoding kernel
+///
+/// # Arguments
+/// * basis_indices_d - Device pointer to array of basis indices (one per 
sample)
+/// * state_batch_d - Device pointer to output batch state vectors
+/// * num_samples - Number of samples in batch
+/// * state_len - State vector size per sample (2^num_qubits)
+/// * num_qubits - Number of qubits (for bit-shift optimization)
+/// * stream - CUDA stream for async execution
+///
+/// # Returns
+/// CUDA error code (0 = cudaSuccess)
+int launch_basis_encode_batch(
+    const size_t* basis_indices_d,
+    void* state_batch_d,
+    size_t num_samples,
+    size_t state_len,
+    unsigned int num_qubits,
+    cudaStream_t stream
+) {
+    if (num_samples == 0 || state_len == 0) {
+        return cudaErrorInvalidValue;
+    }
+
+    cuDoubleComplex* state_complex_d = 
static_cast<cuDoubleComplex*>(state_batch_d);
+
+    const int blockSize = 256;
+    const size_t total_elements = num_samples * state_len;
+    const size_t blocks_needed = (total_elements + blockSize - 1) / blockSize;
+    const size_t max_blocks = 2048;
+    const size_t gridSize = (blocks_needed < max_blocks) ? blocks_needed : 
max_blocks;
+
+    basis_encode_batch_kernel<<<gridSize, blockSize, 0, stream>>>(
+        basis_indices_d,
+        state_complex_d,
+        num_samples,
+        state_len,
+        num_qubits
+    );
+
+    return (int)cudaGetLastError();
+}
+
+} // extern "C"
diff --git a/qdp/qdp-kernels/src/lib.rs b/qdp/qdp-kernels/src/lib.rs
index 4eda08696..536dc1df1 100644
--- a/qdp/qdp-kernels/src/lib.rs
+++ b/qdp/qdp-kernels/src/lib.rs
@@ -135,7 +135,34 @@ unsafe extern "C" {
         stream: *mut c_void,
     ) -> i32;
 
-    // TODO: launch_angle_encode, launch_basis_encode
+    /// Launch basis encoding kernel
+    /// Maps an integer index to a computational basis state.
+    /// Returns CUDA error code (0 = success)
+    ///
+    /// # Safety
+    /// Requires valid GPU pointer, must sync before freeing
+    pub fn launch_basis_encode(
+        basis_index: usize,
+        state_d: *mut c_void,
+        state_len: usize,
+        stream: *mut c_void,
+    ) -> i32;
+
+    /// Launch batch basis encoding kernel
+    /// Returns CUDA error code (0 = success)
+    ///
+    /// # Safety
+    /// Requires valid GPU pointers, must sync before freeing
+    pub fn launch_basis_encode_batch(
+        basis_indices_d: *const usize,
+        state_batch_d: *mut c_void,
+        num_samples: usize,
+        state_len: usize,
+        num_qubits: u32,
+        stream: *mut c_void,
+    ) -> i32;
+
+    // TODO: launch_angle_encode
 }
 
 // Dummy implementation for non-Linux (allows compilation)
@@ -198,3 +225,27 @@ pub extern "C" fn convert_state_to_float(
 ) -> i32 {
     999
 }
+
+#[cfg(not(target_os = "linux"))]
+#[unsafe(no_mangle)]
+pub extern "C" fn launch_basis_encode(
+    _basis_index: usize,
+    _state_d: *mut c_void,
+    _state_len: usize,
+    _stream: *mut c_void,
+) -> i32 {
+    999
+}
+
+#[cfg(not(target_os = "linux"))]
+#[unsafe(no_mangle)]
+pub extern "C" fn launch_basis_encode_batch(
+    _basis_indices_d: *const usize,
+    _state_batch_d: *mut c_void,
+    _num_samples: usize,
+    _state_len: usize,
+    _num_qubits: u32,
+    _stream: *mut c_void,
+) -> i32 {
+    999
+}
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index b4ba087ab..8a3531c3e 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -222,3 +222,110 @@ def test_encode_errors():
     gpu_tensor = torch.tensor([1.0, 2.0], device="cuda:0")
     with pytest.raises(RuntimeError, match="Only CPU tensors are currently 
supported"):
         engine.encode(gpu_tensor, 1, "amplitude")
+
+
[email protected]
+def test_basis_encode_basic():
+    """Test basic basis encoding (requires GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Encode basis state |0⟩ (index 0 with 2 qubits)
+    qtensor = engine.encode([0.0], 2, "basis")
+    torch_tensor = torch.from_dlpack(qtensor)
+
+    assert torch_tensor.is_cuda
+    assert torch_tensor.shape == (1, 4)  # 2^2 = 4 amplitudes
+
+    # |0⟩ = [1, 0, 0, 0]
+    expected = torch.tensor([[1.0 + 0j, 0.0 + 0j, 0.0 + 0j, 0.0 + 0j]], 
device="cuda:0")
+    assert torch.allclose(torch_tensor, expected.to(torch_tensor.dtype))
+
+
[email protected]
+def test_basis_encode_nonzero_index():
+    """Test basis encoding with non-zero index (requires GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Encode basis state |3⟩ = |11⟩ (index 3 with 2 qubits)
+    qtensor = engine.encode([3.0], 2, "basis")
+    torch_tensor = torch.from_dlpack(qtensor)
+
+    # |3⟩ = [0, 0, 0, 1]
+    expected = torch.tensor([[0.0 + 0j, 0.0 + 0j, 0.0 + 0j, 1.0 + 0j]], 
device="cuda:0")
+    assert torch.allclose(torch_tensor, expected.to(torch_tensor.dtype))
+
+
[email protected]
+def test_basis_encode_3_qubits():
+    """Test basis encoding with 3 qubits (requires GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Encode basis state |5⟩ = |101⟩ (index 5 with 3 qubits)
+    qtensor = engine.encode([5.0], 3, "basis")
+    torch_tensor = torch.from_dlpack(qtensor)
+
+    assert torch_tensor.shape == (1, 8)  # 2^3 = 8 amplitudes
+
+    # |5⟩ should have amplitude 1 at index 5
+    # Check that only index 5 is non-zero
+    host_tensor = torch_tensor.cpu().squeeze()
+    assert host_tensor[5].real == 1.0
+    assert host_tensor[5].imag == 0.0
+    for i in range(8):
+        if i != 5:
+            assert host_tensor[i].real == 0.0
+            assert host_tensor[i].imag == 0.0
+
+
[email protected]
+def test_basis_encode_errors():
+    """Test error handling for basis encoding (requires GPU)."""
+    pytest.importorskip("torch")
+    import torch
+    from _qdp import QdpEngine
+
+    if not torch.cuda.is_available():
+        pytest.skip("GPU required for QdpEngine")
+
+    engine = QdpEngine(0)
+
+    # Test index out of bounds (2^2 = 4, so max index is 3)
+    with pytest.raises(RuntimeError, match="exceeds state vector size"):
+        engine.encode([4.0], 2, "basis")
+
+    # Test negative index
+    with pytest.raises(RuntimeError, match="non-negative"):
+        engine.encode([-1.0], 2, "basis")
+
+    # Test non-integer index
+    with pytest.raises(RuntimeError, match="integer"):
+        engine.encode([1.5], 2, "basis")
+
+    # Test empty input
+    with pytest.raises(RuntimeError, match="empty"):
+        engine.encode([], 2, "basis")
+
+    # Test multiple values (basis expects exactly 1)
+    with pytest.raises(RuntimeError, match="expects exactly 1"):
+        engine.encode([0.0, 1.0], 2, "basis")

Reply via email to