jecsand838 commented on code in PR #8123: URL: https://github.com/apache/arrow-rs/pull/8123#discussion_r2277212769
########## arrow-avro/src/writer/mod.rs: ########## @@ -0,0 +1,350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro writer implementation for the `arrow-avro` crate. +//! +//! # Overview +//! +//! * Use **`AvroWriter`** (Object Container File) when you want a +//! self‑contained Avro file with header, schema JSON, optional compression, +//! blocks, and sync markers. +//! * Use **`AvroStreamWriter`** (raw binary stream) when you already know the +//! schema out‑of‑band (i.e., via a schema registry) and need a stream +//! of Avro‑encoded records with minimal framing. +//! + +/// Encodes `RecordBatch` into the Avro binary format. +pub mod encoder; +/// Logic for different Avro container file formats. +pub mod format; + +use crate::compression::CompressionCodec; +use crate::schema::AvroSchema; +use crate::writer::encoder::{encode_record_batch, write_long}; +use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, Schema}; +use std::io::{self, Write}; +use std::sync::Arc; + +/// Builder to configure and create a `Writer`. +#[derive(Debug, Clone)] +pub struct WriterBuilder { + schema: Schema, + codec: Option<CompressionCodec>, +} + +impl WriterBuilder { + /// Create a new builder with default settings. + pub fn new(schema: Schema) -> Self { + Self { + schema, + codec: None, + } + } + + /// Change the compression codec. + pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self { + self.codec = codec; + self + } + + /// Create a new `Writer` with specified `AvroFormat` and builder options. + pub fn build<W, F>(self, writer: W) -> Writer<W, F> + where + W: Write, + F: AvroFormat, + { + Writer { + writer, + schema: Arc::from(self.schema), + format: F::default(), + compression: self.codec, + started: false, + } + } +} + +/// Generic Avro writer. +#[derive(Debug)] +pub struct Writer<W: Write, F: AvroFormat> { + writer: W, + schema: Arc<Schema>, + format: F, + compression: Option<CompressionCodec>, + started: bool, +} + +/// Alias for an Avro **Object Container File** writer. +pub type AvroWriter<W> = Writer<W, AvroOcfFormat>; +/// Alias for a raw Avro **binary stream** writer. +pub type AvroStreamWriter<W> = Writer<W, AvroBinaryFormat>; + +impl<W: Write> Writer<W, AvroOcfFormat> { + /// Convenience constructor – same as + pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> { + Ok(WriterBuilder::new(schema).build::<W, AvroOcfFormat>(writer)) + } + + /// Change the compression codec after construction. + pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self { + self.compression = codec; + self + } + + /// Return a reference to the 16‑byte sync marker generated for this file. + pub fn sync_marker(&self) -> Option<&[u8; 16]> { + self.format.sync_marker() + } +} + +impl<W: Write> Writer<W, AvroBinaryFormat> { + /// Convenience constructor to create a new [`AvroStreamWriter`]. + pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> { + Ok(WriterBuilder::new(schema).build::<W, AvroBinaryFormat>(writer)) + } +} + +impl<W: Write, F: AvroFormat> Writer<W, F> { + /// Serialize one [`RecordBatch`] to the output. + pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + if !self.started { + self.format + .start_stream(&mut self.writer, &self.schema, self.compression)?; + self.started = true; + } + if batch.schema() != self.schema { + return Err(ArrowError::SchemaError( + "Schema of RecordBatch differs from Writer schema".to_string(), + )); + } + match self.format.sync_marker() { + Some(&sync) => self.write_ocf_block(batch, &sync), + None => self.write_stream(batch), + } + } + + /// A convenience method to write a slice of [`RecordBatch`]. + /// + /// This is equivalent to calling `write` for each batch in the slice. + pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> { + for b in batches { + self.write(b)?; + } + Ok(()) + } + + /// Flush remaining buffered data and (for OCF) ensure the header is present. + pub fn finish(&mut self) -> Result<(), ArrowError> { + if !self.started { + self.format + .start_stream(&mut self.writer, &self.schema, self.compression)?; + self.started = true; + } + self.writer + .flush() + .map_err(|e| ArrowError::IoError(format!("Error flushing writer: {e}"), e)) + } + + /// Consume the writer, returning the underlying output object. + pub fn into_inner(self) -> W { + self.writer + } + + fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { + let mut buf = Vec::<u8>::with_capacity(1024); + encode_record_batch(batch, &mut buf)?; + let encoded = match self.compression { + Some(codec) => codec.compress(&buf)?, + None => buf, + }; + write_long(&mut self.writer, batch.num_rows() as i64)?; + write_long(&mut self.writer, encoded.len() as i64)?; + self.writer + .write_all(&encoded) + .map_err(|e| ArrowError::IoError(format!("Error writing Avro block: {e}"), e))?; + self.writer + .write_all(sync) + .map_err(|e| ArrowError::IoError(format!("Error writing Avro sync: {e}"), e))?; + Ok(()) + } + + fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + encode_record_batch(batch, &mut self.writer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::reader::ReaderBuilder; + use crate::test_util::arrow_test_data; + use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use std::fs::{remove_file, File}; + use std::io::BufReader; + use std::sync::Arc; + + fn make_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Binary, false), + ]) + } + + fn make_batch() -> RecordBatch { + let ids = Int32Array::from(vec![1, 2, 3]); + let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]); + RecordBatch::try_new( + Arc::new(make_schema()), + vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef], + ) + .expect("failed to build test RecordBatch") + } + + fn contains_ascii(haystack: &[u8], needle: &[u8]) -> bool { + haystack.windows(needle.len()).any(|w| w == needle) + } + + fn unique_temp_path(prefix: &str) -> std::path::PathBuf { Review Comment: I'll make that change in my next PR. Thanks for calling that out. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org