alamb commented on code in PR #9086:
URL: https://github.com/apache/arrow-rs/pull/9086#discussion_r2678929272


##########
arrow-json/src/reader/struct_array.rs:
##########
@@ -21,13 +21,16 @@ use arrow_array::builder::BooleanBufferBuilder;
 use arrow_buffer::buffer::NullBuffer;
 use arrow_data::{ArrayData, ArrayDataBuilder};
 use arrow_schema::{ArrowError, DataType, Fields};
+use std::collections::HashMap;
 
 pub struct StructArrayDecoder {
     data_type: DataType,
     decoders: Vec<Box<dyn ArrayDecoder>>,
     strict_mode: bool,
     is_nullable: bool,
     struct_mode: StructMode,
+    field_name_to_index: Option<HashMap<String, usize>>,
+    child_pos: Vec<u32>,

Review Comment:
   Could you add a comment that explains what child_pos is? It isn't clear here 
(the idea of caching rather than recreating it looks good though)
   
   Specifically I think it is important to document what is stored at each 
index (e.g. each index the tape position of at `field_idx * row_count + row`)



##########
arrow-json/src/reader/struct_array.rs:
##########
@@ -38,131 +41,171 @@ impl StructArrayDecoder {
         is_nullable: bool,
         struct_mode: StructMode,
     ) -> Result<Self, ArrowError> {
-        let decoders = struct_fields(&data_type)
-            .iter()
-            .map(|f| {
-                // If this struct nullable, need to permit nullability in 
child array
-                // StructArrayDecoder::decode verifies that if the child is 
not nullable
-                // it doesn't contain any nulls not masked by its parent
-                let nullable = f.is_nullable() || is_nullable;
-                make_decoder(
-                    f.data_type().clone(),
-                    coerce_primitive,
-                    strict_mode,
-                    nullable,
-                    struct_mode,
-                )
-            })
-            .collect::<Result<Vec<_>, ArrowError>>()?;
+        let (decoders, field_name_to_index) = {
+            let fields = struct_fields(&data_type);
+            let decoders = fields
+                .iter()
+                .map(|f| {
+                    // If this struct nullable, need to permit nullability in 
child array
+                    // StructArrayDecoder::decode verifies that if the child 
is not nullable
+                    // it doesn't contain any nulls not masked by its parent
+                    let nullable = f.is_nullable() || is_nullable;
+                    make_decoder(
+                        f.data_type().clone(),
+                        coerce_primitive,
+                        strict_mode,
+                        nullable,
+                        struct_mode,
+                    )
+                })
+                .collect::<Result<Vec<_>, ArrowError>>()?;
+            let field_name_to_index = if struct_mode == StructMode::ObjectOnly 
{
+                build_field_index(fields)
+            } else {
+                None
+            };
+            (decoders, field_name_to_index)
+        };
 
         Ok(Self {
             data_type,
             decoders,
             strict_mode,
             is_nullable,
             struct_mode,
+            field_name_to_index,
+            child_pos: Vec::new(),
         })
     }
 }
 
 impl ArrayDecoder for StructArrayDecoder {
     fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, 
ArrowError> {
         let fields = struct_fields(&self.data_type);
-        let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; 
pos.len()]).collect();
-
+        let row_count = pos.len();
+        let field_count = fields.len();
+        let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
+            ArrowError::JsonError(format!(
+                "StructArrayDecoder child position buffer size overflow for 
rows={row_count} fields={field_count}"
+            ))
+        })?;
+        if total_len > self.child_pos.len() {
+            self.child_pos
+                .try_reserve(total_len - self.child_pos.len())
+                .map_err(|_| {
+                    ArrowError::JsonError(format!(
+                        "StructArrayDecoder child position buffer allocation 
failed for rows={row_count} fields={field_count}"
+                    ))
+                })?;
+        }
+        self.child_pos.resize(total_len, 0);

Review Comment:
   This seems like it would set some elements to zero twice -- I think you can 
