This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch maint-16.x.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 81940c69e6843c22c2fa61ad05c668c3241b3016
Author: Gang Wu <[email protected]>
AuthorDate: Wed May 8 09:52:57 2024 +0800

    GH-41431: [C++][Parquet][Dataset] Fix repeated scan on encrypted dataset 
(#41550)
    
    ### Rationale for this change
    
    When parquet dataset is reused to create multiple scanners, `FileMetaData` 
objects are cached to avoid parsing them again. However, these caused issues on 
encrypted files since internal file decryptors were no longer created by cached 
`FileMetaData` objects.
    
    ### What changes are included in this PR?
    
    Expose file_decryptor from FileMetaData and set it properly.
    
    ### Are these changes tested?
    
    Yes, modify the test to reproduce the issue and assure fixed.
    
    ### Are there any user-facing changes?
    
    No.
    * GitHub Issue: #41431
    
    Authored-by: Gang Wu <[email protected]>
    Signed-off-by: Gang Wu <[email protected]>
---
 .../arrow/dataset/file_parquet_encryption_test.cc  | 25 ++++---
 cpp/src/parquet/file_reader.cc                     | 83 ++++++++++++----------
 cpp/src/parquet/metadata.cc                        |  8 +++
 cpp/src/parquet/metadata.h                         |  2 +
 4 files changed, 70 insertions(+), 48 deletions(-)

diff --git a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc 
b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc
index 307017fd67..0287d593d1 100644
--- a/cpp/src/arrow/dataset/file_parquet_encryption_test.cc
+++ b/cpp/src/arrow/dataset/file_parquet_encryption_test.cc
@@ -148,17 +148,22 @@ class DatasetEncryptionTestBase : public ::testing::Test {
                          FileSystemDatasetFactory::Make(file_system_, selector,
                                                         file_format, 
factory_options));
 
-    // Read dataset into table
+    // Create the dataset
     ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
-    ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
-    ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
-    ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());
-
-    // Verify the data was read correctly
-    ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
-    // Validate the table
-    ASSERT_OK(combined_table->ValidateFull());
-    AssertTablesEqual(*combined_table, *table_);
+
+    // Reuse the dataset above to scan it twice to make sure decryption works 
correctly.
+    for (size_t i = 0; i < 2; ++i) {
+      // Read dataset into table
+      ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
+      ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
+      ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());
+
+      // Verify the data was read correctly
+      ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
+      // Validate the table
+      ASSERT_OK(combined_table->ValidateFull());
+      AssertTablesEqual(*combined_table, *table_);
+    }
   }
 
  protected:
diff --git a/cpp/src/parquet/file_reader.cc b/cpp/src/parquet/file_reader.cc
index b3dd1d6054..8fcb0870ce 100644
--- a/cpp/src/parquet/file_reader.cc
+++ b/cpp/src/parquet/file_reader.cc
@@ -215,16 +215,14 @@ class SerializedRowGroup : public 
RowGroupReader::Contents {
                      std::shared_ptr<::arrow::io::internal::ReadRangeCache> 
cached_source,
                      int64_t source_size, FileMetaData* file_metadata,
                      int row_group_number, ReaderProperties props,
-                     std::shared_ptr<Buffer> prebuffered_column_chunks_bitmap,
-                     std::shared_ptr<InternalFileDecryptor> file_decryptor = 
nullptr)
+                     std::shared_ptr<Buffer> prebuffered_column_chunks_bitmap)
       : source_(std::move(source)),
         cached_source_(std::move(cached_source)),
         source_size_(source_size),
         file_metadata_(file_metadata),
         properties_(std::move(props)),
         row_group_ordinal_(row_group_number),
