majia-yu opened a new issue #10455: Bug of group2ctxs for model parallelism
URL: https://github.com/apache/incubator-mxnet/issues/10455
 
 
   ## Description
   I use mx.mod.Module with group2ctxs to split a very larger softmax to 
multiple gpus. When I use less than 8 gpus, everything is fine, the input_grads 
from softmax are almost always same (for the same initilization and inputs), 
but when I use 8 gpus, the input_grads suddenly changes to a wrong value. I 
believe it's a bug. (A minimum code is provided below)
   
   ## Steps to reproduce
   
   Please run the code below. The result from my computer is:
   python magic.py --group2ctxs 1 --num_gpus 8
   loss: 22.12132263183593750000000000000000
   input_grads: 0.00453650718554854393005371093750 # ???
   
   python magic.py --group2ctxs 1 --num_gpus 7
   loss: 22.12132072448730468750000000000000
   input_grads: 0.00030235946178436279296875000000
   
   python magic.py --group2ctxs 1 --num_gpus 4
   loss: 22.12132072448730468750000000000000
   input_grads: 0.00030235989834181964397430419922
   
   python magic.py --group2ctxs 1 --num_gpus 1
   loss: 22.12132072448730468750000000000000
   input_grads: 0.0003023596364073455333709716796
   
   python magic.py --group2ctxs 0 --num_gpus 1
   loss: 22.12132072448730468750000000000000
   input_grads: 0.00030235992744565010070800781250
   
   
   ```
   import mxnet as mx
   import math
   from easydict import EasyDict
   import numpy as np
   import argparse
   
   
   parser = argparse.ArgumentParser()
   parser.add_argument('--num_gpus', type=int)
   parser.add_argument('--group2ctxs', type=int)
   args = parser.parse_args()
   
   
   np.random.seed(0)
   
   
   def softmax(input_shape, num_classes, gpus, s=64.0, m=0.5):
       embeddings = mx.sym.Variable(name='data', shape=input_shape) * s 
       labels = mx.symbol.Variable('softmax_label')
       num_classes_gpu = [num_classes // len(gpus)] * len(gpus)
       num_classes_gpu[0] += num_classes - sum(num_classes_gpu)
       embedding_size = input_shape[1]
       numerators = []
       numerators_index = []
       numerators_mask = []
       denominators = []
       logits = []
       for i, gpu in enumerate(gpus):
           with mx.AttrScope(ctx_group='dev_%d' % i): 
               local_weights = mx.symbol.Variable(
                   "softmax_%d_weight" % i, shape=(num_classes_gpu[i], 
embedding_size), lr_mult=1.0)
               local_weights = mx.symbol.L2Normalization(local_weights, 
mode='instance')
               local_logits = mx.sym.FullyConnected(data=embeddings, 
weight=local_weights, no_bias=True, num_hidden=num_classes_gpu[i])
               local_numerators_mask = (labels >= sum(num_classes_gpu[: i])) * 
(labels < sum(num_classes_gpu[: i + 1]))
               local_labels = mx.sym.where(local_numerators_mask, labels - 
sum(num_classes_gpu[: i]), mx.sym.zeros_like(labels))
               local_numerators_index = mx.sym.stack(mx.sym.arange(0, 
input_shape[0]), local_labels, axis=0)
               cos_t = mx.sym.pick(local_logits, local_labels, axis=1) / s
               sin_t = mx.sym.sqrt(1 - mx.sym.square(cos_t))
               cos_m_t = cos_t * math.cos(m) - sin_t * math.sin(m)
               cos_m_t = mx.sym.where(cos_t > math.cos(math.pi - m), cos_m_t, 
cos_t - math.sin(math.pi - m) * m)
               cos_m_t_s = mx.sym.where(local_numerators_mask, cos_m_t * s, 
cos_t * s)
               local_numerators = mx.sym.where(local_numerators_mask, 
mx.sym.exp(cos_m_t_s), mx.sym.zeros_like(cos_m_t))
               local_logits = local_logits + mx.sym.scatter_nd(
                   cos_m_t_s - cos_t * s, local_numerators_index, 
(input_shape[0], num_classes_gpu[i]))
               local_denominators = mx.sym.sum(mx.sym.exp(local_logits), axis=1)
               numerators.append(local_numerators)
               numerators_index.append(local_numerators_index)
               numerators_mask.append(local_numerators_mask)
               denominators.append(local_denominators)
               logits.append(local_logits)
       numerators = mx.sym.add_n(*numerators)
       denominators = mx.sym.add_n(*denominators)
       fake_loss = []
       for i, gpu in enumerate(gpus):
           with mx.AttrScope(ctx_group='dev_%d' % i):
               local_logits = logits[i]
               local_numerators_mask = numerators_mask[i]
               local_numerators_index = numerators_index[i]
               local_logits_grad = 
mx.sym.broadcast_div(mx.sym.exp(local_logits), mx.sym.expand_dims(denominators, 
axis=1))
               local_logits_grad = local_logits_grad - mx.sym.scatter_nd(
                   local_numerators_mask, local_numerators_index, 
(input_shape[0], num_classes_gpu[i]))
               local_logits_grad = mx.symbol.BlockGrad(local_logits_grad)
               local_fake_loss = mx.sym.sum(local_logits_grad * local_logits) / 
float(input_shape[0])
               fake_loss.append(local_fake_loss)
       fake_loss = mx.sym.add_n(*fake_loss)
       return mx.sym.make_loss(fake_loss)
   
   
   gpus = list(range(args.num_gpus))
   graph = mx.mod.Module(
       context=[mx.cpu(0)],
       symbol=softmax((512, 256), 500, gpus),
       group2ctxs={'dev_%d' % i:[mx.gpu(gpu)] for i, gpu in enumerate(gpus)} if 
args.group2ctxs else None
   )
   graph.bind(data_shapes=[('data', (512, 256))], 
label_shapes=[('softmax_label',(512,))], inputs_need_grad=True)
   initializer = mx.init.Constant(1)
   graph.init_params(initializer=initializer, arg_params=None, aux_params=None)
   optimizer = mx.optimizer.SGD(learning_rate=0.1, momentum=0.9, wd=0.0005, 
rescale_grad=1)
   graph.init_optimizer(optimizer=optimizer)
   
   
   data = mx.nd.array(np.random.rand(512, 256))
   data = mx.nd.L2Normalization(data, mode='instance')
   label = mx.nd.array(np.random.randint(0,500,(512,)))
   batch=EasyDict()
   batch.data=[data]
   batch.label=[label]
   
   
   graph.forward(batch, is_train=True)
   graph.backward()
   input_grads = graph.get_input_grads()[0].asnumpy()
   loss = graph.get_outputs()[0].asnumpy()
   print('loss: %.32f'%float(loss))
   print('input_grads: %.32f'%np.std(input_grads))
   
   
   ```
   
   
   ## Environment info (Required)
   ----------Python Info----------
   Version      : 3.5.2
   Compiler     : GCC 5.4.0 20160609
   Build        : ('default', 'Sep 14 2017 22:51:06')
   Arch         : ('64bit', 'ELF')
   ------------Pip Info-----------
   Version      : 9.0.3
   Directory    : 
