leezu commented on a change in pull request #17841: Gluon data 2.0: c++ dataloader and built-in image/bbox transforms URL: https://github.com/apache/incubator-mxnet/pull/17841#discussion_r395925754
########## File path: src/io/dataset.cc ########## @@ -0,0 +1,696 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file dataset.cc + * \brief High performance datasets implementation + */ +#include <string> +#include <vector> +#include <algorithm> + +#include <dmlc/parameter.h> +#include <dmlc/recordio.h> +#include <dmlc/io.h> +#include <mxnet/io.h> +#include <mxnet/ndarray.h> +#include <mxnet/tensor_blob.h> + +#include "../imperative/cached_op.h" +#include "../imperative/naive_cached_op.h" +#include "../ndarray/ndarray_function.h" + +#if MXNET_USE_OPENCV +#include <opencv2/opencv.hpp> +#include "./opencv_compatibility.h" +#endif // MXNET_USE_OPENCV + +namespace mxnet { +namespace io { + +struct RecordFileDatasetParam : public dmlc::Parameter<RecordFileDatasetParam> { + std::string rec_file; + std::string idx_file; + // declare parameters + DMLC_DECLARE_PARAMETER(RecordFileDatasetParam) { + DMLC_DECLARE_FIELD(rec_file) + .describe("The absolute path of record file."); + DMLC_DECLARE_FIELD(idx_file) + .describe("The path of the idx file."); + } +}; // struct RecordFileDatasetParam + +DMLC_REGISTER_PARAMETER(RecordFileDatasetParam); + +class RecordFileDataset : public Dataset { + public: + RecordFileDataset* Clone(void) const { + auto other = new RecordFileDataset(); + other->param_ = param_; + other->idx_ = idx_; + // do not share the pointer since it's not threadsafe to seek simultaneously + if (reader_ && stream_) { + dmlc::Stream *stream = dmlc::Stream::Create(param_.rec_file.c_str(), "r"); + other->reader_ = std::make_shared<dmlc::RecordIOReader>(stream); + other->stream_.reset(stream); + } + return other; + } + + void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) { + std::vector<std::pair<std::string, std::string> > kwargs_left; + param_.InitAllowUnknown(kwargs); + // open record file for read + dmlc::Stream *stream = dmlc::Stream::Create(param_.rec_file.c_str(), "r"); + reader_ = std::make_shared<dmlc::RecordIOReader>(stream); + stream_.reset(stream); + // read and process idx file + dmlc::Stream *idx_stream = dmlc::Stream::Create(param_.idx_file.c_str(), "r"); + dmlc::istream is(idx_stream); + size_t key, idx; + while (is >> key >> idx) { + idx_[key] = idx; + } + delete idx_stream; + } + + uint64_t GetLen() const { + return idx_.size(); + } + + bool GetItem(uint64_t idx, std::vector<NDArray>& ret) { + ret.resize(1); + auto& out = ret[0]; + size_t pos = idx_[static_cast<size_t>(idx)]; + { + std::lock_guard<std::mutex> lck(mutex_); + reader_->Seek(pos); + if (reader_->NextRecord(&read_buff_)) { + const char *buf = read_buff_.c_str(); + const size_t size = read_buff_.size(); + out = NDArray(TShape({static_cast<dim_t>(size)}), Context::CPU(), false, mshadow::kInt8); + TBlob dst = out.data(); + RunContext rctx{Context::CPU(), nullptr, nullptr, false}; + mxnet::ndarray::Copy<cpu, cpu>( + TBlob((void*)buf, out.shape(), cpu::kDevMask, out.dtype(), 0), + &dst, Context::CPU(), Context::CPU(), rctx); + } + } + return true; + } + + private: + /*! \brief parameters */ + RecordFileDatasetParam param_; + /*! \brief recordIO context */ + std::shared_ptr<dmlc::RecordIOReader> reader_; + std::shared_ptr<dmlc::Stream> stream_; + std::string read_buff_; + std::mutex mutex_; + /*! \brief indices */ + std::unordered_map<size_t, size_t> idx_; +}; + +MXNET_REGISTER_IO_DATASET(RecordFileDataset) + .describe("MXNet Record File Dataset") + .add_arguments(RecordFileDatasetParam::__FIELDS__()) + .set_body([]() { + return new RecordFileDataset(); +}); + +struct ImageRecordFileDatasetParam : public dmlc::Parameter<ImageRecordFileDatasetParam> { + std::string rec_file; + std::string idx_file; + int flag; + // declare parameters + DMLC_DECLARE_PARAMETER(ImageRecordFileDatasetParam) { + DMLC_DECLARE_FIELD(rec_file) + .describe("The absolute path of record file."); + DMLC_DECLARE_FIELD(idx_file) + .describe("The path of the idx file."); + DMLC_DECLARE_FIELD(flag).set_default(1) + .describe("If 1, always convert to colored, if 0 always convert to grayscale."); + } +}; // struct ImageRecordFileDatasetParam + +DMLC_REGISTER_PARAMETER(ImageRecordFileDatasetParam); + +#if MXNET_USE_OPENCV +template<int n_channels> +void SwapImageChannels(cv::Mat &img, NDArray& arr) { + int swap_indices[n_channels]; // NOLINT(*) + if (n_channels == 1) { + swap_indices[0] = 0; + } else if (n_channels == 3) { + swap_indices[0] = 2; + swap_indices[1] = 1; + swap_indices[2] = 0; + } else if (n_channels == 4) { + swap_indices[0] = 2; + swap_indices[1] = 1; + swap_indices[2] = 0; + swap_indices[3] = 3; + } + + TShape arr_shape = TShape({img.rows, img.cols, n_channels}); + if (arr.is_none() || arr.shape() != arr_shape || arr.ctx() != mxnet::Context::CPU(0) || + arr.dtype() != mshadow::kUint8 || arr.storage_type() != kDefaultStorage) { + arr = NDArray(arr_shape, mxnet::Context::CPU(0), false, mshadow::kUint8); + } + auto ptr = static_cast<uint8_t*>(arr.data().dptr_); + + // swap channels while copying elements into buffer + for (int i = 0; i < img.rows; ++i) { + const uint8_t* im_data = img.ptr<uint8_t>(i); + uint8_t* buffer_data = ptr + i * img.cols * n_channels; + for (int j = 0; j < img.cols; ++j) { + for (int k = 0; k < n_channels; ++k) { + buffer_data[k] = im_data[swap_indices[k]]; + } + im_data += n_channels; + buffer_data += n_channels; + } + } +} +#endif + +/*! \brief Struct for unpack recordio header */ +#pragma pack(1) +struct IRHeader { + uint32_t flag; + float label; + uint64_t id; + uint64_t id2; +}; // struct IRHeader + +class ImageRecordFileDataset : public Dataset { + public: + ImageRecordFileDataset* Clone(void) const { + auto other = new ImageRecordFileDataset(); + other->param_ = param_; + other->base_.reset(base_->Clone()); + return other; + } + + void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) { + std::vector<std::pair<std::string, std::string> > kwargs_left; + param_.InitAllowUnknown(kwargs); + base_ = std::make_shared<RecordFileDataset>(); + base_->Init(kwargs); + } + + uint64_t GetLen() const { + return base_->GetLen(); + } + + bool GetItem(uint64_t idx, std::vector<NDArray>& ret) { + CHECK_LT(idx, GetLen()); + std::vector<NDArray> raw; + if (!base_->GetItem(idx, raw)) return false; + CHECK_EQ(raw.size(), 1U) << "RecordFileDataset should return size 1 NDArray vector"; + uint8_t *s = reinterpret_cast<uint8_t*>(raw[0].data().dptr_); + size_t size = raw[0].shape().Size(); + CHECK_GT(size, sizeof(IRHeader)) << "Invalid size of bytes from Record File"; + IRHeader header; + std::memcpy(&header, s, sizeof(header)); + size -= sizeof(header); + s += sizeof(header); + NDArray label = NDArray(Context::CPU(), mshadow::default_type_flag); + RunContext rctx{Context::CPU(), nullptr, nullptr, false}; + if (header.flag > 0) { + auto label_shape = header.flag <= 1 ? TShape(0, 1) : TShape({header.flag}); + label.ReshapeAndAlloc(label_shape); + TBlob dst = label.data(); + mxnet::ndarray::Copy<cpu, cpu>( + TBlob((void*)s, label.shape(), cpu::kDevMask, label.dtype(), 0), + &dst, Context::CPU(), Context::CPU(), rctx); + s += sizeof(float) * header.flag; + size -= sizeof(float) * header.flag; + } else { + // label is a scalar with ndim() == 0 + label.ReshapeAndAlloc(TShape(0, 1)); + TBlob dst = label.data(); + *(dst.dptr<float>()) = header.label; + } + ret.resize(2); + ret[1] = label; +#if MXNET_USE_OPENCV + cv::Mat buf(1, size, CV_8U, s); + cv::Mat res = cv::imdecode(buf, param_.flag); + CHECK(!res.empty()) << "Decoding failed. Invalid image file."; + const int n_channels = res.channels(); + if (n_channels == 1) { + SwapImageChannels<1>(res, ret[0]); + } else if (n_channels == 3) { + SwapImageChannels<3>(res, ret[0]); + } else if (n_channels == 4) { + SwapImageChannels<4>(res, ret[0]); + } + return true; +#else + LOG(FATAL) << "Opencv is needed for image decoding."; +#endif + return false; // should not reach here + }; + + private: + /*! \brief parameters */ + ImageRecordFileDatasetParam param_; + /*! \brief base recordIO reader */ + std::shared_ptr<RecordFileDataset> base_; +}; + +MXNET_REGISTER_IO_DATASET(ImageRecordFileDataset) + .describe("MXNet Image Record File Dataset") + .add_arguments(ImageRecordFileDatasetParam::__FIELDS__()) + .set_body([]() { + return new ImageRecordFileDataset(); +}); + +struct ImageSequenceDatasetParam : public dmlc::Parameter<ImageSequenceDatasetParam> { + /*! \brief the list of absolute image paths, separated by \0 characters */ + std::string img_list; + /*! \brief the path separator character, by default it's ; */ + char path_sep; + /*! \brief If flag is 0, always convert to grayscale(1 channel). + * If flag is 1, always convert to colored (3 channels). + * If flag is -1, keep channels unchanged. + */ + int flag; + // declare parameters + DMLC_DECLARE_PARAMETER(ImageSequenceDatasetParam) { + DMLC_DECLARE_FIELD(img_list) + .describe("The list of image absolute paths."); + DMLC_DECLARE_FIELD(path_sep).set_default('|') + .describe("The path separator for joined image paths."); + DMLC_DECLARE_FIELD(flag).set_default(1) + .describe("If 1, always convert to colored, if 0 always convert to grayscale."); + } +}; // struct ImageSequenceDatasetParam + +DMLC_REGISTER_PARAMETER(ImageSequenceDatasetParam); + +class ImageSequenceDataset : public Dataset { + public: + ImageSequenceDataset* Clone(void) const { + return new ImageSequenceDataset(*this); + } + + void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) { + std::vector<std::pair<std::string, std::string> > kwargs_left; + param_.InitAllowUnknown(kwargs); + img_list_ = dmlc::Split(param_.img_list, param_.path_sep); + } + + uint64_t GetLen() const { + return img_list_.size(); + } + + bool GetItem(uint64_t idx, std::vector<NDArray>& ret) { +#if MXNET_USE_OPENCV + CHECK_LT(idx, img_list_.size()) + << "GetItem index: " << idx << " out of bound: " << img_list_.size(); + cv::Mat res = cv::imread(img_list_[idx], param_.flag); + CHECK(!res.empty()) << "Decoding failed. Invalid image file."; + const int n_channels = res.channels(); + ret.resize(1); + if (n_channels == 1) { + SwapImageChannels<1>(res, ret[0]); + } else if (n_channels == 3) { + SwapImageChannels<3>(res, ret[0]); + } else if (n_channels == 4) { + SwapImageChannels<4>(res, ret[0]); + } + return true; +#else + LOG(FATAL) << "Opencv is needed for image decoding."; +#endif + return false; + }; + + private: + /*! \brief parameters */ + ImageSequenceDatasetParam param_; + /*! \brief image list */ + std::vector<std::string> img_list_; +}; + +MXNET_REGISTER_IO_DATASET(ImageSequenceDataset) + .describe("Image Sequence Dataset") + .add_arguments(ImageSequenceDatasetParam::__FIELDS__()) + .set_body([]() { + return new ImageSequenceDataset(); +}); + +struct NDArrayDatasetParam : public dmlc::Parameter<NDArrayDatasetParam> { + /*! \brief the source ndarray */ + std::intptr_t arr; + // declare parameters + DMLC_DECLARE_PARAMETER(NDArrayDatasetParam) { + DMLC_DECLARE_FIELD(arr) + .describe("Pointer to NDArray."); + } +}; // struct NDArrayDatasetParam + +DMLC_REGISTER_PARAMETER(NDArrayDatasetParam); + +class NDArrayDataset : public Dataset { Review comment: Specify all datasets as `final`? This will allow the compiler to perform further optimization ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