-        
prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)),
-        file_decryptor_(std::move(file_decryptor)) {
+        
prebuffered_column_chunks_bitmap_(std::move(prebuffered_column_chunks_bitmap)) {
     row_group_metadata_ = file_metadata->RowGroup(row_group_number);
   }
 
@@ -263,10 +261,10 @@ class SerializedRowGroup : public 
RowGroupReader::Contents {
     }
 
     // The column is encrypted
-    std::shared_ptr<Decryptor> meta_decryptor =
-        GetColumnMetaDecryptor(crypto_metadata.get(), file_decryptor_.get());
-    std::shared_ptr<Decryptor> data_decryptor =
-        GetColumnDataDecryptor(crypto_metadata.get(), file_decryptor_.get());
+    std::shared_ptr<Decryptor> meta_decryptor = GetColumnMetaDecryptor(
+        crypto_metadata.get(), file_metadata_->file_decryptor().get());
+    std::shared_ptr<Decryptor> data_decryptor = GetColumnDataDecryptor(
+        crypto_metadata.get(), file_metadata_->file_decryptor().get());
     ARROW_DCHECK_NE(meta_decryptor, nullptr);
     ARROW_DCHECK_NE(data_decryptor, nullptr);
 
@@ -291,7 +289,6 @@ class SerializedRowGroup : public RowGroupReader::Contents {
   ReaderProperties properties_;
   int row_group_ordinal_;
   const std::shared_ptr<const Buffer> prebuffered_column_chunks_bitmap_;
-  std::shared_ptr<InternalFileDecryptor> file_decryptor_;
 };
 
 // ----------------------------------------------------------------------
@@ -316,7 +313,9 @@ class SerializedFile : public ParquetFileReader::Contents {
   }
 
   void Close() override {
-    if (file_decryptor_) file_decryptor_->WipeOutDecryptionKeys();
+    if (file_metadata_ && file_metadata_->file_decryptor()) {
+      file_metadata_->file_decryptor()->WipeOutDecryptionKeys();
+    }
   }
 
   std::shared_ptr<RowGroupReader> GetRowGroup(int i) override {
@@ -330,7 +329,7 @@ class SerializedFile : public ParquetFileReader::Contents {
 
     std::unique_ptr<SerializedRowGroup> contents = 
std::make_unique<SerializedRowGroup>(
         source_, cached_source_, source_size_, file_metadata_.get(), i, 
properties_,
-        std::move(prebuffered_column_chunks_bitmap), file_decryptor_);
+        std::move(prebuffered_column_chunks_bitmap));
     return std::make_shared<RowGroupReader>(std::move(contents));
   }
 
@@ -346,8 +345,9 @@ class SerializedFile : public ParquetFileReader::Contents {
           "forget to call ParquetFileReader::Open() first?");
     }
     if (!page_index_reader_) {
-      page_index_reader_ = PageIndexReader::Make(source_.get(), file_metadata_,
-                                                 properties_, 
file_decryptor_.get());
+      page_index_reader_ =
+          PageIndexReader::Make(source_.get(), file_metadata_, properties_,
+                                file_metadata_->file_decryptor().get());
     }
     return page_index_reader_;
   }
@@ -362,8 +362,8 @@ class SerializedFile : public ParquetFileReader::Contents {
           "forget to call ParquetFileReader::Open() first?");
     }
     if (!bloom_filter_reader_) {
-      bloom_filter_reader_ =
-          BloomFilterReader::Make(source_, file_metadata_, properties_, 
file_decryptor_);
+      bloom_filter_reader_ = BloomFilterReader::Make(source_, file_metadata_, 
properties_,
+                                                     
file_metadata_->file_decryptor());
       if (bloom_filter_reader_ == nullptr) {
         throw ParquetException("Cannot create BloomFilterReader");
       }
@@ -441,10 +441,12 @@ class SerializedFile : public ParquetFileReader::Contents 
{
     // Parse the footer depending on encryption type
     const bool is_encrypted_footer =
         memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 
4) == 0;
+    std::shared_ptr<InternalFileDecryptor> file_decryptor;
     if (is_encrypted_footer) {
       // Encrypted file with Encrypted footer.
       const std::pair<int64_t, uint32_t> read_size =
-          ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, 
metadata_len);
+          ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, 
metadata_len,
+                                                          &file_decryptor);
       // Read the actual footer
       metadata_start = read_size.first;
       metadata_len = read_size.second;
@@ -453,8 +455,8 @@ class SerializedFile : public ParquetFileReader::Contents {
       // Fall through
     }
 
-    const uint32_t read_metadata_len =
-        ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
+    const uint32_t read_metadata_len = ParseUnencryptedFileMetadata(
+        metadata_buffer, metadata_len, std::move(file_decryptor));
     auto file_decryption_properties = 
