kohr-h commented on a change in pull request #13094: WIP: Simplifications and
some fun stuff for the MNIST Gluon tutorial
URL: https://github.com/apache/incubator-mxnet/pull/13094#discussion_r231315017
##########
File path: docs/tutorials/gluon/mnist.md
##########
@@ -126,71 +125,116 @@ training scope which is defined by `autograd.record()`.
```python
%%time
-epoch = 10
-# Use Accuracy as the evaluation metric.
+
+num_epochs = 10
metric = mx.metric.Accuracy()
softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
-for i in range(epoch):
- # Reset the train data iterator.
+for epoch in range(num_epochs):
+ # Restart the training data iterator at the beginning of each epoch
train_data.reset()
- # Loop over the train data iterator.
+
for batch in train_data:
- # Splits train data into multiple slices along batch_axis
- # and copy each slice into a context.
- data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx,
batch_axis=0)
- # Splits train labels into multiple slices along batch_axis
- # and copy each slice into a context.
- label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx,
batch_axis=0)
- outputs = []
- # Inside training scope
+ # Possibly copy data and labels to the GPU
+ data = batch.data[0].copyto(ctx)
+ labels = batch.label[0].copyto(ctx)
+
+ # The forward pass and the loss computation need to be wrapped
+ # in an `ag.record()` scope to indicate that the results will
+ # be needed in the backward pass (gradient computation).
with ag.record():
- for x, y in zip(data, label):
- z = net(x)
- # Computes softmax cross entropy loss.
- loss = softmax_cross_entropy_loss(z, y)
- # Backpropagate the error for one iteration.
- loss.backward()
- outputs.append(z)
- # Updates internal evaluation
- metric.update(label, outputs)
- # Make one step of parameter update. Trainer needs to know the
- # batch size of data to normalize the gradient by 1/batch_size.
- trainer.step(batch.data[0].shape[0])
- # Gets the evaluation result.
+ out = net(data)
+ loss = softmax_cross_entropy_loss(out, labels)
+
+ # Compute gradients by backpropagation and update the evaluation
+ # metric
+ loss.backward()
+ metric.update(labels, out)
+
+ # Update the parameters by stepping the trainer; the batch size
+ # is required to normalize the gradients by `1 / batch_size`.
+ trainer.step(batch_size=batch.data[0].shape[0])
+
+ # Print the evaluation metric and reset it for the next epoch
name, acc = metric.get()
- # Reset evaluation result to initial state.
+ print('After epoch {}: {} = {}'.format(epoch + 1, name, acc))
metric.reset()
- print('training acc at epoch %d: %s=%f'%(i, name, acc))
```
#### Prediction
After the above training completes, we can evaluate the trained model by
running predictions on validation dataset. Since the dataset also has labels
for all test images, we can compute the accuracy metric over validation data as
follows:
```python
-# Use Accuracy as the evaluation metric.
metric = mx.metric.Accuracy()
-# Reset the validation data iterator.
val_data.reset()
-# Loop over the validation data iterator.
for batch in val_data:
- # Splits validation data into multiple slices along batch_axis
- # and copy each slice into a context.
- data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx,
batch_axis=0)
- # Splits validation label into multiple slices along batch_axis
- # and copy each slice into a context.
- label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx,
batch_axis=0)
- outputs = []
- for x in data:
- outputs.append(net(x))
- # Updates internal evaluation
- metric.update(label, outputs)
-print('validation acc: %s=%f'%metric.get())
+ # Possibly copy data and labels to the GPU
+ data = batch.data[0].copyto(ctx)
+ labels = batch.label[0].copyto(ctx)
+ metric.update(labels, net(data))
+print('Validaton: {} = {}'.format(*metric.get()))
assert metric.get()[1] > 0.94
```
If everything went well, we should see an accuracy value that is around 0.96,
which means that we are able to accurately predict the digit in 96% of test
images. This is a pretty good result. But as we will see in the next part of
this tutorial, we can do a lot better than that.
+That said, a single numer only gives us very limited information on the
performance of our neural network. It is always a good idea to actually look at
the images on which the network performed poorly, and check for clues on how to
improve the performance. We do that with the help of a small function that
produces a list of the images where the network got it wrong, together with the
predicted and true labels:
+
+```python
+def get_mislabelled(it):
+ """Return list of ``(input, pred_lbl, true_lbl)`` for mislabelled
samples."""
+ mislabelled = []
+ it.reset()
+ for batch in it:
+ data = batch.data[0].copyto(ctx)
+ labels = batch.label[0].copyto(ctx)
+ out = net(data)
+ # Predicted label is the index is where the output is maximal
+ preds = nd.argmax(out, axis=1)
+ for d, p, l in zip(data, preds, labels):
+ if p != l:
+ mislabelled.append(
+ (d.asnumpy(), int(p.asnumpy()), int(l.asnumpy()))
+ )
+ return mislabelled
+```
+
+We can now get the mislabelled images in the training and validation sets and
plot a selection of them:
+
+```python
+import numpy as np
+
+sample_size = 8
+wrong_train = get_mislabelled(train_data)
+wrong_val = get_mislabelled(val_data)
+wrong_train_sample = [wrong_train[i] for i in np.random.randint(0,
len(wrong_train), size=sample_size)]
+wrong_val_sample = [wrong_val[i] for i in np.random.randint(0, len(wrong_val),
size=sample_size)]
+
+import matplotlib.pyplot as plt
+
+fig, axs = plt.subplots(ncols=sample_size)
+for ax, (img, pred, lbl) in zip(axs, wrong_train_sample):
+ fig.set_size_inches(18, 4)
+ fig.suptitle("Sample of wrong predictions in the training set",
fontsize=20)
+ ax.imshow(img[0], cmap="gray")
+ ax.set_title("Predicted: {}\nActual: {}".format(pred, lbl))
+ ax.xaxis.set_visible(False)
+ ax.yaxis.set_visible(False)
+
+fig, axs = plt.subplots(ncols=sample_size)
+for ax, (img, pred, lbl) in zip(axs, wrong_val_sample):
+ fig.set_size_inches(18, 4)
+ fig.suptitle("Sample of wrong predictions in the validation set",
fontsize=20)
+ ax.imshow(img[0], cmap="gray")
+ ax.set_title("Predicted: {}\nActual: {}".format(pred, lbl))
+ ax.xaxis.set_visible(False)
+ ax.yaxis.set_visible(False)
+```
+
Review comment:
Okay, will do.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services