soeque1 commented on issue #12796: Add embedding to print_summary URL: https://github.com/apache/incubator-mxnet/pull/12796#issuecomment-429038277 The Post (#12778) does not have enough information to run the code. So I added something to the sample code. ``` from mxnet import gluon from mxnet.gluon import HybridBlock, nn import mxnet as mx vocab_size ``` ``` from mxnet import gluon from mxnet.gluon import HybridBlock, nn import mxnet as mx class SentClassificationModel(gluon.HybridBlock): def __init__(self, vocab_size, num_embed, **kwargs): super(SentClassificationModel, self).__init__(**kwargs) with self.name_scope(): self.embed = nn.Embedding(input_dim=vocab_size, output_dim=num_embed) self.drop = nn.Dropout(0.3) self.fc = nn.Dense(100, activation='relu') self.out = nn.Dense(2) def hybrid_forward(self, F ,inputs): em_out = self.drop(self.embed(inputs)) fc_out = self.fc(em_out) return(self.out(fc_out)) ctx = mx.gpu() model = SentClassificationModel(vocab_size = 20, num_embed=50) model.initialize(mx.init.Xavier(),ctx=ctx) model.hybridize() mx.viz.print_summary( model(mx.sym.var('data')), shape={'data':(1,30)}, #set your shape here ) ``` The output is the below ``` ________________________________________________________________________________________________________________________ Layer (type) Output Shape Param # Previous Layer ======================================================================================================================== data(null) 30 0 ________________________________________________________________________________________________________________________ sentclassificationmodel2_embedding0_fwd(Embedding) 30x50 1000 data ________________________________________________________________________________________________________________________ sentclassificationmodel2_dropout0_fwd(Dropout) 30x50 0 sentclassificationmodel2_embeddi ________________________________________________________________________________________________________________________ sentclassificationmodel2_dense0_fwd(FullyConnected) 100 3100 sentclassificationmodel2_dropout ________________________________________________________________________________________________________________________ sentclassificationmodel2_dense0_relu_fwd(Activation)100 0 sentclassificationmodel2_dense0_ ________________________________________________________________________________________________________________________ sentclassificationmodel2_dense1_fwd(FullyConnected) 2 202 sentclassificationmodel2_dense0_ ======================================================================================================================== Total params: 4302 ________________________________________________________________________________________________________________________ ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