properties_.file_decryption_properties().get();
     if (is_encrypted_footer) {
       // Nothing else to do here.
@@ -550,34 +552,37 @@ class SerializedFile : public ParquetFileReader::Contents 
{
     // Parse the footer depending on encryption type
     const bool is_encrypted_footer =
         memcmp(footer_buffer->data() + footer_read_size - 4, kParquetEMagic, 
4) == 0;
+    std::shared_ptr<InternalFileDecryptor> file_decryptor;
     if (is_encrypted_footer) {
       // Encrypted file with Encrypted footer.
       std::pair<int64_t, uint32_t> read_size;
       BEGIN_PARQUET_CATCH_EXCEPTIONS
-      read_size =
-          ParseMetaDataOfEncryptedFileWithEncryptedFooter(metadata_buffer, 
metadata_len);
+      read_size = ParseMetaDataOfEncryptedFileWithEncryptedFooter(
+          metadata_buffer, metadata_len, &file_decryptor);
       END_PARQUET_CATCH_EXCEPTIONS
       // Read the actual footer
       int64_t metadata_start = read_size.first;
       metadata_len = read_size.second;
       return source_->ReadAsync(metadata_start, metadata_len)
-          .Then([this, metadata_len, is_encrypted_footer](
+          .Then([this, metadata_len, is_encrypted_footer, file_decryptor](
                     const std::shared_ptr<::arrow::Buffer>& metadata_buffer) {
             // Continue and read the file footer
-            return ParseMetaDataFinal(metadata_buffer, metadata_len, 
is_encrypted_footer);
+            return ParseMetaDataFinal(metadata_buffer, metadata_len, 
is_encrypted_footer,
+                                      file_decryptor);
           });
     }
     return ParseMetaDataFinal(std::move(metadata_buffer), metadata_len,
-                              is_encrypted_footer);
+                              is_encrypted_footer, std::move(file_decryptor));
   }
 
   // Continuation
-  ::arrow::Status ParseMetaDataFinal(std::shared_ptr<::arrow::Buffer> 
metadata_buffer,
-                                     uint32_t metadata_len,
-                                     const bool is_encrypted_footer) {
+  ::arrow::Status ParseMetaDataFinal(
+      std::shared_ptr<::arrow::Buffer> metadata_buffer, uint32_t metadata_len,
+      const bool is_encrypted_footer,
+      std::shared_ptr<InternalFileDecryptor> file_decryptor) {
     BEGIN_PARQUET_CATCH_EXCEPTIONS
-    const uint32_t read_metadata_len =
-        ParseUnencryptedFileMetadata(metadata_buffer, metadata_len);
+    const uint32_t read_metadata_len = ParseUnencryptedFileMetadata(
+        metadata_buffer, metadata_len, std::move(file_decryptor));
     auto file_decryption_properties = 
properties_.file_decryption_properties().get();
     if (is_encrypted_footer) {
       // Nothing else to do here.
@@ -608,11 +613,11 @@ class SerializedFile : public ParquetFileReader::Contents 
{
   // Maps row group ordinal and prebuffer status of its column chunks in the 
form of a
   // bitmap buffer.
   std::unordered_map<int, std::shared_ptr<Buffer>> prebuffered_column_chunks_;
-  std::shared_ptr<InternalFileDecryptor> file_decryptor_;
 
   // \return The true length of the metadata in bytes
-  uint32_t ParseUnencryptedFileMetadata(const std::shared_ptr<Buffer>& 
footer_buffer,
-                                        const uint32_t metadata_len);
+  uint32_t ParseUnencryptedFileMetadata(
+      const std::shared_ptr<Buffer>& footer_buffer, const uint32_t 
metadata_len,
+      std::shared_ptr<InternalFileDecryptor> file_decryptor);
 
   std::string HandleAadPrefix(FileDecryptionProperties* 
file_decryption_properties,
                               EncryptionAlgorithm& algo);
@@ -624,11 +629,13 @@ class SerializedFile : public ParquetFileReader::Contents 
{
 
   // \return The position and size of the actual footer
   std::pair<int64_t, uint32_t> ParseMetaDataOfEncryptedFileWithEncryptedFooter(
-      const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t 
footer_len);
+      const std::shared_ptr<Buffer>& crypto_metadata_buffer, uint32_t 
footer_len,
+      std::shared_ptr<InternalFileDecryptor>* file_decryptor);
 };
 
 uint32_t SerializedFile::ParseUnencryptedFileMetadata(
-    const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t 
metadata_len) {
+    const std::shared_ptr<Buffer>& metadata_buffer, const uint32_t 
metadata_len,
+    std::shared_ptr<InternalFileDecryptor> file_decryptor) {
   if (metadata_buffer->size() != metadata_len) {
     throw ParquetException("Failed reading metadata buffer (requested " +
                            std::to_string(metadata_len) + " bytes but got " +
@@ -637,7 +644,7 @@ uint32_t SerializedFile::ParseUnencryptedFileMetadata(
   uint32_t read_metadata_len = metadata_len;
   // The encrypted read path falls through to here, so pass in the decryptor
   file_metadata_ = FileMetaData::Make(metadata_buffer->data(), 
&read_metadata_len,
-                                      properties_, file_decryptor_);
+                                      properties_, std::move(file_decryptor));
   return read_metadata_len;
 }
 
@@ -645,7 +652,7 @@ std::pair<int64_t, uint32_t>
 SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
     const std::shared_ptr<::arrow::Buffer>& crypto_metadata_buffer,
     // both metadata & crypto metadata length
-    const uint32_t footer_len) {
+    const uint32_t footer_len, std::shared_ptr<InternalFileDecryptor>* 
file_decryptor) {
   // encryption with encrypted footer
   // Check if the footer_buffer contains the entire metadata
   if (crypto_metadata_buffer->size() != footer_len) {
@@ -664,7 +671,7 @@ 
SerializedFile::ParseMetaDataOfEncryptedFileWithEncryptedFooter(
   // Handle AAD prefix
   EncryptionAlgorithm algo = file_crypto_metadata->encryption_algorithm();
   std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
-  file_decryptor_ = std::make_shared<InternalFileDecryptor>(
+  *file_decryptor = std::make_shared<InternalFileDecryptor>(
       file_decryption_properties, file_aad, algo.algorithm,
       file_crypto_metadata->key_metadata(), properties_.memory_pool());
 
@@ -683,12 +690,12 @@ void 
SerializedFile::ParseMetaDataOfEncryptedFileWithPlaintextFooter(
     EncryptionAlgorithm algo = file_metadata_->encryption_algorithm();
     // Handle AAD prefix
     std::string file_aad = HandleAadPrefix(file_decryption_properties, algo);
-    file_decryptor_ = std::make_shared<InternalFileDecryptor>(
+    auto file_decryptor = std::make_shared<InternalFileDecryptor>(
         file_decryption_properties, file_aad, algo.algorithm,
         file_metadata_->footer_signing_key_metadata(), 
properties_.memory_pool());
     // set the InternalFileDecryptor in the metadata as well, as it's used
     // for signature verification and for ColumnChunkMetaData creation.
-    file_metadata_->set_file_decryptor(file_decryptor_);
+    file_metadata_->set_file_decryptor(std::move(file_decryptor));
 
     if (file_decryption_properties->check_plaintext_footer_integrity()) {
       if (metadata_len - read_metadata_len !=
diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc
index 3f101b5ae3..b24883cdc1 100644
--- a/cpp/src/parquet/metadata.cc
+++ b/cpp/src/parquet/metadata.cc
@@ -826,6 +826,10 @@ class FileMetaData::FileMetaDataImpl {
     file_decryptor_ = std::move(file_decryptor);
   }
 
+  const std::shared_ptr<InternalFileDecryptor>& file_decryptor() const {
+    return file_decryptor_;
+  }
+
  private:
   friend FileMetaDataBuilder;
   uint32_t metadata_len_ = 0;
@@ -947,6 +951,10 @@ void FileMetaData::set_file_decryptor(
   impl_->set_file_decryptor(std::move(file_decryptor));
 }
 
+const std::shared_ptr<InternalFileDecryptor>& FileMetaData::file_decryptor() 
const {
+  return impl_->file_decryptor();
+}
+
 ParquetVersion::type FileMetaData::version() const {
   switch (impl_->version()) {
     case 1:
diff --git a/cpp/src/parquet/metadata.h b/cpp/src/parquet/metadata.h
index 640b898024..9fc30df58e 100644
--- a/cpp/src/parquet/metadata.h
+++ b/cpp/src/parquet/metadata.h
@@ -399,12 +399,14 @@ class PARQUET_EXPORT FileMetaData {
  private:
   friend FileMetaDataBuilder;
   friend class SerializedFile;
+  friend class SerializedRowGroup;
 
   explicit FileMetaData(const void* serialized_metadata, uint32_t* 
metadata_len,
                         const ReaderProperties& properties,
                         std::shared_ptr<InternalFileDecryptor> file_decryptor 
= NULLPTR);
 
   void set_file_decryptor(std::shared_ptr<InternalFileDecryptor> 
file_decryptor);
+  const std::shared_ptr<InternalFileDecryptor>& file_decryptor() const;
 
   // PIMPL Idiom
   FileMetaData();

Reply via email to