This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 7e5076f1f7 Replace serde with `serde_core` when possible (#8558)
7e5076f1f7 is described below
commit 7e5076f1f775a6fd08a4d63389e26e2920fe3f6a
Author: Adam Gutglick <[email protected]>
AuthorDate: Sat Oct 11 14:12:02 2025 +0100
Replace serde with `serde_core` when possible (#8558)
# Which issue does this PR close?
- Closes #8451.
With this change, its possible to compile the core crate without pulling
`serde_derive`, which is will be only required for `arrow-avro` and
`arrow-schema/serde`.
# Rationale for this change
Improve compile time and reduce number of dependencies and binary size
in some cases.
# What changes are included in this PR?
1. Use `serde_core` when possible
2. Manually implement `Serialize/Deserialize` for canonical extension
type metadata.
# Are these changes tested?
Covered by existing tests
# Are there any user-facing changes?
No
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-integration-testing/Cargo.toml | 1 -
arrow-json/Cargo.toml | 2 +-
arrow-json/src/reader/mod.rs | 4 +-
arrow-json/src/reader/serializer.rs | 6 +-
arrow-json/src/reader/tape.rs | 2 +-
arrow-json/src/writer/encoder.rs | 2 +-
arrow-schema/Cargo.toml | 15 ++-
.../src/extension/canonical/fixed_shape_tensor.rs | 149 ++++++++++++++++++++-
arrow-schema/src/extension/canonical/json.rs | 77 ++++++++++-
arrow-schema/src/extension/canonical/opaque.rs | 133 +++++++++++++++++-
.../extension/canonical/variable_shape_tensor.rs | 144 +++++++++++++++++++-
11 files changed, 509 insertions(+), 26 deletions(-)
diff --git a/arrow-integration-testing/Cargo.toml
b/arrow-integration-testing/Cargo.toml
index 35eb47b8d6..ae13d32b57 100644
--- a/arrow-integration-testing/Cargo.toml
+++ b/arrow-integration-testing/Cargo.toml
@@ -40,7 +40,6 @@ arrow-integration-test = { path =
"../arrow-integration-test", default-features
clap = { version = "4", default-features = false, features = ["std", "derive",
"help", "error-context", "usage"] }
futures = { version = "0.3", default-features = false }
prost = { version = "0.14.1", default-features = false }
-serde = { version = "1.0", default-features = false, features = ["rc",
"derive"] }
serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.0", default-features = false, features = [
"rt-multi-thread"] }
tonic = { version = "0.14.1", default-features = false }
diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml
index b7134b170f..2f9e584060 100644
--- a/arrow-json/Cargo.toml
+++ b/arrow-json/Cargo.toml
@@ -44,7 +44,7 @@ arrow-schema = { workspace = true }
half = { version = "2.1", default-features = false }
indexmap = { version = "2.0", default-features = false, features = ["std"] }
num-traits = { version = "0.2.19", default-features = false, features =
["std"] }
-serde = { version = "1.0", default-features = false }
+serde_core = { version = "1.0", default-features = false }
serde_json = { version = "1.0", default-features = false, features = ["std"] }
chrono = { workspace = true }
lexical-core = { version = "1.0", default-features = false}
diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs
index e4658f8653..c47aa65f81 100644
--- a/arrow-json/src/reader/mod.rs
+++ b/arrow-json/src/reader/mod.rs
@@ -138,7 +138,7 @@ use std::io::BufRead;
use std::sync::Arc;
use chrono::Utc;
-use serde::Serialize;
+use serde_core::Serialize;
use arrow_array::timezone::Tz;
use arrow_array::types::*;
@@ -613,6 +613,8 @@ impl Decoder {
/// ```
///
/// Note: this ignores any batch size setting, and always decodes all rows
+ ///
+ /// [serde]: https://docs.rs/serde/latest/serde/
pub fn serialize<S: Serialize>(&mut self, rows: &[S]) -> Result<(),
ArrowError> {
self.tape_decoder.serialize(rows)
}
diff --git a/arrow-json/src/reader/serializer.rs
b/arrow-json/src/reader/serializer.rs
index 95068af678..5d004fbb5c 100644
--- a/arrow-json/src/reader/serializer.rs
+++ b/arrow-json/src/reader/serializer.rs
@@ -17,10 +17,10 @@
use crate::reader::tape::TapeElement;
use lexical_core::FormattedSize;
-use serde::ser::{
+use serde_core::ser::{
Impossible, SerializeMap, SerializeSeq, SerializeStruct, SerializeTuple,
SerializeTupleStruct,
};
-use serde::{Serialize, Serializer};
+use serde_core::{Serialize, Serializer};
#[derive(Debug)]
pub struct SerializerError(String);
@@ -33,7 +33,7 @@ impl std::fmt::Display for SerializerError {
}
}
-impl serde::ser::Error for SerializerError {
+impl serde_core::ser::Error for SerializerError {
fn custom<T>(msg: T) -> Self
where
T: std::fmt::Display,
diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs
index e3e42ae1cc..89ee3f7787 100644
--- a/arrow-json/src/reader/tape.rs
+++ b/arrow-json/src/reader/tape.rs
@@ -18,7 +18,7 @@
use crate::reader::serializer::TapeSerializer;
use arrow_schema::ArrowError;
use memchr::memchr2;
-use serde::Serialize;
+use serde_core::Serialize;
use std::fmt::Write;
/// We decode JSON to a flattened tape representation,
diff --git a/arrow-json/src/writer/encoder.rs b/arrow-json/src/writer/encoder.rs
index c960da3e07..b562249fc5 100644
--- a/arrow-json/src/writer/encoder.rs
+++ b/arrow-json/src/writer/encoder.rs
@@ -26,7 +26,7 @@ use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_schema::{ArrowError, DataType, FieldRef};
use half::f16;
use lexical_core::FormattedSize;
-use serde::Serializer;
+use serde_core::Serializer;
/// Configuration options for the JSON encoder.
#[derive(Debug, Clone, Default)]
diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml
index e8ca520c3c..99e08e20c2 100644
--- a/arrow-schema/Cargo.toml
+++ b/arrow-schema/Cargo.toml
@@ -33,25 +33,30 @@ name = "arrow_schema"
bench = false
[dependencies]
-serde = { version = "1.0", default-features = false, features = [
- "derive",
+serde_core = { version = "1.0", default-features = false, features = [
"std",
"rc",
], optional = true }
+serde = { version = "1.0", default-features = false, features = [
+ "derive",
+], optional = true }
bitflags = { version = "2.0.0", default-features = false, optional = true }
serde_json = { version = "1.0", optional = true }
[features]
-canonical_extension_types = ["dep:serde", "dep:serde_json"]
+canonical_extension_types = ["dep:serde_core", "dep:serde_json"]
# Enable ffi support
ffi = ["bitflags"]
-serde = ["dep:serde"]
+serde = ["dep:serde_core", "dep:serde"]
[package.metadata.docs.rs]
all-features = true
[dev-dependencies]
-bincode = { version = "2.0.1", default-features = false, features = ["std",
"serde"] }
+bincode = { version = "2.0.1", default-features = false, features = [
+ "std",
+ "serde",
+] }
criterion = { version = "0.5", default-features = false }
insta = "1.43.1"
diff --git a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs
b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs
index 94258123aa..b6bd1c1223 100644
--- a/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs
+++ b/arrow-schema/src/extension/canonical/fixed_shape_tensor.rs
@@ -19,7 +19,10 @@
//!
//!
<https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
-use serde::{Deserialize, Serialize};
+use serde_core::de::{self, MapAccess, Visitor};
+use serde_core::ser::SerializeStruct;
+use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
+use std::fmt;
use crate::{ArrowError, DataType, extension::ExtensionType};
@@ -129,7 +132,7 @@ impl FixedShapeTensor {
}
/// Extension type metadata for [`FixedShapeTensor`].
-#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Clone, PartialEq)]
pub struct FixedShapeTensorMetadata {
/// The physical shape of the contained tensors.
shape: Vec<usize>,
@@ -141,6 +144,143 @@ pub struct FixedShapeTensorMetadata {
permutations: Option<Vec<usize>>,
}
+impl Serialize for FixedShapeTensorMetadata {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ let mut state =
serializer.serialize_struct("FixedShapeTensorMetadata", 3)?;
+ state.serialize_field("shape", &self.shape)?;
+ state.serialize_field("dim_names", &self.dim_names)?;
+ state.serialize_field("permutations", &self.permutations)?;
+ state.end()
+ }
+}
+
+#[derive(Debug)]
+enum MetadataField {
+ Shape,
+ DimNames,
+ Permutations,
+}
+
+struct MetadataFieldVisitor;
+
+impl<'de> Visitor<'de> for MetadataFieldVisitor {
+ type Value = MetadataField;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("`shape`, `dim_names`, or `permutations`")
+ }
+
+ fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
+ where
+ E: de::Error,
+ {
+ match value {
+ "shape" => Ok(MetadataField::Shape),
+ "dim_names" => Ok(MetadataField::DimNames),
+ "permutations" => Ok(MetadataField::Permutations),
+ _ => Err(de::Error::unknown_field(
+ value,
+ &["shape", "dim_names", "permutations"],
+ )),
+ }
+ }
+}
+
+impl<'de> Deserialize<'de> for MetadataField {
+ fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_identifier(MetadataFieldVisitor)
+ }
+}
+
+struct FixedShapeTensorMetadataVisitor;
+
+impl<'de> Visitor<'de> for FixedShapeTensorMetadataVisitor {
+ type Value = FixedShapeTensorMetadata;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("struct FixedShapeTensorMetadata")
+ }
+
+ fn visit_seq<V>(self, mut seq: V) -> Result<FixedShapeTensorMetadata,
V::Error>
+ where
+ V: de::SeqAccess<'de>,
+ {
+ let shape = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(0, &self))?;
+ let dim_names = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(1, &self))?;
+ let permutations = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(2, &self))?;
+ Ok(FixedShapeTensorMetadata {
+ shape,
+ dim_names,
+ permutations,
+ })
+ }
+
+ fn visit_map<V>(self, mut map: V) -> Result<FixedShapeTensorMetadata,
V::Error>
+ where
+ V: MapAccess<'de>,
+ {
+ let mut shape = None;
+ let mut dim_names = None;
+ let mut permutations = None;
+
+ while let Some(key) = map.next_key()? {
+ match key {
+ MetadataField::Shape => {
+ if shape.is_some() {
+ return Err(de::Error::duplicate_field("shape"));
+ }
+ shape = Some(map.next_value()?);
+ }
+ MetadataField::DimNames => {
+ if dim_names.is_some() {
+ return Err(de::Error::duplicate_field("dim_names"));
+ }
+ dim_names = Some(map.next_value()?);
+ }
+ MetadataField::Permutations => {
+ if permutations.is_some() {
+ return Err(de::Error::duplicate_field("permutations"));
+ }
+ permutations = Some(map.next_value()?);
+ }
+ }
+ }
+
+ let shape = shape.ok_or_else(|| de::Error::missing_field("shape"))?;
+
+ Ok(FixedShapeTensorMetadata {
+ shape,
+ dim_names,
+ permutations,
+ })
+ }
+}
+
+impl<'de> Deserialize<'de> for FixedShapeTensorMetadata {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_struct(
+ "FixedShapeTensorMetadata",
+ &["shape", "dim_names", "permutations"],
+ FixedShapeTensorMetadataVisitor,
+ )
+ }
+}
+
impl FixedShapeTensorMetadata {
/// Returns metadata for a fixed shape tensor extension type.
///
@@ -377,9 +517,8 @@ mod tests {
}
#[test]
- #[should_panic(
- expected = "FixedShapeTensor metadata deserialization failed: missing
field `shape`"
- )]
+ #[should_panic(expected = "FixedShapeTensor metadata deserialization
failed: \
+ unknown field `not-shape`, expected one of `shape`, `dim_names`,
`permutations`")]
fn invalid_metadata() {
let fixed_shape_tensor =
FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500],
None, None).unwrap();
diff --git a/arrow-schema/src/extension/canonical/json.rs
b/arrow-schema/src/extension/canonical/json.rs
index 3660945104..297a2d99aa 100644
--- a/arrow-schema/src/extension/canonical/json.rs
+++ b/arrow-schema/src/extension/canonical/json.rs
@@ -19,7 +19,10 @@
//!
//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#json>
-use serde::{Deserialize, Serialize};
+use serde_core::de::{self, MapAccess, Visitor};
+use serde_core::ser::SerializeStruct;
+use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
+use std::fmt;
use crate::{ArrowError, DataType, extension::ExtensionType};
@@ -42,10 +45,78 @@ use crate::{ArrowError, DataType, extension::ExtensionType};
pub struct Json(JsonMetadata);
/// Empty object
-#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)]
-#[serde(deny_unknown_fields)]
+#[derive(Debug, Clone, Copy, PartialEq)]
struct Empty {}
+impl Serialize for Empty {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ let state = serializer.serialize_struct("Empty", 0)?;
+ state.end()
+ }
+}
+
+struct EmptyVisitor;
+
+impl<'de> Visitor<'de> for EmptyVisitor {
+ type Value = Empty;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("struct Empty")
+ }
+
+ fn visit_seq<A>(self, mut _seq: A) -> Result<Self::Value, A::Error>
+ where
+ A: de::SeqAccess<'de>,
+ {
+ Ok(Empty {})
+ }
+
+ fn visit_map<V>(self, mut map: V) -> Result<Empty, V::Error>
+ where
+ V: MapAccess<'de>,
+ {
+ if let Some(key) = map.next_key::<String>()? {
+ return Err(de::Error::unknown_field(&key, EMPTY_FIELDS));
+ }
+ Ok(Empty {})
+ }
+
+ fn visit_u64<E>(self, _v: u64) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ Err(de::Error::unknown_field("", EMPTY_FIELDS))
+ }
+
+ fn visit_str<E>(self, _v: &str) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ Err(de::Error::unknown_field("", EMPTY_FIELDS))
+ }
+
+ fn visit_bytes<E>(self, _v: &[u8]) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ Err(de::Error::unknown_field("", EMPTY_FIELDS))
+ }
+}
+
+static EMPTY_FIELDS: &[&str] = &[];
+
+impl<'de> Deserialize<'de> for Empty {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_struct("Empty", EMPTY_FIELDS, EmptyVisitor)
+ }
+}
+
/// Extension type metadata for [`Json`].
#[derive(Debug, Default, Clone, PartialEq)]
pub struct JsonMetadata(Option<Empty>);
diff --git a/arrow-schema/src/extension/canonical/opaque.rs
b/arrow-schema/src/extension/canonical/opaque.rs
index 5aa064e6d3..fceae8d371 100644
--- a/arrow-schema/src/extension/canonical/opaque.rs
+++ b/arrow-schema/src/extension/canonical/opaque.rs
@@ -19,7 +19,11 @@
//!
//! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#opaque>
-use serde::{Deserialize, Serialize};
+use serde_core::ser::SerializeStruct;
+use serde_core::{
+ Deserialize, Deserializer, Serialize, Serializer,
+ de::{MapAccess, Visitor},
+};
use crate::{ArrowError, DataType, extension::ExtensionType};
@@ -61,7 +65,7 @@ impl From<OpaqueMetadata> for Opaque {
}
/// Extension type metadata for [`Opaque`].
-#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Clone, PartialEq)]
pub struct OpaqueMetadata {
/// Name of the unknown type in the external system.
type_name: String,
@@ -70,6 +74,131 @@ pub struct OpaqueMetadata {
vendor_name: String,
}
+impl Serialize for OpaqueMetadata {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ let mut state = serializer.serialize_struct("OpaqueMetadata", 2)?;
+ state.serialize_field("type_name", &self.type_name)?;
+ state.serialize_field("vendor_name", &self.vendor_name)?;
+ state.end()
+ }
+}
+
+#[derive(Debug)]
+enum MetadataField {
+ TypeName,
+ VendorName,
+}
+
+struct MetadataFieldVisitor;
+
+impl<'de> Visitor<'de> for MetadataFieldVisitor {
+ type Value = MetadataField;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) ->
std::fmt::Result {
+ formatter.write_str("`type_name` or `vendor_name`")
+ }
+
+ fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
+ where
+ E: serde_core::de::Error,
+ {
+ match value {
+ "type_name" => Ok(MetadataField::TypeName),
+ "vendor_name" => Ok(MetadataField::VendorName),
+ _ => Err(serde_core::de::Error::unknown_field(
+ value,
+ &["type_name", "vendor_name"],
+ )),
+ }
+ }
+}
+
+impl<'de> Deserialize<'de> for MetadataField {
+ fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_identifier(MetadataFieldVisitor)
+ }
+}
+
+struct OpaqueMetadataVisitor;
+
+impl<'de> Visitor<'de> for OpaqueMetadataVisitor {
+ type Value = OpaqueMetadata;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) ->
std::fmt::Result {
+ formatter.write_str("struct OpaqueMetadata")
+ }
+
+ fn visit_seq<V>(self, mut seq: V) -> Result<OpaqueMetadata, V::Error>
+ where
+ V: serde_core::de::SeqAccess<'de>,
+ {
+ let type_name = seq
+ .next_element()?
+ .ok_or_else(|| serde_core::de::Error::invalid_length(0, &self))?;
+ let vendor_name = seq
+ .next_element()?
+ .ok_or_else(|| serde_core::de::Error::invalid_length(1, &self))?;
+ Ok(OpaqueMetadata {
+ type_name,
+ vendor_name,
+ })
+ }
+
+ fn visit_map<V>(self, mut map: V) -> Result<OpaqueMetadata, V::Error>
+ where
+ V: MapAccess<'de>,
+ {
+ let mut type_name = None;
+ let mut vendor_name = None;
+
+ while let Some(key) = map.next_key()? {
+ match key {
+ MetadataField::TypeName => {
+ if type_name.is_some() {
+ return
Err(serde_core::de::Error::duplicate_field("type_name"));
+ }
+ type_name = Some(map.next_value()?);
+ }
+ MetadataField::VendorName => {
+ if vendor_name.is_some() {
+ return
Err(serde_core::de::Error::duplicate_field("vendor_name"));
+ }
+ vendor_name = Some(map.next_value()?);
+ }
+ }
+ }
+
+ let type_name =
+ type_name.ok_or_else(||
serde_core::de::Error::missing_field("type_name"))?;
+ let vendor_name =
+ vendor_name.ok_or_else(||
serde_core::de::Error::missing_field("vendor_name"))?;
+
+ Ok(OpaqueMetadata {
+ type_name,
+ vendor_name,
+ })
+ }
+}
+
+impl<'de> Deserialize<'de> for OpaqueMetadata {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_struct(
+ "OpaqueMetadata",
+ &["type_name", "vendor_name"],
+ OpaqueMetadataVisitor,
+ )
+ }
+}
+
impl OpaqueMetadata {
/// Returns a new `OpaqueMetadata`.
pub fn new(type_name: impl Into<String>, vendor_name: impl Into<String>)
-> Self {
diff --git a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs
b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs
index 45fde67ee7..b5403dcf68 100644
--- a/arrow-schema/src/extension/canonical/variable_shape_tensor.rs
+++ b/arrow-schema/src/extension/canonical/variable_shape_tensor.rs
@@ -19,7 +19,9 @@
//!
//!
<https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor>
-use serde::{Deserialize, Serialize};
+use serde_core::de::{self, MapAccess, Visitor};
+use serde_core::{Deserialize, Deserializer, Serialize, Serializer};
+use std::fmt;
use crate::{ArrowError, DataType, Field, extension::ExtensionType};
@@ -140,7 +142,7 @@ impl VariableShapeTensor {
}
/// Extension type metadata for [`VariableShapeTensor`].
-#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
+#[derive(Debug, Clone, PartialEq)]
pub struct VariableShapeTensorMetadata {
/// Explicit names to tensor dimensions.
dim_names: Option<Vec<String>>,
@@ -148,11 +150,147 @@ pub struct VariableShapeTensorMetadata {
/// Indices of the desired ordering of the original dimensions.
permutations: Option<Vec<usize>>,
- /// Sizes of individual tensor’s dimensions which are guaranteed to stay
+ /// Sizes of individual tensor's dimensions which are guaranteed to stay
/// constant in uniform dimensions and can vary in non-uniform dimensions.
uniform_shape: Option<Vec<Option<i32>>>,
}
+impl Serialize for VariableShapeTensorMetadata {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ use serde_core::ser::SerializeStruct;
+ let mut state =
serializer.serialize_struct("VariableShapeTensorMetadata", 3)?;
+ state.serialize_field("dim_names", &self.dim_names)?;
+ state.serialize_field("permutations", &self.permutations)?;
+ state.serialize_field("uniform_shape", &self.uniform_shape)?;
+ state.end()
+ }
+}
+
+#[derive(Debug)]
+enum MetadataField {
+ DimNames,
+ Permutations,
+ UniformShape,
+}
+
+struct MetadataFieldVisitor;
+
+impl<'de> Visitor<'de> for MetadataFieldVisitor {
+ type Value = MetadataField;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("`dim_names`, `permutations`, or `uniform_shape`")
+ }
+
+ fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
+ where
+ E: de::Error,
+ {
+ match value {
+ "dim_names" => Ok(MetadataField::DimNames),
+ "permutations" => Ok(MetadataField::Permutations),
+ "uniform_shape" => Ok(MetadataField::UniformShape),
+ _ => Err(de::Error::unknown_field(
+ value,
+ &["dim_names", "permutations", "uniform_shape"],
+ )),
+ }
+ }
+}
+
+impl<'de> Deserialize<'de> for MetadataField {
+ fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_identifier(MetadataFieldVisitor)
+ }
+}
+
+struct VariableShapeTensorMetadataVisitor;
+
+impl<'de> Visitor<'de> for VariableShapeTensorMetadataVisitor {
+ type Value = VariableShapeTensorMetadata;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("struct VariableShapeTensorMetadata")
+ }
+
+ fn visit_seq<V>(self, mut seq: V) -> Result<VariableShapeTensorMetadata,
V::Error>
+ where
+ V: de::SeqAccess<'de>,
+ {
+ let dim_names = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(0, &self))?;
+ let permutations = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(1, &self))?;
+ let uniform_shape = seq
+ .next_element()?
+ .ok_or_else(|| de::Error::invalid_length(2, &self))?;
+ Ok(VariableShapeTensorMetadata {
+ dim_names,
+ permutations,
+ uniform_shape,
+ })
+ }
+
+ fn visit_map<V>(self, mut map: V) -> Result<VariableShapeTensorMetadata,
V::Error>
+ where
+ V: MapAccess<'de>,
+ {
+ let mut dim_names = None;
+ let mut permutations = None;
+ let mut uniform_shape = None;
+
+ while let Some(key) = map.next_key()? {
+ match key {
+ MetadataField::DimNames => {
+ if dim_names.is_some() {
+ return Err(de::Error::duplicate_field("dim_names"));
+ }
+ dim_names = Some(map.next_value()?);
+ }
+ MetadataField::Permutations => {
+ if permutations.is_some() {
+ return Err(de::Error::duplicate_field("permutations"));
+ }
+ permutations = Some(map.next_value()?);
+ }
+ MetadataField::UniformShape => {
+ if uniform_shape.is_some() {
+ return
Err(de::Error::duplicate_field("uniform_shape"));
+ }
+ uniform_shape = Some(map.next_value()?);
+ }
+ }
+ }
+
+ Ok(VariableShapeTensorMetadata {
+ dim_names,
+ permutations,
+ uniform_shape,
+ })
+ }
+}
+
+impl<'de> Deserialize<'de> for VariableShapeTensorMetadata {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ deserializer.deserialize_struct(
+ "VariableShapeTensorMetadata",
+ &["dim_names", "permutations", "uniform_shape"],
+ VariableShapeTensorMetadataVisitor,
+ )
+ }
+}
+
impl VariableShapeTensorMetadata {
/// Returns metadata for a variable shape tensor extension type.
///