piiswrong closed pull request #9460: Data-iterator tutorial made python3 compatible. URL: https://github.com/apache/incubator-mxnet/pull/9460
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/docs/tutorials/basic/data.md b/docs/tutorials/basic/data.md index 60a7ec185b..b60626a489 100644 --- a/docs/tutorials/basic/data.md +++ b/docs/tutorials/basic/data.md @@ -44,6 +44,7 @@ Before diving into the details let's setup the environment by importing some req import mxnet as mx %matplotlib inline import os +import sys import subprocess import numpy as np import matplotlib.pyplot as plt @@ -100,12 +101,11 @@ Thus we can create a new iterator by: The example below shows how to create a Simple iterator. ```python - class SimpleIter(mx.io.DataIter): def __init__(self, data_names, data_shapes, data_gen, label_names, label_shapes, label_gen, num_batches=10): - self._provide_data = zip(data_names, data_shapes) - self._provide_label = zip(label_names, label_shapes) + self._provide_data = list(zip(data_names, data_shapes)) + self._provide_label = list(zip(label_names, label_shapes)) self.num_batches = num_batches self.data_gen = data_gen self.label_gen = label_gen @@ -180,6 +180,30 @@ mod = mx.mod.Module(symbol=net) mod.fit(data_iter, num_epoch=5) ``` +A note on python 3 usage: Lot of the methods in mxnet use string for python2 and bytes for python3. +In order to keep this tutorial readable, we are going to define a utility function that converts +string to bytes in python 3 environment + +```python +def str_or_bytes(str): + """ + A utility function for this tutorial that helps us convert string + to bytes if we are using python3. + + Parameters + ---------- + str : string + + Returns + ------- + string (python2) or bytes (python3) + """ + if sys.version_info[0] < 3: + return str + else: + return bytes(str, 'utf-8') +``` + ## Record IO Record IO is a file format used by MXNet for data IO. It compactly packs the data for efficient read and writes from distributed file system like Hadoop HDFS and AWS S3. @@ -197,7 +221,8 @@ using `MXRecordIO`. The files are named with a `.rec` extension. ```python record = mx.recordio.MXRecordIO('tmp.rec', 'w') for i in range(5): - record.write('record_%d'%i) + record.write(str_or_bytes('record_%d'%i)) + record.close() ``` @@ -221,7 +246,8 @@ We will create an indexed record file and a corresponding index file as below: ```python record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') for i in range(5): - record.write_idx(i, 'record_%d'%i) + record.write_idx(i, str_or_bytes('record_%d'%i)) + record.close() ``` @@ -255,11 +281,11 @@ The `mx.recordio` package provides a few utility functions for such operations, data = 'data' label1 = 1.0 header1 = mx.recordio.IRHeader(flag=0, label=label1, id=1, id2=0) -s1 = mx.recordio.pack(header1, data) +s1 = mx.recordio.pack(header1, str_or_bytes(data)) label2 = [1.0, 2.0, 3.0] header2 = mx.recordio.IRHeader(flag=3, label=label2, id=2, id2=0) -s2 = mx.recordio.pack(header2, data) +s2 = mx.recordio.pack(header2, str_or_bytes(data)) ``` ```python ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services