leezu commented on a change in pull request #17841: URL: https://github.com/apache/incubator-mxnet/pull/17841#discussion_r416971594
########## File path: src/io/dataset.cc ########## @@ -0,0 +1,697 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2020 by Contributors + * \file dataset.cc + * \brief High performance datasets implementation + */ +#include <dmlc/parameter.h> +#include <dmlc/recordio.h> +#include <dmlc/io.h> +#include <mxnet/io.h> +#include <mxnet/ndarray.h> +#include <mxnet/tensor_blob.h> + +#include <string> +#include <vector> +#include <algorithm> + +#include "../imperative/cached_op.h" +#include "../imperative/naive_cached_op.h" +#include "../ndarray/ndarray_function.h" + +#if MXNET_USE_OPENCV +#include <opencv2/opencv.hpp> +#include "./opencv_compatibility.h" +#endif // MXNET_USE_OPENCV + +namespace mxnet { +namespace io { + +struct RecordFileDatasetParam : public dmlc::Parameter<RecordFileDatasetParam> { + std::string rec_file; + std::string idx_file; + // declare parameters + DMLC_DECLARE_PARAMETER(RecordFileDatasetParam) { + DMLC_DECLARE_FIELD(rec_file) + .describe("The absolute path of record file."); + DMLC_DECLARE_FIELD(idx_file) + .describe("The path of the idx file."); + } +}; // struct RecordFileDatasetParam + +DMLC_REGISTER_PARAMETER(RecordFileDatasetParam); + +class RecordFileDataset final : public Dataset { + public: + explicit RecordFileDataset(const std::vector<std::pair<std::string, std::string> >& kwargs) { + std::vector<std::pair<std::string, std::string> > kwargs_left; + param_.InitAllowUnknown(kwargs); + // open record file for read + dmlc::Stream *stream = dmlc::Stream::Create(param_.rec_file.c_str(), "r"); + reader_ = std::make_shared<dmlc::RecordIOReader>(stream); + stream_.reset(stream); + // read and process idx file + dmlc::Stream *idx_stream = dmlc::Stream::Create(param_.idx_file.c_str(), "r"); + dmlc::istream is(idx_stream); + size_t key, idx; + while (is >> key >> idx) { + idx_[key] = idx; + } + delete idx_stream; + } + + RecordFileDataset* Clone(void) const { + auto other = new RecordFileDataset(std::vector<std::pair<std::string, std::string> >()); + other->param_ = param_; + other->idx_ = idx_; + // do not share the pointer since it's not threadsafe to seek simultaneously + if (reader_ && stream_) { + dmlc::Stream *stream = dmlc::Stream::Create(param_.rec_file.c_str(), "r"); + other->reader_ = std::make_shared<dmlc::RecordIOReader>(stream); + other->stream_.reset(stream); + } + return other; + } + + uint64_t GetLen() const { + return idx_.size(); + } + + bool GetItem(uint64_t idx, std::vector<NDArray>* ret) { + ret->resize(1); + auto& out = (*ret)[0]; + size_t pos = idx_[static_cast<size_t>(idx)]; + { + std::lock_guard<std::mutex> lck(mutex_); + reader_->Seek(pos); + if (reader_->NextRecord(&read_buff_)) { + const char *buf = read_buff_.c_str(); + const size_t size = read_buff_.size(); + out = NDArray(TShape({static_cast<dim_t>(size)}), Context::CPU(), false, mshadow::kInt8); + TBlob dst = out.data(); + RunContext rctx{Context::CPU(), nullptr, nullptr, false}; + mxnet::ndarray::Copy<cpu, cpu>( + TBlob(const_cast<void*>(reinterpret_cast<const void*>(buf)), + out.shape(), cpu::kDevMask, out.dtype(), 0), + &dst, Context::CPU(), Context::CPU(), rctx); + } + } + return true; + } + + private: + /*! \brief parameters */ + RecordFileDatasetParam param_; + /*! \brief recordIO context */ + std::shared_ptr<dmlc::RecordIOReader> reader_; Review comment: Why not make this `thread_local` given the thread safety issues? Then you may not need to copy around datasets at the level of the sampler ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