/mnt/ficusjordan/dfyu/Projects/SoftmaxFace/python/python3.5_ubuntu16.04/SoftmaxFace/lib/python3.5/site-packages/pip
   ----------MXNet Info-----------
   Version      : 1.1.0
   Directory    : 
/mnt/ficusjordan/dfyu/Projects/SoftmaxFace/python/python3.5_ubuntu16.04/SoftmaxFace/lib/python3.5/site-packages/mxnet
   Commit Hash   : 07a83a0325a3d782513a04f47d711710972cb144
   ----------System Info----------
   Platform     : Linux-4.13.0-36-generic-x86_64-with-Ubuntu-16.04-xenial
   system       : Linux
   node         : WXRG0094
   release      : 4.13.0-36-generic
   version      : #40~16.04.1-Ubuntu SMP Fri Feb 16 23:25:58 UTC 2018
   ----------Hardware Info----------
   machine      : x86_64
   processor    : x86_64
   Architecture:          x86_64
   CPU op-mode(s):        32-bit, 64-bit
   Byte Order:            Little Endian
   CPU(s):                56
   On-line CPU(s) list:   0-55
   Thread(s) per core:    2
   Core(s) per socket:    14
   Socket(s):             2
   NUMA node(s):          2
   Vendor ID:             GenuineIntel
   CPU family:            6
   Model:                 79
   Model name:            Intel(R) Xeon(R) CPU E5-2680 v4 @ 2.40GHz
   Stepping:              1
   CPU MHz:               2400.215
   CPU max MHz:           3300.0000
   CPU min MHz:           1200.0000
   BogoMIPS:              4800.43
   Virtualization:        VT-x
   L1d cache:             32K
   L1i cache:             32K
   L2 cache:              256K
   L3 cache:              35840K
   NUMA node0 CPU(s):     0-13,28-41
   NUMA node1 CPU(s):     14-27,42-55
   Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge 
mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx 
pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology 
nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est 
tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt 
tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch 
cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti retpoline intel_ppin intel_pt 
tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep 
bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap xsaveopt cqm_llc cqm_occup_llc 
cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts
   ----------Network Test----------
   Setting timeout: 10
   Timing for FashionMNIST: 
https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz,
 DNS: 0.1864 sec, LOAD: 1.3491 sec.
   Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0119 
sec, LOAD: 3.8785 sec.
   Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0096 sec, 
LOAD: 0.4216 sec.
   Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.1880 sec, LOAD: 
0.8841 sec.
   Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.2153 sec, LOAD: 
0.4721 sec.
   Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0118 sec, LOAD: 
0.9239 sec.
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to