get the same result without the extra setting via
   
   ```shell
   self.child_pos.clear();
   self.child_pos.resize(total_len, 0);
   ```
   
   
   Also, I think resize calls reserve internally (it [internally calls 
extend_with](https://doc.rust-lang.org/src/alloc/vec/mod.rs.html#3416) which 
calls reserve), so there is no need to also call `child_pos.reserve` above
   
   (also the rest of this crate just calls `reserve` so I think using 
`try_reserve` just here seems unecessary)



##########
arrow-json/src/reader/struct_array.rs:
##########
@@ -38,131 +41,171 @@ impl StructArrayDecoder {
         is_nullable: bool,
         struct_mode: StructMode,
     ) -> Result<Self, ArrowError> {
-        let decoders = struct_fields(&data_type)
-            .iter()
-            .map(|f| {
-                // If this struct nullable, need to permit nullability in 
child array
-                // StructArrayDecoder::decode verifies that if the child is 
not nullable
-                // it doesn't contain any nulls not masked by its parent
-                let nullable = f.is_nullable() || is_nullable;
-                make_decoder(
-                    f.data_type().clone(),
-                    coerce_primitive,
-                    strict_mode,
-                    nullable,
-                    struct_mode,
-                )
-            })
-            .collect::<Result<Vec<_>, ArrowError>>()?;
+        let (decoders, field_name_to_index) = {
+            let fields = struct_fields(&data_type);
+            let decoders = fields
+                .iter()
+                .map(|f| {
+                    // If this struct nullable, need to permit nullability in 
child array
+                    // StructArrayDecoder::decode verifies that if the child 
is not nullable
+                    // it doesn't contain any nulls not masked by its parent
+                    let nullable = f.is_nullable() || is_nullable;
+                    make_decoder(
+                        f.data_type().clone(),
+                        coerce_primitive,
+                        strict_mode,
+                        nullable,
+                        struct_mode,
+                    )
+                })
+                .collect::<Result<Vec<_>, ArrowError>>()?;
+            let field_name_to_index = if struct_mode == StructMode::ObjectOnly 
{
+                build_field_index(fields)
+            } else {
+                None
+            };
+            (decoders, field_name_to_index)
+        };
 
         Ok(Self {
             data_type,
             decoders,
             strict_mode,
             is_nullable,
             struct_mode,
+            field_name_to_index,
+            child_pos: Vec::new(),
         })
     }
 }
 
 impl ArrayDecoder for StructArrayDecoder {
     fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, 
ArrowError> {
         let fields = struct_fields(&self.data_type);
-        let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; 
pos.len()]).collect();
-
+        let row_count = pos.len();
+        let field_count = fields.len();
+        let total_len = field_count.checked_mul(row_count).ok_or_else(|| {
+            ArrowError::JsonError(format!(
+                "StructArrayDecoder child position buffer size overflow for 
rows={row_count} fields={field_count}"
+            ))
+        })?;
+        if total_len > self.child_pos.len() {
+            self.child_pos
+                .try_reserve(total_len - self.child_pos.len())
+                .map_err(|_| {
+                    ArrowError::JsonError(format!(
+                        "StructArrayDecoder child position buffer allocation 
failed for rows={row_count} fields={field_count}"
+                    ))
+                })?;
+        }
+        self.child_pos.resize(total_len, 0);
+        self.child_pos.fill(0);
         let mut nulls = self
             .is_nullable
             .then(|| BooleanBufferBuilder::new(pos.len()));
 
