universewill commented on issue #12952: a very interesting problem URL: https://github.com/apache/incubator-mxnet/issues/12952#issuecomment-432919290 This is my model code: ``` def my_model(num_classes = num_classes): sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0) all_layers = sym.get_internals() print(all_layers.list_outputs()[-10:]) feat_sym = all_layers['flatten0_output'] # 用于冻结 abandon_para_names = all_layers.list_outputs()[-60:] freeze_para_names =[k for k in arg_params.keys() if k not in abandon_para_names] # 额外增加两个smulti-task的共有层 net = mx.sym.FullyConnected(data=feat_sym, num_hidden=1024, name='my_layer_0') net = mx.sym.FullyConnected(data=net, num_hidden=512, name='my_layer_1') net = mx.sym.FullyConnected(data=net, num_hidden=256, name='my_layer_2') net = mx.symbol.BatchNorm(net, name='my_layer') net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='my_layer_3') # 分类 loss = mx.sym.SoftmaxOutput(data=net, name= 'softmax') mod = mx.mod.Module(loss, context=mx.gpu(0), fixed_param_names=freeze_para_names) new_args = dict({k:arg_params[k] for k in arg_params if k not in all_layers.list_outputs()[-60:]}) return mod, arg_params, aux_params ``` This is my train code: ``` import logging logging.getLogger().setLevel(logging.INFO) logging.basicConfig(level=logging.INFO, filename='train.log') mod, arg_params, aux_params = mx_model.leather_model(num_classes) train_iter = DataIter(image_root=data_root, batch_size = batch_size) batch_size = 128 data_root = '/home/universe/jupyter/data/leather/data' train_iter = DataIter(image_root=data_root, batch_size = batch_size) mod.bind(data_shapes= train_iter.provide_data, label_shapes= train_iter.provide_label, force_rebind=True) curr_data_shapes = tuple(i.shape for i in mod._data_shapes) print(curr_data_shapes) new_data_shapes = tuple(i.shape for i in train_iter.next().data) print(new_data_shapes) initializer = mx.initializer.Xavier() mod.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=True, allow_extra=True) mod.init_optimizer(optimizer='adam', optimizer_params={'learning_rate': 0.0005}, force_init=True) mod.fit(train_data=train_iter, batch_end_callback=mx.callback.Speedometer(batch_size, 4), epoch_end_callback=mx.callback.do_checkpoint("./checkpoints/my_model", 50), num_epoch=1500) ``` @wkcn
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
