janelu9 commented on issue #8065: Help? when I go on training a language model, 
errors occured
URL: 
https://github.com/apache/incubator-mxnet/issues/8065#issuecomment-332730117
 
 
   ```
   my iter :
   # -*- coding: utf-8 -*-
   """
   Created on Mon Sep 18 15:11:46 2017
   
   @author: T800
   """
   
   import mxnet as mx
   import numpy as np
     
   def gen_buckets_data(data,batch_size=32):
       bucket_data={}
       for d in data:
           len_d=len(d) 
           if len_d !=0:
               if  len_d in bucket_data:
                   bucket_data[len_d].append(d)
               else:
                   bucket_data[len_d]=[d]
       bk=bucket_data.keys()
       len_max=max(bk)
       for h,ij in enumerate(bucket_data.iteritems()):
           i,j=ij
           lj=len(j)
           rm_num=lj%batch_size
           if rm_num and i<len_max:
               bucket_data[bk[h+1]].extend(bucket_data[i][:rm_num])
               del bucket_data[i][:rm_num]
       for k in bucket_data.keys():
           if bucket_data[k]==[]: 
               bucket_data.pop(k)
       bucket_data=bucket_data.values()
       if rm_num>0:bucket_data[-1].extend(bucket_data[-2][:batch_size-rm_num])
       return bucket_data
   
   class SimpleBatch(object):
       def __init__(self, data_names, data, label_names, label,bucket_key):
           self.data = data
           self.label = label
           self.data_names = data_names
           self.label_names = label_names     
           self.bucket_key=bucket_key
   
       @property
       def provide_data(self):
           return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
       @property
       def provide_label(self):
           return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
       
   class BucketSentenceIter(mx.io.DataIter):
       def __init__(self,data,batch_size,N):
           super(BucketSentenceIter, self).__init__()
           data=gen_buckets_data(data,batch_size)
           self.data,self.bucket_plan,self.bucket_idx=[],[],[]
           for idx,i in enumerate(data):
               temp=np.zeros((len(i),len(i[0])),int)            
               for jdx,j in enumerate(i):
                   temp[jdx,:len(j)]=j
               self.data.append(temp)
               li=len(i)
               self.bucket_plan.extend([idx]*(li/batch_size))
               rg=range(0,li,batch_size)
               np.random.shuffle(rg)
               self.bucket_idx.append(rg)
           np.random.shuffle(self.bucket_plan)
           self.batch_size=batch_size
           self.bucket_idx_ct=[0]*len(data)
           self.buckets=[len(i[0]) for i in data]
           self.default_bucket_key = self.buckets[-1]
           self.N=N
           self.provide_data = [('data', (batch_size, 
self.default_bucket_key+self.N-1))]
           self.provide_label = [('softmax_label', (self.batch_size, 
self.default_bucket_key))]
       def __iter__(self):
           for i in self.bucket_plan:
               begin=self.bucket_idx[i][self.bucket_idx_ct[i]]
               data=self.data[i][begin:begin+self.batch_size]
               self.bucket_idx_ct[i]+=1
               label=np.zeros(data.shape,int)
               label[:,:-1]=data[:,1:]
               data = 
[mx.nd.array(np.hstack([np.zeros([self.batch_size,self.N-1],int),data]))]
               label = [mx.nd.array(label)]
               data_names = ['data']
               label_names = ['softmax_label']
               yield SimpleBatch(data_names, data, label_names, 
label,self.buckets[i])
       def reset(self):
           self.bucket_idx_ct = [0]*len(self.data)
           
           
           
       
           
   
   
   ```
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to