-        // We avoid having the match on self.struct_mode inside the hot loop 
for performance
-        // TODO: Investigate how to extract duplicated logic.
-        match self.struct_mode {
-            StructMode::ObjectOnly => {
-                for (row, p) in pos.iter().enumerate() {
-                    let end_idx = match (tape.get(*p), nulls.as_mut()) {
-                        (TapeElement::StartObject(end_idx), None) => end_idx,
-                        (TapeElement::StartObject(end_idx), Some(nulls)) => {
-                            nulls.append(true);
-                            end_idx
-                        }
-                        (TapeElement::Null, Some(nulls)) => {
-                            nulls.append(false);
-                            continue;
-                        }
-                        (_, _) => return Err(tape.error(*p, "{")),
-                    };
-
-                    let mut cur_idx = *p + 1;
-                    while cur_idx < end_idx {
-                        // Read field name
-                        let field_name = match tape.get(cur_idx) {
-                            TapeElement::String(s) => tape.get_string(s),
-                            _ => return Err(tape.error(cur_idx, "field name")),
+        {
+            let child_pos = self.child_pos.as_mut_slice();
+            // We avoid having the match on self.struct_mode inside the hot 
loop for performance
+            // TODO: Investigate how to extract duplicated logic.
+            match self.struct_mode {
+                StructMode::ObjectOnly => {
+                    for (row, p) in pos.iter().enumerate() {
+                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
+                            (TapeElement::StartObject(end_idx), None) => 
end_idx,
+                            (TapeElement::StartObject(end_idx), Some(nulls)) 
=> {
+                                nulls.append(true);
+                                end_idx
+                            }
+                            (TapeElement::Null, Some(nulls)) => {
+                                nulls.append(false);
+                                continue;
+                            }
+                            (_, _) => return Err(tape.error(*p, "{")),
                         };
 
-                        // Update child pos if match found
-                        match fields.iter().position(|x| x.name() == 
field_name) {
-                            Some(field_idx) => child_pos[field_idx][row] = 
cur_idx + 1,
-                            None => {
-                                if self.strict_mode {
-                                    return Err(ArrowError::JsonError(format!(
-                                        "column '{field_name}' missing from 
schema",
-                                    )));
+                        let mut cur_idx = *p + 1;
+                        while cur_idx < end_idx {
+                            // Read field name
+                            let field_name = match tape.get(cur_idx) {
+                                TapeElement::String(s) => tape.get_string(s),
+                                _ => return Err(tape.error(cur_idx, "field 
name")),
+                            };
+
+                            // Update child pos if match found
+                            let field_idx = match &self.field_name_to_index {
+                                Some(map) => map.get(field_name).copied(),
+                                None => fields.iter().position(|x| x.name() == 
field_name),
+                            };
+                            match field_idx {
+                                Some(field_idx) => {
+                                    child_pos[field_idx * row_count + row] = 
cur_idx + 1;
+                                }
+                                None => {
+                                    if self.strict_mode {
+                                        return 
Err(ArrowError::JsonError(format!(
+                                            "column '{field_name}' missing 
from schema",
+                                        )));
+                                    }
                                 }
                             }
+                            // Advance to next field
+                            cur_idx = tape.next(cur_idx + 1, "field value")?;
                         }
-                        // Advance to next field
-                        cur_idx = tape.next(cur_idx + 1, "field value")?;
                     }
                 }
-            }
-            StructMode::ListOnly => {
-                for (row, p) in pos.iter().enumerate() {
-                    let end_idx = match (tape.get(*p), nulls.as_mut()) {
-                        (TapeElement::StartList(end_idx), None) => end_idx,
-                        (TapeElement::StartList(end_idx), Some(nulls)) => {
-                            nulls.append(true);
-                            end_idx
-                        }
-                        (TapeElement::Null, Some(nulls)) => {
-                            nulls.append(false);
-                            continue;
-                        }
-                        (_, _) => return Err(tape.error(*p, "[")),
-                    };
+                StructMode::ListOnly => {
+                    for (row, p) in pos.iter().enumerate() {
+                        let end_idx = match (tape.get(*p), nulls.as_mut()) {
+                            (TapeElement::StartList(end_idx), None) => end_idx,
+                            (TapeElement::StartList(end_idx), Some(nulls)) => {
+                                nulls.append(true);
+                                end_idx
+                            }
+                            (TapeElement::Null, Some(nulls)) => {
+                                nulls.append(false);
+                                continue;
+                            }
+                            (_, _) => return Err(tape.error(*p, "[")),
+                        };
 
-                    let mut cur_idx = *p + 1;
-                    let mut entry_idx = 0;
-                    while cur_idx < end_idx {
-                        if entry_idx >= fields.len() {
+                        let mut cur_idx = *p + 1;
+                        let mut entry_idx = 0;
+                        while cur_idx < end_idx {
+                            if entry_idx >= fields.len() {
+                                return Err(ArrowError::JsonError(format!(
+                                    "found extra columns for {} fields",
+                                    fields.len()
+                                )));
+                            }
+                            child_pos[entry_idx * row_count + row] = cur_idx;

Review Comment:
   👍  this is a nice way to avoid allocations 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to