andersonm-ibm commented on a change in pull request #10450: URL: https://github.com/apache/arrow/pull/10450#discussion_r700784235
########## File path: python/pyarrow/tests/parquet/test_parquet_encryption.py ########## @@ -0,0 +1,388 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pyarrow as pa +import pytest + +try: + import pyarrow.parquet as pq +except ImportError: + pq = None + +import base64 +from cryptography.fernet import Fernet +from cryptography.fernet import InvalidToken +from collections import OrderedDict +from datetime import timedelta + +DATA_TABLE = pa.Table.from_pydict( + OrderedDict([ + ('a', pa.array([1, 2, 3])), + ('b', pa.array(['a', 'b', 'c'])), + ('c', pa.array(['x', 'y', 'z'])) + ]) +) +PARQUET_NAME = 'encrypted_table.in_mem.parquet' +FOOTER_KEY = Fernet.generate_key() +FOOTER_KEY_NAME = "footer_key" +COL_KEY = Fernet.generate_key() +COL_KEY_NAME = "col_key" +BASIC_ENCRYPTION_CONFIG = pq.EncryptionConfiguration( + footer_key=FOOTER_KEY_NAME, + column_keys={ + COL_KEY_NAME: ["a", "b"], + }, +) + + +class InMemoryKmsClient(pq.KmsClient): + """This is a mock class implementation of KmsClient, built for testing only. + """ + + def __init__(self, config): + """Create an InMemoryKmsClient instance.""" + pq.KmsClient.__init__(self) + self.master_keys_map = config.custom_kms_conf + + def wrap_key(self, key_bytes, master_key_identifier): + """Wrap key key_bytes with key identified by master_key_identifier. + The result contains nonce concatenated before the encrypted key.""" + master_key = self.master_keys_map[master_key_identifier] + # Create a cipher object to encrypt data + cipher = Fernet(master_key.encode('utf-8')) + encrypted_key = cipher.encrypt(key_bytes) + result = base64.b64encode(encrypted_key) + return result + + def unwrap_key(self, wrapped_key, master_key_identifier): + """Unwrap wrapped_key with key identified by master_key_identifier""" + master_key = self.master_keys_map[master_key_identifier] + decoded_wrapped_key = base64.b64decode(wrapped_key) + # Create a cipher object to decrypt data + cipher = Fernet(master_key.encode('utf-8')) + decrypted_key = cipher.decrypt(decoded_wrapped_key) + return decrypted_key + + +def verify_file_encrypted(path): + """Verify that the file is encrypted by looking at its first 4 bytes. + If it's the magic string PARE + then this is a parquet with encrypted footer.""" + with open(path, "rb") as file: + magic_str = file.read(4) + # Verify magic string for parquet with encrypted footer is PARE + assert(magic_str == b'PARE') + + [email protected] +def test_encrypted_parquet_write_read(tempdir): + """Write an encrypted parquet, verify it's encrypted, and then read it.""" + path = tempdir / PARQUET_NAME + table = DATA_TABLE + + # Encrypt the footer with the footer key, + # encrypt column `a` and column `b` with another key, + # keep `c` plaintext + encryption_config = pq.EncryptionConfiguration( + footer_key=FOOTER_KEY_NAME, + column_keys={ + COL_KEY_NAME: ["a", "b"], + }, + encryption_algorithm="AES_GCM_V1", + cache_lifetime=timedelta(minutes=5.0), + data_key_length_bits=256) + + kms_connection_config = pq.KmsConnectionConfig( + custom_kms_conf={ + FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"), + COL_KEY_NAME: COL_KEY.decode("UTF-8"), + } + ) + + def kms_factory(kms_connection_configuration): + return InMemoryKmsClient(kms_connection_configuration) + + crypto_factory = pq.CryptoFactory(kms_factory) + # Write with encryption properties + write_encrypted_parquet(path, table, encryption_config, + kms_connection_config, crypto_factory) + verify_file_encrypted(path) + + # Read with decryption properties + decryption_config = pq.DecryptionConfiguration( + cache_lifetime=timedelta(minutes=5.0)) + result_table = read_encrypted_parquet( + path, decryption_config, kms_connection_config, crypto_factory) + assert table.equals(result_table) + + +def write_encrypted_parquet(path, table, encryption_config, + kms_connection_config, crypto_factory): + file_encryption_properties = crypto_factory.file_encryption_properties( + kms_connection_config, encryption_config) + assert(file_encryption_properties is not None) + with pq.ParquetWriter( + path, table.schema, + encryption_properties=file_encryption_properties) as writer: + writer.write_table(table) + + +def read_encrypted_parquet(path, decryption_config, + kms_connection_config, crypto_factory): + file_decryption_properties = crypto_factory.file_decryption_properties( + kms_connection_config, decryption_config) + assert(file_decryption_properties is not None) + meta = pq.read_metadata( + path, decryption_properties=file_decryption_properties) + assert(meta.num_columns == 3) + schema = pq.read_schema( + path, decryption_properties=file_decryption_properties) + assert(len(schema.names) == 3) + + result = pq.ParquetFile( + path, decryption_properties=file_decryption_properties) + return result.read(use_threads=False) Review comment: Yes, there is separate work on fixing that. Added a special test for multithreaded reads, will enable it after the fix. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
