This is an automated email from the ASF dual-hosted git repository.
hcr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/mahout.git
The following commit(s) were added to refs/heads/main by this push:
new 67e48d530 [QDP] PyTorch input format support (#815)
67e48d530 is described below
commit 67e48d530ba107dea7c000daae220988da126c6f
Author: Jie-Kai Chang <[email protected]>
AuthorDate: Sat Jan 17 11:58:13 2026 +0800
[QDP] PyTorch input format support (#815)
* PyTorch input format support
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* update github ci
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* remove github workflow
Signed-off-by: 400Ping <[email protected]>
* fix conflict
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* fix
Signed-off-by: 400Ping <[email protected]>
* chore: trigger ci
* update cargo.lock
Signed-off-by: 400Ping <[email protected]>
* add test bindings
Signed-off-by: 400Ping <[email protected]>
---------
Signed-off-by: 400Ping <[email protected]>
Signed-off-by: 400Ping <[email protected]>
---
qdp/Cargo.lock | 316 +++++++++++++++++++++++++++++++++++++-
qdp/Cargo.toml | 2 +
qdp/docs/readers/README.md | 16 +-
qdp/qdp-core/Cargo.toml | 2 +
qdp/qdp-core/src/io.rs | 14 ++
qdp/qdp-core/src/lib.rs | 34 ++++
qdp/qdp-core/src/readers/mod.rs | 4 +
qdp/qdp-core/src/readers/torch.rs | 182 ++++++++++++++++++++++
qdp/qdp-core/tests/torch_io.rs | 62 ++++++++
qdp/qdp-python/Cargo.toml | 1 +
qdp/qdp-python/src/lib.rs | 12 +-
testing/qdp/test_bindings.py | 35 +++++
12 files changed, 664 insertions(+), 16 deletions(-)
diff --git a/qdp/Cargo.lock b/qdp/Cargo.lock
index e175e1f60..c78ac09f8 100644
--- a/qdp/Cargo.lock
+++ b/qdp/Cargo.lock
@@ -8,6 +8,17 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
+[[package]]
+name = "aes"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
+dependencies = [
+ "cfg-if",
+ "cipher",
+ "cpufeatures",
+]
+
[[package]]
name = "ahash"
version = "0.8.12"
@@ -308,6 +319,12 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
+[[package]]
+name = "base64ct"
+version = "1.8.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06"
+
[[package]]
name = "bitflags"
version = "1.3.2"
@@ -368,6 +385,26 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
+[[package]]
+name = "bzip2"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8"
+dependencies = [
+ "bzip2-sys",
+ "libc",
+]
+
+[[package]]
+name = "bzip2-sys"
+version = "0.1.13+1.0.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14"
+dependencies = [
+ "cc",
+ "pkg-config",
+]
+
[[package]]
name = "cc"
version = "1.2.52"
@@ -397,6 +434,16 @@ dependencies = [
"windows-link",
]
+[[package]]
+name = "cipher"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
+dependencies = [
+ "crypto-common",
+ "inout",
+]
+
[[package]]
name = "const-random"
version = "0.1.18"
@@ -417,6 +464,12 @@ dependencies = [
"tiny-keccak",
]
+[[package]]
+name = "constant_time_eq"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
+
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@@ -512,6 +565,15 @@ dependencies = [
"libloading",
]
+[[package]]
+name = "deranged"
+version = "0.5.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587"
+dependencies = [
+ "powerfmt",
+]
+
[[package]]
name = "derive_arbitrary"
version = "1.4.2"
@@ -531,6 +593,7 @@ checksum =
"9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
+ "subtle",
]
[[package]]
@@ -667,6 +730,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
+[[package]]
+name = "hmac"
+version = "0.12.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
+dependencies = [
+ "digest",
+]
+
[[package]]
name = "iana-time-zone"
version = "0.1.64"
@@ -710,6 +782,15 @@ dependencies = [
"rustversion",
]
+[[package]]
+name = "inout"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
+dependencies = [
+ "generic-array",
+]
+
[[package]]
name = "integer-encoding"
version = "3.0.4"
@@ -898,6 +979,19 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
+[[package]]
+name = "ndarray"
+version = "0.15.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
+dependencies = [
+ "matrixmultiply",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
[[package]]
name = "ndarray"
version = "0.16.1"
@@ -939,7 +1033,7 @@ dependencies = [
"num-complex",
"num-traits",
"py_literal",
- "zip",
+ "zip 2.4.2",
]
[[package]]
@@ -975,6 +1069,12 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "num-conv"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
+
[[package]]
name = "num-integer"
version = "0.1.46"
@@ -1086,7 +1186,18 @@ dependencies = [
"snap",
"thrift",
"twox-hash 1.6.3",
- "zstd",
+ "zstd 0.13.3",
+]
+
+[[package]]
+name = "password-hash"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
+dependencies = [
+ "base64ct",
+ "rand_core",
+ "subtle",
]
[[package]]
@@ -1095,6 +1206,18 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
+[[package]]
+name = "pbkdf2"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
+dependencies = [
+ "digest",
+ "hmac",
+ "password-hash",
+ "sha2",
+]
+
[[package]]
name = "pest"
version = "2.8.5"
@@ -1169,6 +1292,21 @@ dependencies = [
"portable-atomic",
]
+[[package]]
+name = "powerfmt"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
+
+[[package]]
+name = "ppv-lite86"
+version = "0.2.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
+dependencies = [
+ "zerocopy",
+]
+
[[package]]
name = "prettyplease"
version = "0.2.37"
@@ -1396,7 +1534,8 @@ dependencies = [
"protoc-bin-vendored",
"qdp-kernels",
"rayon",
- "thiserror",
+ "tch",
+ "thiserror 2.0.17",
]
[[package]]
@@ -1431,6 +1570,36 @@ version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
+[[package]]
+name = "rand"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
+dependencies = [
+ "libc",
+ "rand_chacha",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+dependencies = [
+ "ppv-lite86",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.6.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
+dependencies = [
+ "getrandom 0.2.17",
+]
+
[[package]]
name = "rawpointer"
version = "0.2.1"
@@ -1526,6 +1695,16 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984"
+[[package]]
+name = "safetensors"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "semver"
version = "1.0.27"
@@ -1545,6 +1724,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
+ "serde_derive",
]
[[package]]
@@ -1580,6 +1760,17 @@ dependencies = [
"zmij",
]
+[[package]]
+name = "sha1"
+version = "0.10.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "digest",
+]
+
[[package]]
name = "sha2"
version = "0.10.9"
@@ -1621,6 +1812,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
+[[package]]
+name = "subtle"
+version = "2.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
+
[[package]]
name = "syn"
version = "2.0.114"
@@ -1638,6 +1835,23 @@ version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1dd07eb858a2067e2f3c7155d54e929265c264e6f37efe3ee7a8d1b5a1dd0ba"
+[[package]]
+name = "tch"
+version = "0.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7c7cb00bc2770454b515388d45be7097a3ded2eca172f3dcdb7ca4cc06c40bf1"
+dependencies = [
+ "half",
+ "lazy_static",
+ "libc",
+ "ndarray 0.15.6",
+ "rand",
+ "safetensors",
+ "thiserror 1.0.69",
+ "torch-sys",
+ "zip 0.6.6",
+]
+
[[package]]
name = "tempfile"
version = "3.24.0"
@@ -1651,13 +1865,33 @@ dependencies = [
"windows-sys",
]
+[[package]]
+name = "thiserror"
+version = "1.0.69"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
+dependencies = [
+ "thiserror-impl 1.0.69",
+]
+
[[package]]
name = "thiserror"
version = "2.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
dependencies = [
- "thiserror-impl",
+ "thiserror-impl 2.0.17",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "1.0.69"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
]
[[package]]
@@ -1682,6 +1916,25 @@ dependencies = [
"ordered-float",
]
+[[package]]
+name = "time"
+version = "0.3.45"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f9e442fc33d7fdb45aa9bfeb312c095964abdf596f7567261062b2a7107aaabd"
+dependencies = [
+ "deranged",
+ "num-conv",
+ "powerfmt",
+ "serde_core",
+ "time-core",
+]
+
+[[package]]
+name = "time-core"
+version = "0.1.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b36ee98fd31ec7426d599183e8fe26932a8dc1fb76ddb6214d05493377d34ca"
+
[[package]]
name = "tiny-keccak"
version = "2.0.2"
@@ -1691,6 +1944,18 @@ dependencies = [
"crunchy",
]
+[[package]]
+name = "torch-sys"
+version = "0.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "29e0244e5b148a31dd7fe961165037d1927754d024095c1013937532d7e73a22"
+dependencies = [
+ "anyhow",
+ "cc",
+ "libc",
+ "zip 0.6.6",
+]
+
[[package]]
name = "twox-hash"
version = "1.6.3"
@@ -1891,6 +2156,26 @@ dependencies = [
"syn",
]
+[[package]]
+name = "zip"
+version = "0.6.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
+dependencies = [
+ "aes",
+ "byteorder",
+ "bzip2",
+ "constant_time_eq",
+ "crc32fast",
+ "crossbeam-utils",
+ "flate2",
+ "hmac",
+ "pbkdf2",
+ "sha1",
+ "time",
+ "zstd 0.11.2+zstd.1.5.2",
+]
+
[[package]]
name = "zip"
version = "2.4.2"
@@ -1904,7 +2189,7 @@ dependencies = [
"flate2",
"indexmap",
"memchr",
- "thiserror",
+ "thiserror 2.0.17",
"zopfli",
]
@@ -1926,13 +2211,32 @@ dependencies = [
"simd-adler32",
]
+[[package]]
+name = "zstd"
+version = "0.11.2+zstd.1.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
+dependencies = [
+ "zstd-safe 5.0.2+zstd.1.5.2",
+]
+
[[package]]
name = "zstd"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
dependencies = [
- "zstd-safe",
+ "zstd-safe 7.2.4",
+]
+
+[[package]]
+name = "zstd-safe"
+version = "5.0.2+zstd.1.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
+dependencies = [
+ "libc",
+ "zstd-sys",
]
[[package]]
diff --git a/qdp/Cargo.toml b/qdp/Cargo.toml
index 7c9571833..c2acf6ffc 100644
--- a/qdp/Cargo.toml
+++ b/qdp/Cargo.toml
@@ -35,6 +35,8 @@ ndarray-npy = "0.9"
prost = "0.12"
prost-build = "0.12"
bytes = "1.5"
+# PyTorch tensor loading (optional)
+tch = "0.15"
# Optional: vendored protoc to avoid build failures when protoc is missing
protoc-bin-vendored = "3"
diff --git a/qdp/docs/readers/README.md b/qdp/docs/readers/README.md
index e2c263479..1cd3f125b 100644
--- a/qdp/docs/readers/README.md
+++ b/qdp/docs/readers/README.md
@@ -35,8 +35,8 @@ pub trait StreamingDataReader: DataReader {
|--------|--------|-----------|--------|
| Parquet | `ParquetReader` | ✅ `ParquetStreamingReader` | ✅ Complete |
| Arrow IPC | `ArrowIPCReader` | ❌ | ✅ Complete |
-| NumPy | `NumpyReader` | ❌ | ❌ |
-| PyTorch | `TorchReader` | ❌ | ❌ |
+| NumPy | `NumpyReader` | ❌ | ✅ Complete |
+| PyTorch | `TorchReader` | ❌ | ✅ (feature: `pytorch`) |
## Benefits
@@ -123,7 +123,9 @@ fn read_quantum_data(path: &str) -> Result<(Vec<f64>,
usize, usize)> {
} else if path.ends_with(".arrow") {
ArrowIPCReader::new(path)?.read_batch()
} else if path.ends_with(".npy") {
- NumpyReader::new(path)?.read_batch() // When implemented
+ NumpyReader::new(path)?.read_batch()
+ } else if path.ends_with(".pt") || path.ends_with(".pth") {
+ TorchReader::new(path)?.read_batch()
} else {
Err(MahoutError::InvalidInput("Unsupported format".into()))
}
@@ -150,8 +152,8 @@ qdp-core/src/
│ ├── mod.rs # Reader registry
│ ├── parquet.rs # Parquet implementation
│ ├── arrow_ipc.rs # Arrow IPC implementation
-│ ├── numpy.rs # NumPy (placeholder)
-│ └── torch.rs # PyTorch (placeholder)
+│ ├── numpy.rs # NumPy implementation
+│ └── torch.rs # PyTorch (feature-gated)
├── io.rs # Legacy API & helper functions
└── lib.rs # Main library
@@ -207,8 +209,8 @@ let reader = ParquetStreamingReader::new(path, None)?;
## Future Enhancements
Planned format support:
-- **NumPy** (`.npy`): Python ecosystem integration
-- **PyTorch** (`.pt`): Deep learning workflows
+- **NumPy streaming**: Chunked reads for large `.npy` files
+- **PyTorch streaming**: Streaming support for large tensors
- **HDF5** (`.h5`): Scientific data storage
- **JSON**: Human-readable format for small datasets
- **CSV**: Simple tabular data
diff --git a/qdp/qdp-core/Cargo.toml b/qdp/qdp-core/Cargo.toml
index 1c92f0332..0353554b8 100644
--- a/qdp/qdp-core/Cargo.toml
+++ b/qdp/qdp-core/Cargo.toml
@@ -15,6 +15,7 @@ ndarray = { workspace = true }
ndarray-npy = { workspace = true }
prost = { workspace = true }
bytes = { workspace = true }
+tch = { workspace = true, optional = true }
[build-dependencies]
prost-build = { workspace = true }
@@ -26,6 +27,7 @@ name = "qdp_core"
[features]
default = []
observability = ["nvtx"]
+pytorch = ["tch"]
[dev-dependencies]
approx = "0.5.1"
diff --git a/qdp/qdp-core/src/io.rs b/qdp/qdp-core/src/io.rs
index 4e3cbdd07..ab3e903da 100644
--- a/qdp/qdp-core/src/io.rs
+++ b/qdp/qdp-core/src/io.rs
@@ -260,6 +260,20 @@ pub fn read_numpy_batch<P: AsRef<Path>>(path: P) ->
Result<(Vec<f64>, usize, usi
reader.read_batch()
}
+/// Reads batch data from a PyTorch .pt/.pth file.
+///
+/// Expects a 1D or 2D tensor saved with `torch.save`.
+/// Returns flattened data suitable for batch encoding.
+/// Requires the `pytorch` feature to be enabled.
+///
+/// # Returns
+/// Tuple of `(flattened_data, num_samples, sample_size)`
+pub fn read_torch_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize,
usize)> {
+ use crate::reader::DataReader;
+ let mut reader = crate::readers::TorchReader::new(path)?;
+ reader.read_batch()
+}
+
/// Streaming Parquet reader for List<Float64> and FixedSizeList<Float64>
columns
///
/// Reads Parquet files in chunks without loading entire file into memory.
diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs
index cd89b3f94..cb44ef36b 100644
--- a/qdp/qdp-core/src/lib.rs
+++ b/qdp/qdp-core/src/lib.rs
@@ -231,6 +231,40 @@ impl QdpEngine {
)
}
+ /// Load data from PyTorch .pt/.pth file and encode into quantum state
+ ///
+ /// Supports 1D or 2D tensors saved with `torch.save`.
+ /// Requires the `pytorch` feature to be enabled.
+ ///
+ /// # Arguments
+ /// * `path` - Path to PyTorch tensor file (.pt/.pth)
+ /// * `num_qubits` - Number of qubits
+ /// * `encoding_method` - Strategy: "amplitude", "angle", or "basis"
+ ///
+ /// # Returns
+ /// Single DLPack pointer containing all encoded states (shape:
[num_samples, 2^num_qubits])
+ pub fn encode_from_torch(
+ &self,
+ path: &str,
+ num_qubits: usize,
+ encoding_method: &str,
+ ) -> Result<*mut DLManagedTensor> {
+ crate::profile_scope!("Mahout::EncodeFromTorch");
+
+ let (batch_data, num_samples, sample_size) = {
+ crate::profile_scope!("IO::ReadTorchBatch");
+ crate::io::read_torch_batch(path)?
+ };
+
+ self.encode_batch(
+ &batch_data,
+ num_samples,
+ sample_size,
+ num_qubits,
+ encoding_method,
+ )
+ }
+
/// Load data from TensorFlow TensorProto file and encode into quantum
state
///
/// Supports Float64 tensors with shape [batch_size, feature_size] or [n].
diff --git a/qdp/qdp-core/src/readers/mod.rs b/qdp/qdp-core/src/readers/mod.rs
index c3ffc6efe..841c8059b 100644
--- a/qdp/qdp-core/src/readers/mod.rs
+++ b/qdp/qdp-core/src/readers/mod.rs
@@ -22,14 +22,18 @@
//! # Fully Implemented Formats
//! - **Parquet**: [`ParquetReader`], [`ParquetStreamingReader`]
//! - **Arrow IPC**: [`ArrowIPCReader`]
+//! - **NumPy**: [`NumpyReader`]
//! - **TensorFlow TensorProto**: [`TensorFlowReader`]
+//! - **PyTorch**: [`TorchReader`] (feature: `pytorch`)
pub mod arrow_ipc;
pub mod numpy;
pub mod parquet;
pub mod tensorflow;
+pub mod torch;
pub use arrow_ipc::ArrowIPCReader;
pub use numpy::NumpyReader;
pub use parquet::{ParquetReader, ParquetStreamingReader};
pub use tensorflow::TensorFlowReader;
+pub use torch::TorchReader;
diff --git a/qdp/qdp-core/src/readers/torch.rs
b/qdp/qdp-core/src/readers/torch.rs
new file mode 100644
index 000000000..7c8507968
--- /dev/null
+++ b/qdp/qdp-core/src/readers/torch.rs
@@ -0,0 +1,182 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! PyTorch tensor reader implementation.
+//!
+//! Supports `.pt`/`.pth` files containing a single tensor saved with
`torch.save`.
+//! The tensor must be 1D or 2D and will be converted to `float64`.
+//! Requires the `pytorch` feature to be enabled.
+
+use std::path::Path;
+
+use crate::error::{MahoutError, Result};
+use crate::reader::DataReader;
+
+/// Reader for PyTorch `.pt`/`.pth` tensor files.
+pub struct TorchReader {
+ path: std::path::PathBuf,
+ read: bool,
+ num_samples: Option<usize>,
+ sample_size: Option<usize>,
+}
+
+impl TorchReader {
+ /// Create a new PyTorch reader.
+ ///
+ /// # Arguments
+ /// * `path` - Path to the `.pt`/`.pth` file
+ pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
+ let path = path.as_ref();
+
+ match path.try_exists() {
+ Ok(false) => {
+ return Err(MahoutError::Io(format!(
+ "PyTorch file not found: {}",
+ path.display()
+ )));
+ }
+ Err(e) => {
+ return Err(MahoutError::Io(format!(
+ "Failed to check if PyTorch file exists at {}: {}",
+ path.display(),
+ e
+ )));
+ }
+ Ok(true) => {}
+ }
+
+ Ok(Self {
+ path: path.to_path_buf(),
+ read: false,
+ num_samples: None,
+ sample_size: None,
+ })
+ }
+}
+
+impl DataReader for TorchReader {
+ fn read_batch(&mut self) -> Result<(Vec<f64>, usize, usize)> {
+ if self.read {
+ return Err(MahoutError::InvalidInput(
+ "Reader already consumed".to_string(),
+ ));
+ }
+ self.read = true;
+
+ #[cfg(feature = "pytorch")]
+ {
+ let (data, num_samples, sample_size) =
read_torch_tensor(&self.path)?;
+ self.num_samples = Some(num_samples);
+ self.sample_size = Some(sample_size);
+ Ok((data, num_samples, sample_size))
+ }
+
+ #[cfg(not(feature = "pytorch"))]
+ {
+ Err(MahoutError::NotImplemented(
+ "PyTorch reader requires the 'pytorch' feature".to_string(),
+ ))
+ }
+ }
+
+ fn get_sample_size(&self) -> Option<usize> {
+ self.sample_size
+ }
+
+ fn get_num_samples(&self) -> Option<usize> {
+ self.num_samples
+ }
+}
+
+#[cfg(feature = "pytorch")]
+fn read_torch_tensor(path: &Path) -> Result<(Vec<f64>, usize, usize)> {
+ use tch::{Device, Kind, Tensor};
+
+ let tensor = Tensor::load(path).map_err(|e| {
+ MahoutError::Io(format!(
+ "Failed to load PyTorch tensor from {}: {}",
+ path.display(),
+ e
+ ))
+ })?;
+
+ let sizes = tensor.size();
+ let (num_samples, sample_size) = parse_shape(&sizes)?;
+ let tensor = tensor
+ .to_device(Device::Cpu)
+ .to_kind(Kind::Double)
+ .contiguous();
+
+ let expected = num_samples.checked_mul(sample_size).ok_or_else(|| {
+ MahoutError::InvalidInput(format!(
+ "Tensor shape too large: {} * {} would overflow",
+ num_samples, sample_size
+ ))
+ })?;
+
+ let flat = tensor.view([-1]);
+ let data: Vec<f64> = Vec::<f64>::try_from(&flat).map_err(|e| {
+ MahoutError::InvalidInput(format!(
+ "Failed to read PyTorch tensor data from {}: {}",
+ path.display(),
+ e
+ ))
+ })?;
+ if data.len() != expected {
+ return Err(MahoutError::InvalidInput(format!(
+ "Tensor data length mismatch: expected {}, got {}",
+ expected,
+ data.len()
+ )));
+ }
+
+ Ok((data, num_samples, sample_size))
+}
+
+#[cfg(feature = "pytorch")]
+fn parse_shape(sizes: &[i64]) -> Result<(usize, usize)> {
+ match sizes.len() {
+ 1 => {
+ let sample_size = checked_dim(sizes[0], "sample")?;
+ Ok((1, sample_size))
+ }
+ 2 => {
+ let num_samples = checked_dim(sizes[0], "batch")?;
+ let sample_size = checked_dim(sizes[1], "feature")?;
+ Ok((num_samples, sample_size))
+ }
+ _ => Err(MahoutError::InvalidInput(format!(
+ "Unsupported tensor rank: {} (only 1D and 2D supported)",
+ sizes.len()
+ ))),
+ }
+}
+
+#[cfg(feature = "pytorch")]
+fn checked_dim(value: i64, label: &str) -> Result<usize> {
+ if value <= 0 {
+ return Err(MahoutError::InvalidInput(format!(
+ "Invalid {} dimension size: {}",
+ label, value
+ )));
+ }
+ usize::try_from(value).map_err(|_| {
+ MahoutError::InvalidInput(format!(
+ "{} dimension too large to fit in usize: {}",
+ label, value
+ ))
+ })
+}
diff --git a/qdp/qdp-core/tests/torch_io.rs b/qdp/qdp-core/tests/torch_io.rs
new file mode 100644
index 000000000..eea7e5e85
--- /dev/null
+++ b/qdp/qdp-core/tests/torch_io.rs
@@ -0,0 +1,62 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#[cfg(feature = "pytorch")]
+mod pytorch_tests {
+ use qdp_core::io::read_torch_batch;
+ use qdp_core::reader::DataReader;
+ use qdp_core::readers::TorchReader;
+ use std::fs;
+ use tch::Tensor;
+
+ #[test]
+ fn test_torch_reader_basic_1d() {
+ let temp_path = "/tmp/test_torch_basic_1d.pt";
+ let sample_size = 12;
+ let data: Vec<f64> = (0..sample_size).map(|i| i as f64).collect();
+
+ let tensor = Tensor::from_slice(&data);
+ tensor.save(temp_path).unwrap();
+
+ let mut reader = TorchReader::new(temp_path).unwrap();
+ let (read_data, read_samples, read_size) =
reader.read_batch().unwrap();
+
+ assert_eq!(read_samples, 1);
+ assert_eq!(read_size, sample_size);
+ assert_eq!(read_data, data);
+
+ fs::remove_file(temp_path).unwrap();
+ }
+
+ #[test]
+ fn test_read_torch_batch_function_2d() {
+ let temp_path = "/tmp/test_torch_batch_fn.pt";
+ let num_samples = 4;
+ let sample_size = 3;
+ let data: Vec<f64> = (0..num_samples * sample_size).map(|i| i as
f64).collect();
+
+ let tensor = Tensor::from_slice(&data).reshape([num_samples as i64,
sample_size as i64]);
+ tensor.save(temp_path).unwrap();
+
+ let (read_data, read_samples, read_size) =
read_torch_batch(temp_path).unwrap();
+
+ assert_eq!(read_samples, num_samples);
+ assert_eq!(read_size, sample_size);
+ assert_eq!(read_data, data);
+
+ fs::remove_file(temp_path).unwrap();
+ }
+}
diff --git a/qdp/qdp-python/Cargo.toml b/qdp/qdp-python/Cargo.toml
index cfe598f77..6eee0bcf2 100644
--- a/qdp/qdp-python/Cargo.toml
+++ b/qdp/qdp-python/Cargo.toml
@@ -15,3 +15,4 @@ qdp-core = { path = "../qdp-core" }
[features]
default = []
observability = ["qdp-core/observability"]
+pytorch = ["qdp-core/pytorch"]
diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs
index 1bd25d798..0baa4eb56 100644
--- a/qdp/qdp-python/src/lib.rs
+++ b/qdp/qdp-python/src/lib.rs
@@ -217,8 +217,8 @@ impl QdpEngine {
/// data: Input data - supports:
/// - Python list: [1.0, 2.0, 3.0, 4.0]
/// - NumPy array: 1D (single sample) or 2D (batch) array
- /// - PyTorch tensor: CPU float64 tensor (C-contiguous
recommended; converted via NumPy view)
- /// - String path: .parquet, .arrow, .npy file
+ /// - PyTorch tensor: CPU tensor (float64 recommended; will be
copied to GPU)
+ /// - String path: .parquet, .arrow, .npy, .pt, .pth file
/// - pathlib.Path: Path object (converted via os.fspath())
/// num_qubits: Number of qubits for encoding
/// encoding_method: Encoding strategy ("amplitude" default, "angle",
or "basis")
@@ -444,9 +444,15 @@ impl QdpEngine {
.map_err(|e| {
PyRuntimeError::new_err(format!("Encoding from NumPy
failed: {}", e))
})?
+ } else if path.ends_with(".pt") || path.ends_with(".pth") {
+ self.engine
+ .encode_from_torch(path, num_qubits, encoding_method)
+ .map_err(|e| {
+ PyRuntimeError::new_err(format!("Encoding from PyTorch
failed: {}", e))
+ })?
} else {
return Err(PyRuntimeError::new_err(format!(
- "Unsupported file format. Expected .parquet, .arrow, .feather,
or .npy, got: {}",
+ "Unsupported file format. Expected .parquet, .arrow, .feather,
.npy, .pt, or .pth, got: {}",
path
)));
};
diff --git a/testing/qdp/test_bindings.py b/testing/qdp/test_bindings.py
index 8a3531c3e..9c5acfd56 100644
--- a/testing/qdp/test_bindings.py
+++ b/testing/qdp/test_bindings.py
@@ -198,6 +198,41 @@ def test_encode_tensor_batch():
assert torch_tensor.shape == (3, 4), "Batch encoding should preserve batch
size"
[email protected]
+def test_encode_from_tensorflow_binding():
+ """Test TensorFlow TensorProto binding path (requires GPU and
TensorFlow)."""
+ pytest.importorskip("torch")
+ tf = pytest.importorskip("tensorflow")
+ import numpy as np
+ import torch
+ from _qdp import QdpEngine
+ import os
+ import tempfile
+
+ if not torch.cuda.is_available():
+ pytest.skip("GPU required for QdpEngine")
+
+ engine = QdpEngine(0)
+ num_qubits = 2
+ sample_size = 2**num_qubits
+
+ data = np.array([[1.0, 2.0, 3.0, 4.0], [0.5, 0.5, 0.5, 0.5]],
dtype=np.float64)
+ tensor_proto = tf.make_tensor_proto(data, dtype=tf.float64)
+
+ with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f:
+ pb_path = f.name
+ f.write(tensor_proto.SerializeToString())
+
+ try:
+ qtensor = engine.encode_from_tensorflow(pb_path, num_qubits,
"amplitude")
+ torch_tensor = torch.from_dlpack(qtensor)
+ assert torch_tensor.is_cuda
+ assert torch_tensor.shape == (2, sample_size)
+ finally:
+ if os.path.exists(pb_path):
+ os.remove(pb_path)
+
+
@pytest.mark.gpu
def test_encode_errors():
"""Test error handling for unified encode method."""