This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cebaa2  Update multi-task learning example (#12964)
3cebaa2 is described below

commit 3cebaa252bc9c6de20fec32fa7f20e5427b2a70a
Author: Thomas Delteil <[email protected]>
AuthorDate: Mon Nov 12 16:32:33 2018 -0800

    Update multi-task learning example (#12964)
    
    * Update multi task learning example
    
    * Updating README.md
---
 example/multi-task/README.md                 |  13 +-
 example/multi-task/example_multi_task.py     | 159 ----------
 example/multi-task/multi-task-learning.ipynb | 454 +++++++++++++++++++++++++++
 3 files changed, 462 insertions(+), 164 deletions(-)

diff --git a/example/multi-task/README.md b/example/multi-task/README.md
index 9034814..b7756fe 100644
--- a/example/multi-task/README.md
+++ b/example/multi-task/README.md
@@ -1,10 +1,13 @@
 # Mulit-task learning example
  
-This is a simple example to show how to use mxnet for multi-task learning. It 
uses MNIST as an example and mocks up the multi-label task.
+This is a simple example to show how to use mxnet for multi-task learning. It 
uses MNIST as an example, trying to predict jointly the digit and whether this 
digit is odd or even.
 
-## Usage
-First, you need to write a multi-task iterator on your own. The iterator needs 
to generate multiple labels according to your applications, and the label names 
should be specified in the `provide_label` function, which needs to be consist 
with the names of output layers. 
+For example:
 
-Then, if you want to show metrics of different tasks separately, you need to 
write your own metric class and specify the `num` parameter. In the `update` 
function of metric, calculate the metrics separately for different tasks.
+![](https://camo.githubusercontent.com/ed3cf256f47713335dc288f32f9b0b60bf1028b7/68747470733a2f2f7777772e636c61737365732e63732e756368696361676f2e6564752f617263686976652f323031332f737072696e672f31323330302d312f70612f7061312f64696769742e706e67)
 
-The example script uses gpu as device by default, if gpu is not available for 
your environment, you can change `device` to be `mx.cpu()`.
+Should be jointly classified as 4, and Even.
+
+In this example we don't expect the tasks to contribute to each other much, 
but for example multi-task learning has been successfully applied to the domain 
of image captioning. In [A Multi-task Learning Approach for Image 
Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, 
Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a 
network to jointly classify images and generate text captions
+
+Please refer to the notebook for a fully worked example.
diff --git a/example/multi-task/example_multi_task.py 
b/example/multi-task/example_multi_task.py
deleted file mode 100644
index 9e89849..0000000
--- a/example/multi-task/example_multi_task.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: skip-file
-import mxnet as mx
-from mxnet.test_utils import get_mnist_iterator
-import numpy as np
-import logging
-import time
-
-logging.basicConfig(level=logging.DEBUG)
-
-def build_network():
-    data = mx.symbol.Variable('data')
-    fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
-    act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-    fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
-    act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
-    fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
-    sm1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax1')
-    sm2 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax2')
-
-    softmax = mx.symbol.Group([sm1, sm2])
-
-    return softmax
-
-class Multi_mnist_iterator(mx.io.DataIter):
-    '''multi label mnist iterator'''
-
-    def __init__(self, data_iter):
-        super(Multi_mnist_iterator, self).__init__()
-        self.data_iter = data_iter
-        self.batch_size = self.data_iter.batch_size
-
-    @property
-    def provide_data(self):
-        return self.data_iter.provide_data
-
-    @property
-    def provide_label(self):
-        provide_label = self.data_iter.provide_label[0]
-        # Different labels should be used here for actual application
-        return [('softmax1_label', provide_label[1]), \
-                ('softmax2_label', provide_label[1])]
-
-    def hard_reset(self):
-        self.data_iter.hard_reset()
-
-    def reset(self):
-        self.data_iter.reset()
-
-    def next(self):
-        batch = self.data_iter.next()
-        label = batch.label[0]
-
-        return mx.io.DataBatch(data=batch.data, label=[label, label], \
-                pad=batch.pad, index=batch.index)
-
-class Multi_Accuracy(mx.metric.EvalMetric):
-    """Calculate accuracies of multi label"""
-
-    def __init__(self, num=None):
-        self.num = num
-        super(Multi_Accuracy, self).__init__('multi-accuracy')
-
-    def reset(self):
-        """Resets the internal evaluation result to initial state."""
-        self.num_inst = 0 if self.num is None else [0] * self.num
-        self.sum_metric = 0.0 if self.num is None else [0.0] * self.num
-
-    def update(self, labels, preds):
-        mx.metric.check_label_shapes(labels, preds)
-
-        if self.num is not None:
-            assert len(labels) == self.num
-
-        for i in range(len(labels)):
-            pred_label = 
mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
-            label = labels[i].asnumpy().astype('int32')
-
-            mx.metric.check_label_shapes(label, pred_label)
-
-            if self.num is None:
-                self.sum_metric += (pred_label.flat == label.flat).sum()
-                self.num_inst += len(pred_label.flat)
-            else:
-                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
-                self.num_inst[i] += len(pred_label.flat)
-
-    def get(self):
-        """Gets the current evaluation result.
-
-        Returns
-        -------
-        names : list of str
-           Name of the metrics.
-        values : list of float
-           Value of the evaluations.
-        """
-        if self.num is None:
-            return super(Multi_Accuracy, self).get()
-        else:
-            return zip(*(('%s-task%d'%(self.name, i), float('nan') if 
self.num_inst[i] == 0
-                                                      else self.sum_metric[i] 
/ self.num_inst[i])
-                       for i in range(self.num)))
-
-    def get_name_value(self):
-        """Returns zipped name and value pairs.
-
-        Returns
-        -------
-        list of tuples
-            A (name, value) tuple list.
-        """
-        if self.num is None:
-            return super(Multi_Accuracy, self).get_name_value()
-        name, value = self.get()
-        return list(zip(name, value))
-
-
-batch_size=100
-num_epochs=100
-device = mx.gpu(0)
-lr = 0.01
-
-network = build_network()
-train, val = get_mnist_iterator(batch_size=batch_size, input_shape = (784,))
-train = Multi_mnist_iterator(train)
-val = Multi_mnist_iterator(val)
-
-
-model = mx.mod.Module(
-    context            = device,
-    symbol             = network,
-    label_names        = ('softmax1_label', 'softmax2_label'))
-
-model.fit(
-    train_data         = train,
-    eval_data          = val,
-    eval_metric        = Multi_Accuracy(num=2),
-    num_epoch          = num_epochs,
-    optimizer_params   = (('learning_rate', lr), ('momentum', 0.9), ('wd', 
0.00001)),
-    initializer        = mx.init.Xavier(factor_type="in", magnitude=2.34),
-    batch_end_callback = mx.callback.Speedometer(batch_size, 50))
-
diff --git a/example/multi-task/multi-task-learning.ipynb 
b/example/multi-task/multi-task-learning.ipynb
new file mode 100644
index 0000000..6e03e2b
--- /dev/null
+++ b/example/multi-task/multi-task-learning.ipynb
@@ -0,0 +1,454 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Multi-Task Learning Example"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This is a simple example to show how to use mxnet for multi-task 
learning.\n",
+    "\n",
+    "The network is jointly going to learn whether a number is odd or even and 
to actually recognize the digit.\n",
+    "\n",
+    "\n",
+    "For example\n",
+    "\n",
+    "- 1 : 1 and odd\n",
+    "- 2 : 2 and even\n",
+    "- 3 : 3 and odd\n",
+    "\n",
+    "etc\n",
+    "\n",
+    "In this example we don't expect the tasks to contribute to each other 
much, but for example multi-task learning has been successfully applied to the 
domain of image captioning. In [A Multi-task Learning Approach for Image 
Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, 
Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a 
network to jointly classify images and generate text captions"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import logging\n",
+    "import random\n",
+    "import time\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import mxnet as mx\n",
+    "from mxnet import gluon, nd, autograd\n",
+    "import numpy as np"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 99,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "batch_size = 128\n",
+    "epochs = 5\n",
+    "ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()\n",
+    "lr = 0.01"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Data\n",
+    "\n",
+    "We get the traditionnal MNIST dataset and add a new label to the existing 
one. For each digit we return a new label that stands for Odd or Even"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    
"![](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_dataset = gluon.data.vision.MNIST(train=True)\n",
+    "test_dataset = gluon.data.vision.MNIST(train=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def transform(x,y):\n",
+    "    x = x.transpose((2,0,1)).astype('float32')/255.\n",
+    "    y1 = y\n",
+    "    y2 = y % 2 #odd or even\n",
+    "    return x, np.float32(y1), np.float32(y2)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We assign the transform to the original dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_dataset_t = train_dataset.transform(transform)\n",
+    "test_dataset_t = test_dataset.transform(transform)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We load the datasets DataLoaders"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, 
last_batch='rollover', batch_size=batch_size, num_workers=5)\n",
+    "test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, 
last_batch='rollover', batch_size=batch_size, num_workers=5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Input shape: (28, 28, 1), Target Labels: (5.0, 1.0)\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(\"Input shape: {}, Target Labels: 
{}\".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Multi-task Network\n",
+    "\n",
+    "The output of the featurization is passed to two different outputs layers"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 135,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class MultiTaskNetwork(gluon.HybridBlock):\n",
+    "    \n",
+    "    def __init__(self):\n",
+    "        super(MultiTaskNetwork, self).__init__()\n",
+    "        \n",
+    "        self.shared = gluon.nn.HybridSequential()\n",
+    "        with self.shared.name_scope():\n",
+    "            self.shared.add(\n",
+    "                gluon.nn.Dense(128, activation='relu'),\n",
+    "                gluon.nn.Dense(64, activation='relu'),\n",
+    "                gluon.nn.Dense(10, activation='relu')\n",
+    "            )\n",
+    "        self.output1 = gluon.nn.Dense(10) # Digist recognition\n",
+    "        self.output2 = gluon.nn.Dense(1) # odd or even\n",
+    "\n",
+    "        \n",
+    "    def hybrid_forward(self, F, x):\n",
+    "        y = self.shared(x)\n",
+    "        output1 = self.output1(y)\n",
+    "        output2 = self.output2(y)\n",
+    "        return output1, output2"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can use two different losses, one for each output"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 136,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loss_digits = gluon.loss.SoftmaxCELoss()\n",
+    "loss_odd_even = gluon.loss.SigmoidBCELoss()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We create and initialize the network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 137,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mx.random.seed(42)\n",
+    "random.seed(42)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 138,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net = MultiTaskNetwork()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 139,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net.initialize(mx.init.Xavier(), ctx=ctx)\n",
+    "net.hybridize() # hybridize for speed"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 140,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer = gluon.Trainer(net.collect_params(), 'adam', 
{'learning_rate':lr})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluate Accuracy\n",
+    "We need to evaluate the accuracy of each task separately"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 141,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def evaluate_accuracy(net, data_iterator):\n",
+    "    acc_digits = mx.metric.Accuracy(name='digits')\n",
+    "    acc_odd_even = mx.metric.Accuracy(name='odd_even')\n",
+    "    \n",
+    "    for i, (data, label_digit, label_odd_even) in 
enumerate(data_iterator):\n",
+    "        data = data.as_in_context(ctx)\n",
+    "        label_digit = label_digit.as_in_context(ctx)\n",
+    "        label_odd_even = 
label_odd_even.as_in_context(ctx).reshape(-1,1)\n",
+    "\n",
+    "        output_digit, output_odd_even = net(data)\n",
+    "        \n",
+    "        acc_digits.update(label_digit, output_digit.softmax())\n",
+    "        acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 
0.5)\n",
+    "    return acc_digits.get(), acc_odd_even.get()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training Loop"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We need to balance the contribution of each loss to the overall training 
and do so by tuning this alpha parameter within [0,1]."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 142,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "alpha = 0.5 # Combine losses factor"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 143,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch [0], Acc Digits   0.8945 Loss Digits   0.3409\n",
+      "Epoch [0], Acc Odd/Even 0.9561 Loss Odd/Even 0.1152\n",
+      "Epoch [0], Testing Accuracies (('digits', 0.9487179487179487), 
('odd_even', 0.9770633012820513))\n",
+      "Epoch [1], Acc Digits   0.9576 Loss Digits   0.1475\n",
+      "Epoch [1], Acc Odd/Even 0.9804 Loss Odd/Even 0.0559\n",
+      "Epoch [1], Testing Accuracies (('digits', 0.9642427884615384), 
('odd_even', 0.9826722756410257))\n",
+      "Epoch [2], Acc Digits   0.9681 Loss Digits   0.1124\n",
+      "Epoch [2], Acc Odd/Even 0.9852 Loss Odd/Even 0.0418\n",
+      "Epoch [2], Testing Accuracies (('digits', 0.9580328525641025), 
('odd_even', 0.9846754807692307))\n",
+      "Epoch [3], Acc Digits   0.9734 Loss Digits   0.0961\n",
+      "Epoch [3], Acc Odd/Even 0.9884 Loss Odd/Even 0.0340\n",
+      "Epoch [3], Testing Accuracies (('digits', 0.9670472756410257), 
('odd_even', 0.9839743589743589))\n",
+      "Epoch [4], Acc Digits   0.9762 Loss Digits   0.0848\n",
+      "Epoch [4], Acc Odd/Even 0.9894 Loss Odd/Even 0.0310\n",
+      "Epoch [4], Testing Accuracies (('digits', 0.9652887658227848), 
('odd_even', 0.9858583860759493))\n"
+     ]
+    }
+   ],
+   "source": [
+    "for e in range(epochs):\n",
+    "    # Accuracies for each task\n",
+    "    acc_digits = mx.metric.Accuracy(name='digits')\n",
+    "    acc_odd_even = mx.metric.Accuracy(name='odd_even')\n",
+    "    # Accumulative losses\n",
+    "    l_digits_ = 0.\n",
+    "    l_odd_even_ = 0. \n",
+    "    \n",
+    "    for i, (data, label_digit, label_odd_even) in 
enumerate(train_data):\n",
+    "        data = data.as_in_context(ctx)\n",
+    "        label_digit = label_digit.as_in_context(ctx)\n",
+    "        label_odd_even = 
label_odd_even.as_in_context(ctx).reshape(-1,1)\n",
+    "        \n",
+    "        with autograd.record():\n",
+    "            output_digit, output_odd_even = net(data)\n",
+    "            l_digits = loss_digits(output_digit, label_digit)\n",
+    "            l_odd_even = loss_odd_even(output_odd_even, 
label_odd_even)\n",
+    "\n",
+    "            # Combine the loss of each task\n",
+    "            l_combined = (1-alpha)*l_digits + alpha*l_odd_even\n",
+    "            \n",
+    "        l_combined.backward()\n",
+    "        trainer.step(data.shape[0])\n",
+    "        \n",
+    "        l_digits_ += l_digits.mean()\n",
+    "        l_odd_even_ += l_odd_even.mean()\n",
+    "        acc_digits.update(label_digit, output_digit.softmax())\n",
+    "        acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 
0.5)\n",
+    "        \n",
+    "    print(\"Epoch [{}], Acc Digits   {:.4f} Loss Digits   
{:.4f}\".format(\n",
+    "        e, acc_digits.get()[1], l_digits_.asscalar()/(i+1)))\n",
+    "    print(\"Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even 
{:.4f}\".format(\n",
+    "        e, acc_odd_even.get()[1], l_odd_even_.asscalar()/(i+1)))\n",
+    "    print(\"Epoch [{}], Testing Accuracies {}\".format(e, 
evaluate_accuracy(net, test_data)))\n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Testing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 144,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_random_data():\n",
+    "    idx = random.randint(0, len(test_dataset))\n",
+    "\n",
+    "    img = test_dataset[idx][0]\n",
+    "    data, _, _ = test_dataset_t[idx]\n",
+    "    data = data.as_in_context(ctx).expand_dims(axis=0)\n",
+    "\n",
+    "    plt.imshow(img.squeeze().asnumpy(), cmap='gray')\n",
+    "    \n",
+    "    return data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 152,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Predicted digit: [9.], odd: [1.]\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADeVJREFUeJzt3X+MFPX9x/HXG6QGAQ3aiBdLpd9Ga6pBak5joqk01caaRuAfUhMbjE2viTUpEVFCNT31Dxu1rdWYJldLCk2/QhUb+KPWWuKP1jQNIKiotFJC00OEkjNBEiNyvPvHzdlTbz6zzs7uzPF+PpLL7e57Z+ad5V7M7H5m9mPuLgDxTKq7AQD1IPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4I6oZsbMzNOJwQ6zN2tlee1tec3s6vM7O9mtsvMVrSzLgDdZWXP7TezyZL+
 [...]
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "data = get_random_data()\n",
+    "\n",
+    "digit, odd_even = net(data)\n",
+    "\n",
+    "digit = digit.argmax(axis=1)[0].asnumpy()\n",
+    "odd_even = (odd_even.sigmoid()[0] > 0.5).asnumpy()\n",
+    "\n",
+    "print(\"Predicted digit: {}, odd: {}\".format(digit, odd_even))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

Reply via email to