Repository: arrow Updated Branches: refs/heads/master 1f81adcc8 -> 282103012
ARROW-508: [C++] Add basic threadsafety to normal files and memory maps This patch is stacked on ARROW-494, so will need to be rebased. * Since the naive `ReadAt` implementation involves a Seek and a Read, this locks until the read is completed. * Normal file reads block until completion * File writes block until completion This covers the threadsafety requirements for parquet-cpp at least. For on-disk files, the following methods are now threadsafe: * `ArrowInputFile::Read` and `ArrowInputFile::ReadAt` * `ArrowOutputStream::Write` parquet-cpp calls `Seek` in a couple places: https://github.com/apache/parquet-cpp/blob/master/src/parquet/file/reader-internal.cc#L257 Strictly speaking, if two threads are trying to read the same file from the same input source, this could have a race condition in esoteric circumstances. I'm going to report a bug to change these to `ReadAt` which can be more easily made threadsafe Author: Wes McKinney <[email protected]> Closes #300 from wesm/ARROW-508 and squashes the following commits: e57156c [Wes McKinney] Make base ReadableFileInterface::ReadAt and some file functions threadsafe Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/28210301 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/28210301 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/28210301 Branch: refs/heads/master Commit: 2821030124eb3e884b0e48f09c38b54f00430b13 Parents: 1f81adc Author: Wes McKinney <[email protected]> Authored: Mon Jan 23 09:11:26 2017 -0500 Committer: Wes McKinney <[email protected]> Committed: Mon Jan 23 09:11:26 2017 -0500 ---------------------------------------------------------------------- cpp/src/arrow/io/file.cc | 10 ++++- cpp/src/arrow/io/file.h | 9 ++++- cpp/src/arrow/io/interfaces.cc | 3 ++ cpp/src/arrow/io/interfaces.h | 12 +++++- cpp/src/arrow/io/io-file-test.cc | 69 +++++++++++++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/28210301/cpp/src/arrow/io/file.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/io/file.cc b/cpp/src/arrow/io/file.cc index 3bf8dfa..ff58e53 100644 --- a/cpp/src/arrow/io/file.cc +++ b/cpp/src/arrow/io/file.cc @@ -76,6 +76,7 @@ #include <cstring> #include <iostream> #include <limits> +#include <mutex> #include <sstream> #include <vector> @@ -350,6 +351,7 @@ class OSFile { } Status Read(int64_t nbytes, int64_t* bytes_read, uint8_t* out) { + std::lock_guard<std::mutex> guard(lock_); return FileRead(fd_, out, nbytes, bytes_read); } @@ -361,6 +363,7 @@ class OSFile { Status Tell(int64_t* pos) const { return FileTell(fd_, pos); } Status Write(const uint8_t* data, int64_t length) { + std::lock_guard<std::mutex> guard(lock_); if (length < 0) { return Status::IOError("Length must be non-negative"); } return FileWrite(fd_, data, length); } @@ -377,6 +380,8 @@ class OSFile { protected: std::string path_; + std::mutex lock_; + // File descriptor int fd_; @@ -649,6 +654,8 @@ bool MemoryMappedFile::supports_zero_copy() const { } Status MemoryMappedFile::WriteAt(int64_t position, const uint8_t* data, int64_t nbytes) { + std::lock_guard<std::mutex> guard(lock_); + if (!memory_map_->opened() || !memory_map_->writable()) { return Status::IOError("Unable to write"); } @@ -658,13 +665,14 @@ Status MemoryMappedFile::WriteAt(int64_t position, const uint8_t* data, int64_t } Status MemoryMappedFile::Write(const uint8_t* data, int64_t nbytes) { + std::lock_guard<std::mutex> guard(lock_); + if (!memory_map_->opened() || !memory_map_->writable()) { return Status::IOError("Unable to write"); } if (nbytes + memory_map_->position() > memory_map_->size()) { return Status::Invalid("Cannot write past end of memory map"); } - return WriteInternal(data, nbytes); } http://git-wip-us.apache.org/repos/asf/arrow/blob/28210301/cpp/src/arrow/io/file.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/io/file.h b/cpp/src/arrow/io/file.h index 930346b..fe55e96 100644 --- a/cpp/src/arrow/io/file.h +++ b/cpp/src/arrow/io/file.h @@ -50,6 +50,8 @@ class ARROW_EXPORT FileOutputStream : public OutputStream { // OutputStream interface Status Close() override; Status Tell(int64_t* position) override; + + // Write bytes to the stream. Thread-safe Status Write(const uint8_t* data, int64_t nbytes) override; int file_descriptor() const; @@ -76,6 +78,7 @@ class ARROW_EXPORT ReadableFile : public ReadableFileInterface { Status Close() override; Status Tell(int64_t* position) override; + // Read bytes from the file. Thread-safe Status Read(int64_t nbytes, int64_t* bytes_read, uint8_t* buffer) override; Status Read(int64_t nbytes, std::shared_ptr<Buffer>* out) override; @@ -112,16 +115,18 @@ class ARROW_EXPORT MemoryMappedFile : public ReadWriteFileInterface { Status Seek(int64_t position) override; - // Required by ReadableFileInterface, copies memory into out + // Required by ReadableFileInterface, copies memory into out. Not thread-safe Status Read(int64_t nbytes, int64_t* bytes_read, uint8_t* out) override; - // Zero copy read + // Zero copy read. Not thread-safe Status Read(int64_t nbytes, std::shared_ptr<Buffer>* out) override; bool supports_zero_copy() const override; + /// Write data at the current position in the file. Thread-safe Status Write(const uint8_t* data, int64_t nbytes) override; + /// Write data at a particular position in the file. Thread-safe Status WriteAt(int64_t position, const uint8_t* data, int64_t nbytes) override; // @return: the size in bytes of the memory source http://git-wip-us.apache.org/repos/asf/arrow/blob/28210301/cpp/src/arrow/io/interfaces.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/io/interfaces.cc b/cpp/src/arrow/io/interfaces.cc index 8040f93..7e78caa 100644 --- a/cpp/src/arrow/io/interfaces.cc +++ b/cpp/src/arrow/io/interfaces.cc @@ -19,6 +19,7 @@ #include <cstdint> #include <memory> +#include <mutex> #include "arrow/buffer.h" #include "arrow/status.h" @@ -34,12 +35,14 @@ ReadableFileInterface::ReadableFileInterface() { Status ReadableFileInterface::ReadAt( int64_t position, int64_t nbytes, int64_t* bytes_read, uint8_t* out) { + std::lock_guard<std::mutex> guard(lock_); RETURN_NOT_OK(Seek(position)); return Read(nbytes, bytes_read, out); } Status ReadableFileInterface::ReadAt( int64_t position, int64_t nbytes, std::shared_ptr<Buffer>* out) { + std::lock_guard<std::mutex> guard(lock_); RETURN_NOT_OK(Seek(position)); return Read(nbytes, out); } http://git-wip-us.apache.org/repos/asf/arrow/blob/28210301/cpp/src/arrow/io/interfaces.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/io/interfaces.h b/cpp/src/arrow/io/interfaces.h index fdb3788..7868090 100644 --- a/cpp/src/arrow/io/interfaces.h +++ b/cpp/src/arrow/io/interfaces.h @@ -20,6 +20,7 @@ #include <cstdint> #include <memory> +#include <mutex> #include <string> #include "arrow/util/macros.h" @@ -99,14 +100,21 @@ class ARROW_EXPORT ReadableFileInterface : public InputStream, public Seekable { virtual bool supports_zero_copy() const = 0; - // Read at position, provide default implementations using Read(...), but can - // be overridden + /// Read at position, provide default implementations using Read(...), but can + /// be overridden + /// + /// Default implementation is thread-safe virtual Status ReadAt( int64_t position, int64_t nbytes, int64_t* bytes_read, uint8_t* out); + /// Default implementation is thread-safe virtual Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr<Buffer>* out); + std::mutex& lock() { return lock_; } + protected: + std::mutex lock_; + ReadableFileInterface(); }; http://git-wip-us.apache.org/repos/asf/arrow/blob/28210301/cpp/src/arrow/io/io-file-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/io/io-file-test.cc b/cpp/src/arrow/io/io-file-test.cc index 999b296..86a3287 100644 --- a/cpp/src/arrow/io/io-file-test.cc +++ b/cpp/src/arrow/io/io-file-test.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include <atomic> #include <cstdint> #include <cstdio> #include <cstring> @@ -25,6 +26,7 @@ #include <memory> #include <sstream> #include <string> +#include <thread> #include "gtest/gtest.h" @@ -325,6 +327,40 @@ TEST_F(TestReadableFile, CustomMemoryPool) { ASSERT_EQ(2, pool.num_allocations()); } +TEST_F(TestReadableFile, ThreadSafety) { + std::string data = "foobar"; + { + std::ofstream stream; + stream.open(path_.c_str()); + stream << data; + } + + MyMemoryPool pool; + ASSERT_OK(ReadableFile::Open(path_, &pool, &file_)); + + std::atomic<int> correct_count(0); + const int niter = 10000; + + auto ReadData = [&correct_count, &data, niter, this] () { + std::shared_ptr<Buffer> buffer; + + for (int i = 0; i < niter; ++i) { + ASSERT_OK(file_->ReadAt(0, 3, &buffer)); + if (0 == memcmp(data.c_str(), buffer->data(), 3)) { + correct_count += 1; + } + } + }; + + std::thread thread1(ReadData); + std::thread thread2(ReadData); + + thread1.join(); + thread2.join(); + + ASSERT_EQ(niter * 2, correct_count); +} + // ---------------------------------------------------------------------- // Memory map tests @@ -455,5 +491,38 @@ TEST_F(TestMemoryMappedFile, CastableToFileInterface) { std::shared_ptr<FileInterface> file = memory_mapped_file; } +TEST_F(TestMemoryMappedFile, ThreadSafety) { + std::string data = "foobar"; + std::string path = "ipc-multithreading-test"; + CreateFile(path, static_cast<int>(data.size())); + + std::shared_ptr<MemoryMappedFile> file; + ASSERT_OK(MemoryMappedFile::Open(path, FileMode::READWRITE, &file)); + ASSERT_OK(file->Write(reinterpret_cast<const uint8_t*>(data.c_str()), + static_cast<int64_t>(data.size()))); + + std::atomic<int> correct_count(0); + const int niter = 10000; + + auto ReadData = [&correct_count, &data, niter, &file] () { + std::shared_ptr<Buffer> buffer; + + for (int i = 0; i < niter; ++i) { + ASSERT_OK(file->ReadAt(0, 3, &buffer)); + if (0 == memcmp(data.c_str(), buffer->data(), 3)) { + correct_count += 1; + } + } + }; + + std::thread thread1(ReadData); + std::thread thread2(ReadData); + + thread1.join(); + thread2.join(); + + ASSERT_EQ(niter * 2, correct_count); +} + } // namespace io } // namespace arrow
