wkcn commented on a change in pull request #18411:
URL: https://github.com/apache/incubator-mxnet/pull/18411#discussion_r430753501
##########
File path: python/mxnet/io/io.py
##########
@@ -828,20 +828,26 @@ def __init__(self, handle, data_name='data',
label_name='softmax_label', **kwarg
from ..numpy.multiarray import _np_ndarray_cls
self._create_ndarray_fn = _np_ndarray_cls if is_np_array() else
_ndarray_cls
self.handle = handle
+ self.data_size = 1
self._kwargs = kwargs
+ if isinstance(data_name, list):
+ self.data_size = len(data_name)
# debug option, used to test the speed with io effect eliminated
self._debug_skip_load = False
# load the first batch to get shape information
self.first_batch = None
self.first_batch = self.next()
- data = self.first_batch.data[0]
+ data = self.first_batch.data
label = self.first_batch.label[0]
# properties
- self.provide_data = [DataDesc(data_name, data.shape, data.dtype)]
+ if isinstance(data_name, list):
Review comment:
` if isinstance(data_name, (list, tuple)): `
##########
File path: python/mxnet/io/io.py
##########
@@ -828,20 +828,26 @@ def __init__(self, handle, data_name='data',
label_name='softmax_label', **kwarg
from ..numpy.multiarray import _np_ndarray_cls
self._create_ndarray_fn = _np_ndarray_cls if is_np_array() else
_ndarray_cls
self.handle = handle
+ self.data_size = 1
self._kwargs = kwargs
+ if isinstance(data_name, list):
Review comment:
Users may pass a tuple as the data_name
` if isinstance(data_name, (list, tuple)): `
##########
File path: python/mxnet/io/io.py
##########
@@ -883,9 +889,20 @@ def iter_next(self):
return next_res.value
def getdata(self):
- hdl = NDArrayHandle()
- check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl)))
- return self._create_ndarray_fn(hdl, False)
+ #hdl = NDArrayHandle()
+ #check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl)))
+ #return self._create_ndarray_fn(hdl, False)
+ NDARR = NDArrayHandle *self.data_size
Review comment:
` NDARR = NDArrayHandle * self.data_size `
##########
File path: src/c_api/c_api.cc
##########
@@ -2254,9 +2255,11 @@ int MXDataIterGetIndex(DataIterHandle handle, uint64_t
**out_index, uint64_t *ou
int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
- NDArray* pndarray = new NDArray();
- *pndarray = db.data[0];
- *out = pndarray;
+ for (size_t i = 0 ; i < db.data.size() - 1; ++i) {
Review comment:
Is there the case db.data.size() == 0? It will fall into a dead loop
when db.data.size() == 0.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]