piiswrong closed pull request #10565: [MXNET-307] Add utility to get im2rec.py 
path
URL: https://github.com/apache/incubator-mxnet/pull/10565
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/tutorials/basic/data.md b/docs/tutorials/basic/data.md
index fdd50150f33..0a5dd59c1ce 100644
--- a/docs/tutorials/basic/data.md
+++ b/docs/tutorials/basic/data.md
@@ -391,8 +391,7 @@ Now let's convert them into record io format using the 
`im2rec.py` utility scrip
 First, we need to make a list that contains all the image files and their 
categories:
 
 ```python
-mxnet_path = os.path.dirname(mx.__file__)
-im2rec_path = os.path.join(mxnet_path, 'tools','im2rec.py')
+im2rec_path = mx.test_utils.get_im2rec_path()
 data_path = os.path.join('data','101_ObjectCategories')
 prefix_path = os.path.join('data','caltech')
 
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index d4d8ad29e24..aa388c14ea1 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1732,6 +1732,36 @@ def mean_check(generator, mu, sigma, nsamples=1000000):
           (sample_mean < mu + 3 * sigma / np.sqrt(nsamples))
     return ret
 
+def get_im2rec_path(home_env="MXNET_HOME"):
+    """Get path to the im2rec.py tool
+
+    Parameters
+    ----------
+
+    home_env : str
+        Env variable that holds the path to the MXNET folder
+
+    Returns
+    -------
+    str
+        The path to im2rec.py
+    """
+    # Check first if the path to MXNET is passed as an env variable
+    if home_env in os.environ:
+        mxnet_path = os.environ[home_env]
+    else:
+        # Else use currently imported mxnet as reference
+        mxnet_path = os.path.dirname(mx.__file__)
+    # If MXNet was installed through pip, the location of im2rec.py
+    im2rec_path = os.path.join(mxnet_path, 'tools', 'im2rec.py')
+    if os.path.isfile(im2rec_path):
+        return im2rec_path
+    # If MXNet has been built locally
+    im2rec_path = os.path.join(mxnet_path, '..', '..', 'tools', 'im2rec.py')
+    if os.path.isfile(im2rec_path):
+        return im2rec_path
+    raise IOError('Could not find path to tools/im2rec.py')
+
 def var_check(generator, sigma, nsamples=1000000):
     """Test the generator by matching the variance.
     It will need a large number of samples and is not recommended to use


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to