pitrou commented on code in PR #44990:
URL: https://github.com/apache/arrow/pull/44990#discussion_r1880287234


##########
cpp/src/arrow/dataset/file_parquet_encryption_test.cc:
##########
@@ -151,31 +159,56 @@ class DatasetEncryptionTestBase : public ::testing::Test {
     // Create the dataset
     ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
 
+    std::vector<std::future<Result<std::shared_ptr<Table>>>> threads;
+
+    // Read dataset above multiple times concurrently to see that is 
thread-safe.
     // Reuse the dataset above to scan it twice to make sure decryption works 
correctly.
-    for (size_t i = 0; i < 2; ++i) {
-      // Read dataset into table
-      ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
-      ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
-      ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());
-
-      // Verify the data was read correctly
-      ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
-      // Validate the table
-      ASSERT_OK(combined_table->ValidateFull());
-      AssertTablesEqual(*combined_table, *table_);
+    const size_t attempts = concurrently ? 1000 : 2;
+    for (size_t i = 0; i < attempts; ++i) {
+      if (concurrently) {
+        threads.push_back(std::async(DatasetEncryptionTestBase::read, 
dataset));

Review Comment:
   We don't know how `std::async` is implemented under the hood, whether it's 
gonna use multiple threads etc.
   
   We do have our own `Future` class that we can use together with a 
`ThreadPool`, can we use that?



##########
cpp/src/arrow/dataset/file_parquet_encryption_test.cc:
##########
@@ -151,31 +159,56 @@ class DatasetEncryptionTestBase : public ::testing::Test {
     // Create the dataset
     ASSERT_OK_AND_ASSIGN(auto dataset, dataset_factory->Finish());
 
+    std::vector<std::future<Result<std::shared_ptr<Table>>>> threads;
+
+    // Read dataset above multiple times concurrently to see that is 
thread-safe.
     // Reuse the dataset above to scan it twice to make sure decryption works 
correctly.
-    for (size_t i = 0; i < 2; ++i) {
-      // Read dataset into table
-      ASSERT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
-      ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
-      ASSERT_OK_AND_ASSIGN(auto read_table, scanner->ToTable());
-
-      // Verify the data was read correctly
-      ASSERT_OK_AND_ASSIGN(auto combined_table, read_table->CombineChunks());
-      // Validate the table
-      ASSERT_OK(combined_table->ValidateFull());
-      AssertTablesEqual(*combined_table, *table_);
+    const size_t attempts = concurrently ? 1000 : 2;
+    for (size_t i = 0; i < attempts; ++i) {
+      if (concurrently) {
+        threads.push_back(std::async(DatasetEncryptionTestBase::read, 
dataset));
+      } else {
+        ASSERT_OK_AND_ASSIGN(auto read_table, read(dataset));
+        AssertTablesEqual(*read_table, *table_);
+      }
+    }
+    if (concurrently) {
+      for (auto& thread : threads) {
+        ASSERT_OK_AND_ASSIGN(auto read_table, thread.get());
+        AssertTablesEqual(*read_table, *table_);
+      }
     }
   }
 
+  static Result<std::shared_ptr<Table>> read(const std::shared_ptr<Dataset>& 
dataset) {
+    // Read dataset into table
+    ARROW_ASSIGN_OR_RAISE(auto scanner_builder, dataset->NewScan());
+    ARROW_ASSIGN_OR_RAISE(auto scanner, scanner_builder->Finish());
+    ARROW_ASSIGN_OR_RAISE(auto read_table, scanner->ToTable());
+
+    // Verify the data was read correctly
+    ARROW_ASSIGN_OR_RAISE(auto combined_table, read_table->CombineChunks());
+    // Validate the table
+    RETURN_NOT_OK(combined_table->ValidateFull());
+    return combined_table;
+  }
+
  protected:
   std::shared_ptr<fs::FileSystem> file_system_;
   std::shared_ptr<Table> table_;
   std::shared_ptr<Partitioning> partitioning_;
   std::shared_ptr<parquet::encryption::CryptoFactory> crypto_factory_;
   std::shared_ptr<parquet::encryption::KmsConnectionConfig> 
kms_connection_config_;
+
+ private:

Review Comment:
   Nit: these are superfluous
   ```suggestion
   ```



##########
cpp/src/arrow/dataset/file_parquet_encryption_test.cc:
##########
@@ -54,6 +55,9 @@ namespace dataset {
 // Base class to test writing and reading encrypted dataset.
 class DatasetEncryptionTestBase : public ::testing::Test {
  public:
+  explicit DatasetEncryptionTestBase(bool uniform_encryption = false)
+      : uniform_encryption(uniform_encryption) {}

Review Comment:
   Instead of using a constructor argument and explicitly instantiating 
separate subclasses, you could for example use value parameterization: 
https://github.com/google/googletest/blob/main/docs/advanced.md#value-parameterized-tests



##########
cpp/src/parquet/encryption/encryption_internal.cc:
##########
@@ -52,25 +52,85 @@ constexpr int32_t kBufferSizeLength = 4;
     throw ParquetException("Couldn't init ALG decryption");           \
   }
 
-class AesEncryptor::AesEncryptorImpl {
+class AesEncryptionContext {
+ public:
+  AesEncryptionContext(ParquetCipher::type alg_id, int32_t key_len, bool 
metadata,
+                       bool write_length) {
+    openssl::EnsureInitialized();
+
+    length_buffer_length_ = write_length ? kBufferSizeLength : 0;
+    ciphertext_size_delta_ = length_buffer_length_ + kNonceLength;
+    if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
+      aes_mode_ = kGcmMode;
+      ciphertext_size_delta_ += kGcmTagLength;
+    } else {
+      aes_mode_ = kCtrMode;
+    }
+
+    if (16 != key_len && 24 != key_len && 32 != key_len) {
+      std::stringstream ss;
+      ss << "Wrong key length: " << key_len;
+      throw ParquetException(ss.str());
+    }
+
+    key_length_ = key_len;
+  };
+
+  virtual ~AesEncryptionContext() = default;
+
+ protected:
+  void InitCipherContext() {
+    if (ctx_) return;
+
+    ctx_ = std::unique_ptr<EVP_CIPHER_CTX, 
decltype(ctxDeleter)>(EVP_CIPHER_CTX_new(),
+                                                                 ctxDeleter);
+    if (!ctx_) throw ParquetException("Couldn't init cipher context");
+    InitCipherContext(ctx_.get());
+  }
+
+  virtual void InitCipherContext(EVP_CIPHER_CTX* ctx) = 0;

Review Comment:
   This sort of indirection to the derived class looks convoluted, why not 
reverse the control flow and have the derived class call the base method before 
doing its own initialization?



##########
cpp/src/parquet/encryption/encryption_internal.cc:
##########
@@ -52,25 +52,85 @@ constexpr int32_t kBufferSizeLength = 4;
     throw ParquetException("Couldn't init ALG decryption");           \
   }
 
-class AesEncryptor::AesEncryptorImpl {
+class AesEncryptionContext {
+ public:
+  AesEncryptionContext(ParquetCipher::type alg_id, int32_t key_len, bool 
metadata,
+                       bool write_length) {
+    openssl::EnsureInitialized();
+
+    length_buffer_length_ = write_length ? kBufferSizeLength : 0;
+    ciphertext_size_delta_ = length_buffer_length_ + kNonceLength;
+    if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
+      aes_mode_ = kGcmMode;
+      ciphertext_size_delta_ += kGcmTagLength;
+    } else {
+      aes_mode_ = kCtrMode;
+    }
+
+    if (16 != key_len && 24 != key_len && 32 != key_len) {
+      std::stringstream ss;
+      ss << "Wrong key length: " << key_len;
+      throw ParquetException(ss.str());
+    }
+
+    key_length_ = key_len;
+  };
+
+  virtual ~AesEncryptionContext() = default;
+
+ protected:
+  void InitCipherContext() {
+    if (ctx_) return;
+
+    ctx_ = std::unique_ptr<EVP_CIPHER_CTX, 
decltype(ctxDeleter)>(EVP_CIPHER_CTX_new(),
+                                                                 ctxDeleter);
+    if (!ctx_) throw ParquetException("Couldn't init cipher context");
+    InitCipherContext(ctx_.get());
+  }
+
+  virtual void InitCipherContext(EVP_CIPHER_CTX* ctx) = 0;
+
+  std::function<void(EVP_CIPHER_CTX*)> ctxDeleter = [](EVP_CIPHER_CTX* ctx) {

Review Comment:
   Please let's stick to the coding conventions, e.g.:
   ```suggestion
     static inline std::function<void(EVP_CIPHER_CTX*)> ctx_deleter_ = 
[](EVP_CIPHER_CTX* ctx) {
   ```
   
   (also can make it `static` since it doesn't need to be per instance)



##########
cpp/src/parquet/encryption/internal_file_encryptor.cc:
##########
@@ -50,21 +50,6 @@ 
InternalFileEncryptor::InternalFileEncryptor(FileEncryptionProperties* propertie
   properties_->set_utilized();
 }
 
-void InternalFileEncryptor::WipeOutEncryptionKeys() {
-  properties_->WipeOutEncryptionKeys();

Review Comment:
   Same here



##########
cpp/src/parquet/encryption/encryption_internal.cc:
##########
@@ -52,25 +52,85 @@ constexpr int32_t kBufferSizeLength = 4;
     throw ParquetException("Couldn't init ALG decryption");           \
   }
 
-class AesEncryptor::AesEncryptorImpl {
+class AesEncryptionContext {
+ public:
+  AesEncryptionContext(ParquetCipher::type alg_id, int32_t key_len, bool 
metadata,
+                       bool write_length) {
+    openssl::EnsureInitialized();
+
+    length_buffer_length_ = write_length ? kBufferSizeLength : 0;
+    ciphertext_size_delta_ = length_buffer_length_ + kNonceLength;
+    if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
+      aes_mode_ = kGcmMode;
+      ciphertext_size_delta_ += kGcmTagLength;
+    } else {
+      aes_mode_ = kCtrMode;
+    }
+
+    if (16 != key_len && 24 != key_len && 32 != key_len) {
+      std::stringstream ss;
+      ss << "Wrong key length: " << key_len;
+      throw ParquetException(ss.str());
+    }
+
+    key_length_ = key_len;
+  };
+
+  virtual ~AesEncryptionContext() = default;
+
+ protected:
+  void InitCipherContext() {

Review Comment:
   Why not do this in the constructor?



##########
cpp/src/arrow/dataset/file_parquet.cc:
##########
@@ -544,7 +545,7 @@ Future<std::shared_ptr<parquet::arrow::FileReader>> 
ParquetFileFormat::GetReader
                       // here we know there are no other waiters on the reader.
                       
std::move(const_cast<std::unique_ptr<parquet::ParquetFileReader>&>(
                           reader)),
-                      std::move(arrow_properties), &arrow_reader));
+                      arrow_properties, &arrow_reader));

Review Comment:
   I think we can keep the `std::move` here.



##########
cpp/src/parquet/encryption/internal_file_decryptor.cc:
##########
@@ -64,16 +64,6 @@ 
InternalFileDecryptor::InternalFileDecryptor(FileDecryptionProperties* propertie
   properties_->set_utilized();
 }
 
-void InternalFileDecryptor::WipeOutDecryptionKeys() {
-  std::lock_guard<std::mutex> lock(mutex_);
-  properties_->WipeOutDecryptionKeys();

Review Comment:
   `properties_->WipeOutDecryptionKeys` will clear the private key, we should 
keep it no?
   (even though it's currently not reliable: see 
https://github.com/apache/arrow/issues/31603)



##########
cpp/src/parquet/encryption/encryption_internal.h:
##########
@@ -44,6 +44,8 @@ constexpr int8_t kOffsetIndex = 7;
 constexpr int8_t kBloomFilterHeader = 8;
 constexpr int8_t kBloomFilterBitset = 9;
 
+class AesEncryptionContext;

Review Comment:
   Is this needed? It doesn't seem used below.



##########
cpp/src/parquet/encryption/encryption_internal.cc:
##########
@@ -52,25 +52,85 @@ constexpr int32_t kBufferSizeLength = 4;
     throw ParquetException("Couldn't init ALG decryption");           \
   }
 
-class AesEncryptor::AesEncryptorImpl {
+class AesEncryptionContext {
+ public:
+  AesEncryptionContext(ParquetCipher::type alg_id, int32_t key_len, bool 
metadata,
+                       bool write_length) {
+    openssl::EnsureInitialized();
+
+    length_buffer_length_ = write_length ? kBufferSizeLength : 0;
+    ciphertext_size_delta_ = length_buffer_length_ + kNonceLength;
+    if (metadata || (ParquetCipher::AES_GCM_V1 == alg_id)) {
+      aes_mode_ = kGcmMode;
+      ciphertext_size_delta_ += kGcmTagLength;
+    } else {
+      aes_mode_ = kCtrMode;
+    }
+
+    if (16 != key_len && 24 != key_len && 32 != key_len) {
+      std::stringstream ss;
+      ss << "Wrong key length: " << key_len;
+      throw ParquetException(ss.str());
+    }
+
+    key_length_ = key_len;
+  };
+
+  virtual ~AesEncryptionContext() = default;
+
+ protected:
+  void InitCipherContext() {
+    if (ctx_) return;
+
+    ctx_ = std::unique_ptr<EVP_CIPHER_CTX, 
decltype(ctxDeleter)>(EVP_CIPHER_CTX_new(),
+                                                                 ctxDeleter);
+    if (!ctx_) throw ParquetException("Couldn't init cipher context");
+    InitCipherContext(ctx_.get());
+  }
+
+  virtual void InitCipherContext(EVP_CIPHER_CTX* ctx) = 0;
+
+  std::function<void(EVP_CIPHER_CTX*)> ctxDeleter = [](EVP_CIPHER_CTX* ctx) {
+    EVP_CIPHER_CTX_free(ctx);
+  };
+
+  /// Create a new cipher context that auto-frees
+  /// This duplicates un unused but initialized private context to avoid going 
through
+  /// initialization

Review Comment:
   I agree that ideally we wouldn't do this at all. I'd rather remove any 
optimization around this first, and re-optimize in later PRs if we find that 
cipher creation takes a significant amount of time (it sounds unlikely it will).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to