janelu9 commented on issue #8065: Help? when I go on training a language model, errors occured URL: https://github.com/apache/incubator-mxnet/issues/8065#issuecomment-332730117 ``` my iter : # -*- coding: utf-8 -*- """ Created on Mon Sep 18 15:11:46 2017 @author: T800 """ import mxnet as mx import numpy as np def gen_buckets_data(data,batch_size=32): bucket_data={} for d in data: len_d=len(d) if len_d !=0: if len_d in bucket_data: bucket_data[len_d].append(d) else: bucket_data[len_d]=[d] bk=bucket_data.keys() len_max=max(bk) for h,ij in enumerate(bucket_data.iteritems()): i,j=ij lj=len(j) rm_num=lj%batch_size if rm_num and i<len_max: bucket_data[bk[h+1]].extend(bucket_data[i][:rm_num]) del bucket_data[i][:rm_num] for k in bucket_data.keys(): if bucket_data[k]==[]: bucket_data.pop(k) bucket_data=bucket_data.values() if rm_num>0:bucket_data[-1].extend(bucket_data[-2][:batch_size-rm_num]) return bucket_data class SimpleBatch(object): def __init__(self, data_names, data, label_names, label,bucket_key): self.data = data self.label = label self.data_names = data_names self.label_names = label_names self.bucket_key=bucket_key @property def provide_data(self): return [(n, x.shape) for n, x in zip(self.data_names, self.data)] @property def provide_label(self): return [(n, x.shape) for n, x in zip(self.label_names, self.label)] class BucketSentenceIter(mx.io.DataIter): def __init__(self,data,batch_size,N): super(BucketSentenceIter, self).__init__() data=gen_buckets_data(data,batch_size) self.data,self.bucket_plan,self.bucket_idx=[],[],[] for idx,i in enumerate(data): temp=np.zeros((len(i),len(i[0])),int) for jdx,j in enumerate(i): temp[jdx,:len(j)]=j self.data.append(temp) li=len(i) self.bucket_plan.extend([idx]*(li/batch_size)) rg=range(0,li,batch_size) np.random.shuffle(rg) self.bucket_idx.append(rg) np.random.shuffle(self.bucket_plan) self.batch_size=batch_size self.bucket_idx_ct=[0]*len(data) self.buckets=[len(i[0]) for i in data] self.default_bucket_key = self.buckets[-1] self.N=N self.provide_data = [('data', (batch_size, self.default_bucket_key+self.N-1))] self.provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key))] def __iter__(self): for i in self.bucket_plan: begin=self.bucket_idx[i][self.bucket_idx_ct[i]] data=self.data[i][begin:begin+self.batch_size] self.bucket_idx_ct[i]+=1 label=np.zeros(data.shape,int) label[:,:-1]=data[:,1:] data = [mx.nd.array(np.hstack([np.zeros([self.batch_size,self.N-1],int),data]))] label = [mx.nd.array(label)] data_names = ['data'] label_names = ['softmax_label'] yield SimpleBatch(data_names, data, label_names, label,self.buckets[i]) def reset(self): self.bucket_idx_ct = [0]*len(self.data) ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
With regards, Apache Git Services
