This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new fb4760b add cnn + highway network architecture for Chinese text
classification example (#8030)
fb4760b is described below
commit fb4760b3f51c7dc292caff7d8787a92eef2e1fde
Author: wut0n9 <[email protected]>
AuthorDate: Thu Sep 28 02:39:07 2017 +0800
add cnn + highway network architecture for Chinese text classification
example (#8030)
* add cnn_chinese_text_classification example
* add data files
* add data files
* remove comments
* delete data dir
* add methods for downloading dataset files
* switch download source
* Update data_helpers.py
---
.gitignore | 2 +-
example/cnn_chinese_text_classification/README.md | 28 +++
.../data_helpers.py | 200 +++++++++++++++
.../cnn_chinese_text_classification/text_cnn.py | 268 +++++++++++++++++++++
4 files changed, 497 insertions(+), 1 deletion(-)
diff --git a/.gitignore b/.gitignore
index 4e4ff78..7ca76c9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -147,4 +147,4 @@ target
bin/im2rec
-model/
\ No newline at end of file
+model/
diff --git a/example/cnn_chinese_text_classification/README.md
b/example/cnn_chinese_text_classification/README.md
new file mode 100644
index 0000000..bfb271d
--- /dev/null
+++ b/example/cnn_chinese_text_classification/README.md
@@ -0,0 +1,28 @@
+Implementing CNN + Highway Network for Chinese Text Classification in MXNet
+============
+Sentiment classification forked from
[incubator-mxnet/cnn_text_classification/](https://github.com/apache/incubator-mxnet/tree/master/example/cnn_text_classification),
i've implemented the [Highway Networks](https://arxiv.org/pdf/1505.00387.pdf)
architecture.The final train model is CNN + Highway Network structure, and this
version can achieve a best dev accuracy of 94.75% with the Chinese corpus.
+
+It is a slightly simplified implementation of Kim's [Convolutional Neural
Networks for Sentence Classification](http://arxiv.org/abs/1408.5882) paper in
MXNet.
+
+Recently, I have been learning mxnet for Natural Language Processing (NLP). I
followed this nice blog ["Implementing a CNN for Text Classification in
Tensorflow" blog
post.](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/)
to reimplement it by mxnet framework.
+Data preprocessing code and corpus are directly borrowed from original author
[cnn-text-classification-tf](https://github.com/dennybritz/cnn-text-classification-tf).
+
+## Performance compared to original paper
+I use the same pretrained word2vec
[GoogleNews-vectors-negative300.bin](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing)
in Kim's paper. However, I don't implement L2-normalization of weights on
penultimate layer, but provide a L2-normalization of gradients.
+Finally, I got a best dev accuracy 80.1%, close to 81% that reported in the
original paper.
+
+## Data
+Please download the corpus from this repository
[cnn-text-classification-tf](https://github.com/dennybritz/cnn-text-classification-tf),
:)
+
+'data/rt.vec', this file was trained on the corpus by word2vec tool. I
recommend to use GoogleNews word2vec, which could get better performance, since
+this corpus is small (contains about 10K sentences).
+
+When using GoogleNews word2vec, this code loads it with gensim tools
[gensim](https://github.com/piskvorky/gensim/tree/develop/gensim/models).
+
+## Remark
+If I were wrong in CNN implementation via mxnet, please correct me.
+
+## References
+- ["Implementing a CNN for Text Classification in Tensorflow" blog
post.](http://www.wildml.com/2015/12/implementing-a-cnn-for-text-classification-in-tensorflow/)
+- [Convolutional Neural Networks for Sentence
Classification](http://arxiv.org/abs/1408.5882)
+
diff --git a/example/cnn_chinese_text_classification/data_helpers.py
b/example/cnn_chinese_text_classification/data_helpers.py
new file mode 100644
index 0000000..1a5c4ad
--- /dev/null
+++ b/example/cnn_chinese_text_classification/data_helpers.py
@@ -0,0 +1,200 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import codecs
+
+import numpy as np
+import re
+import itertools
+from collections import Counter
+import os
+
+
+# from gensim.models import word2vec
+
+def clean_str(string):
+ """
+ Tokenization/string cleaning for all datasets except for SST.
+ Original taken from
https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
+ """
+ string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
+ string = re.sub(r"\'s", " \'s", string)
+ string = re.sub(r"\'ve", " \'ve", string)
+ string = re.sub(r"n\'t", " n\'t", string)
+ string = re.sub(r"\'re", " \'re", string)
+ string = re.sub(r"\'d", " \'d", string)
+ string = re.sub(r"\'ll", " \'ll", string)
+ string = re.sub(r",", " , ", string)
+ string = re.sub(r"!", " ! ", string)
+ string = re.sub(r"\(", " \( ", string)
+ string = re.sub(r"\)", " \) ", string)
+ string = re.sub(r"\?", " \? ", string)
+ string = re.sub(r"\s{2,}", " ", string)
+ return string.strip().lower()
+
+
+def get_chinese_text():
+ if not os.path.isdir("data/"):
+ os.system("mkdir data/")
+ if (not os.path.exists('data/pos.txt')) or \
+ (not os.path.exists('data/neg')):
+ os.system("wget -q
https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/example/chinese_text.zip
-P data/")
+ os.chdir("./data")
+ os.system("unzip -u chinese_text.zip")
+ os.chdir("..")
+
+
+def load_data_and_labels():
+ """
+ Loads MR polarity data from files, splits the data into words and
generates labels.
+ Returns split sentences and labels.
+ """
+ # download dataset
+ get_chinese_text()
+
+ # Load data from files
+ positive_examples = list(codecs.open("./data/pos.txt", "r",
"utf-8").readlines())
+ positive_examples = [s.strip() for s in positive_examples]
+ positive_examples = [pe for pe in positive_examples if len(pe) < 100]
+ negative_examples = list(codecs.open("./data/neg.txt", "r",
"utf-8").readlines())
+ negative_examples = [s.strip() for s in negative_examples]
+ negative_examples = [ne for ne in negative_examples if len(ne) < 100]
+ # Split by words
+ x_text = positive_examples + negative_examples
+ # x_text = [clean_str(sent) for sent in x_text]
+ x_text = [list(s) for s in x_text]
+
+ # Generate labels
+ positive_labels = [[0, 1] for _ in positive_examples]
+ negative_labels = [[1, 0] for _ in negative_examples]
+ y = np.concatenate([positive_labels, negative_labels], 0)
+ return [x_text, y]
+
+
+def pad_sentences(sentences, padding_word="</s>"):
+ """
+ Pads all sentences to the same length. The length is defined by the
longest sentence.
+ Returns padded sentences.
+ """
+ sequence_length = max(len(x) for x in sentences)
+ padded_sentences = []
+ for i in range(len(sentences)):
+ sentence = sentences[i]
+ num_padding = sequence_length - len(sentence)
+ new_sentence = sentence + [padding_word] * num_padding
+ padded_sentences.append(new_sentence)
+ return padded_sentences
+
+
+def build_vocab(sentences):
+ """
+ Builds a vocabulary mapping from word to index based on the sentences.
+ Returns vocabulary mapping and inverse vocabulary mapping.
+ """
+ # Build vocabulary
+ word_counts = Counter(itertools.chain(*sentences))
+ # Mapping from index to word
+ vocabulary_inv = [x[0] for x in word_counts.most_common()]
+ # Mapping from word to index
+ vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
+ return [vocabulary, vocabulary_inv]
+
+
+def build_input_data(sentences, labels, vocabulary):
+ """
+ Maps sentencs and labels to vectors based on a vocabulary.
+ """
+ x = np.array([[vocabulary[word] for word in sentence] for sentence in
sentences])
+ y = np.array(labels)
+ return [x, y]
+
+
+def build_input_data_with_word2vec(sentences, labels, word2vec):
+ """Map sentences and labels to vectors based on a pretrained word2vec"""
+ x_vec = []
+ for sent in sentences:
+ vec = []
+ for word in sent:
+ if word in word2vec:
+ vec.append(word2vec[word])
+ else:
+ vec.append(word2vec['</s>'])
+ x_vec.append(vec)
+ x_vec = np.array(x_vec)
+ y_vec = np.array(labels)
+ return [x_vec, y_vec]
+
+
+def load_data_with_word2vec(word2vec):
+ """
+ Loads and preprocessed data for the MR dataset.
+ Returns input vectors, labels, vocabulary, and inverse vocabulary.
+ """
+ # Load and preprocess data
+ sentences, labels = load_data_and_labels()
+ sentences_padded = pad_sentences(sentences)
+ # vocabulary, vocabulary_inv = build_vocab(sentences_padded)
+ return build_input_data_with_word2vec(sentences_padded, labels, word2vec)
+
+
+def load_data():
+ """
+ Loads and preprocessed data for the MR dataset.
+ Returns input vectors, labels, vocabulary, and inverse vocabulary.
+ """
+ # Load and preprocess data
+ sentences, labels = load_data_and_labels()
+ sentences_padded = pad_sentences(sentences)
+ vocabulary, vocabulary_inv = build_vocab(sentences_padded)
+ x, y = build_input_data(sentences_padded, labels, vocabulary)
+ return [x, y, vocabulary, vocabulary_inv]
+
+
+def batch_iter(data, batch_size, num_epochs):
+ """
+ Generates a batch iterator for a dataset.
+ """
+ data = np.array(data)
+ data_size = len(data)
+ num_batches_per_epoch = int(len(data) / batch_size) + 1
+ for epoch in range(num_epochs):
+ # Shuffle the data at each epoch
+ shuffle_indices = np.random.permutation(np.arange(data_size))
+ shuffled_data = data[shuffle_indices]
+ for batch_num in range(num_batches_per_epoch):
+ start_index = batch_num * batch_size
+ end_index = min((batch_num + 1) * batch_size, data_size)
+ yield shuffled_data[start_index:end_index]
+
+
+def load_pretrained_word2vec(infile):
+ if isinstance(infile, str):
+ infile = open(infile)
+
+ word2vec = {}
+ for idx, line in enumerate(infile):
+ if idx == 0:
+ vocab_size, dim = line.strip().split()
+ else:
+ tks = line.strip().split()
+ word2vec[tks[0]] = map(float, tks[1:])
+
+ return word2vec
+
+
+def load_google_word2vec(path):
+ model = word2vec.Word2Vec.load_word2vec_format(path, binary=True)
+ return model
diff --git a/example/cnn_chinese_text_classification/text_cnn.py
b/example/cnn_chinese_text_classification/text_cnn.py
new file mode 100644
index 0000000..8fd6b05
--- /dev/null
+++ b/example/cnn_chinese_text_classification/text_cnn.py
@@ -0,0 +1,268 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# -*- coding: utf-8 -*-
+
+import sys
+import os
+import mxnet as mx
+import numpy as np
+import argparse
+import logging
+
+import time
+
+from mxnet import random
+from mxnet.initializer import Xavier, Initializer
+
+import data_helpers
+
+fmt = '%(asctime)s:filename %(filename)s: lineno
%(lineno)d:%(levelname)s:%(message)s'
+logging.basicConfig(format=fmt, filemode='a+',
filename='./cnn_text_classification.log', level=logging.DEBUG)
+logger = logging.getLogger(__name__)
+
+parser = argparse.ArgumentParser(description="CNN for text classification",
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--pretrained-embedding', type=bool, default=False,
+ help='use pre-trained word2vec')
+parser.add_argument('--num-embed', type=int, default=300,
+ help='embedding layer size')
+parser.add_argument('--gpus', type=str, default='',
+ help='list of gpus to run, e.g. 0 or 0,2,5. empty means
using cpu. ')
+parser.add_argument('--kv-store', type=str, default='local',
+ help='key-value store type')
+parser.add_argument('--num-epochs', type=int, default=150,
+ help='max num of epochs')
+parser.add_argument('--batch-size', type=int, default=50,
+ help='the batch size.')
+parser.add_argument('--optimizer', type=str, default='rmsprop',
+ help='the optimizer type')
+parser.add_argument('--lr', type=float, default=0.0005,
+ help='initial learning rate')
+parser.add_argument('--dropout', type=float, default=0.0,
+ help='dropout rate')
+parser.add_argument('--disp-batches', type=int, default=50,
+ help='show progress for every n batches')
+parser.add_argument('--save-period', type=int, default=10,
+ help='save checkpoint for every n epochs')
+
+
+def save_model():
+ if not os.path.exists("checkpoint"):
+ os.mkdir("checkpoint")
+ return mx.callback.do_checkpoint("checkpoint/checkpoint", args.save_period)
+
+
+def highway(data):
+ _data = data
+ high_weight = mx.sym.Variable('high_weight')
+ high_bias = mx.sym.Variable('high_bias')
+ high_fc = mx.sym.FullyConnected(data=data, weight=high_weight,
bias=high_bias, num_hidden=300, name='high_fc')
+ high_relu = mx.sym.Activation(high_fc, act_type='relu')
+
+ high_trans_weight = mx.sym.Variable('high_trans_weight')
+ high_trans_bias = mx.sym.Variable('high_trans_bias')
+ high_trans_fc = mx.sym.FullyConnected(data=_data,
weight=high_trans_weight, bias=high_trans_bias, num_hidden=300,
+ name='high_trans_sigmoid')
+ high_trans_sigmoid = mx.sym.Activation(high_trans_fc, act_type='sigmoid')
+
+ return high_relu * high_trans_sigmoid + _data * (1 - high_trans_sigmoid)
+
+
+def data_iter(batch_size, num_embed, pre_trained_word2vec=False):
+ logger.info('Loading data...')
+ if pre_trained_word2vec:
+ word2vec = data_helpers.load_pretrained_word2vec('data/rt.vec')
+ x, y = data_helpers.load_data_with_word2vec(word2vec)
+ # reshpae for convolution input
+ x = np.reshape(x, (x.shape[0], 1, x.shape[1], x.shape[2]))
+ embed_size = x.shape[-1]
+ sentence_size = x.shape[2]
+ vocab_size = -1
+ else:
+ x, y, vocab, vocab_inv = data_helpers.load_data()
+ embed_size = num_embed
+ sentence_size = x.shape[1]
+ vocab_size = len(vocab)
+
+ # randomly shuffle data
+ np.random.seed(10)
+ shuffle_indices = np.random.permutation(np.arange(len(y)))
+ x_shuffled = x[shuffle_indices]
+ y_shuffled = y[shuffle_indices]
+
+ # split train/valid set
+ x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:]
+ y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:]
+ logger.info('Train/Valid split: %d/%d' % (len(y_train), len(y_dev)))
+ logger.info('train shape: %(shape)s', {'shape': x_train.shape})
+ logger.info('valid shape: %(shape)s', {'shape': x_dev.shape})
+ logger.info('sentence max words: %(shape)s', {'shape': sentence_size})
+ logger.info('embedding size: %(msg)s', {'msg': embed_size})
+ logger.info('vocab size: %(msg)s', {'msg': vocab_size})
+
+ train = mx.io.NDArrayIter(
+ x_train, y_train, batch_size, shuffle=True)
+ valid = mx.io.NDArrayIter(
+ x_dev, y_dev, batch_size)
+ return (train, valid, sentence_size, embed_size, vocab_size)
+
+
+def sym_gen(batch_size, sentence_size, num_embed, vocab_size,
+ num_label=2, filter_list=[3, 4, 5], num_filter=100,
+ dropout=0.0, pre_trained_word2vec=False):
+ input_x = mx.sym.Variable('data')
+ input_y = mx.sym.Variable('softmax_label')
+
+ # embedding layer
+ if not pre_trained_word2vec:
+ embed_layer = mx.sym.Embedding(data=input_x, input_dim=vocab_size,
output_dim=num_embed, name='vocab_embed')
+ conv_input = mx.sym.Reshape(data=embed_layer,
target_shape=(batch_size, 1, sentence_size, num_embed))
+ else:
+ conv_input = input_x
+
+ # create convolution + (max) pooling layer for each filter operation
+ pooled_outputs = []
+ for i, filter_size in enumerate(filter_list):
+ convi = mx.sym.Convolution(data=conv_input, kernel=(filter_size,
num_embed), num_filter=num_filter)
+ relui = mx.sym.Activation(data=convi, act_type='relu')
+ pooli = mx.sym.Pooling(data=relui, pool_type='max',
kernel=(sentence_size - filter_size + 1, 1), stride=(1, 1))
+ pooled_outputs.append(pooli)
+
+ # combine all pooled outputs
+ total_filters = num_filter * len(filter_list)
+ concat = mx.sym.Concat(*pooled_outputs, dim=1)
+ h_pool = mx.sym.Reshape(data=concat, target_shape=(batch_size,
total_filters))
+
+ # highway network
+ h_pool = highway(h_pool)
+
+ # dropout layer
+ if dropout > 0.0:
+ h_drop = mx.sym.Dropout(data=h_pool, p=dropout)
+ else:
+ h_drop = h_pool
+
+ # fully connected
+ cls_weight = mx.sym.Variable('cls_weight')
+ cls_bias = mx.sym.Variable('cls_bias')
+
+ fc = mx.sym.FullyConnected(data=h_drop, weight=cls_weight, bias=cls_bias,
num_hidden=num_label)
+
+ # softmax output
+ sm = mx.sym.SoftmaxOutput(data=fc, label=input_y, name='softmax')
+
+ return sm, ('data',), ('softmax_label',)
+
+
+def train(symbol, train_iter, valid_iter, data_names, label_names):
+ devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
+ mx.gpu(int(i)) for i in args.gpus.split(',')]
+ module = mx.mod.Module(symbol, data_names=data_names,
label_names=label_names, context=devs)
+
+ init_params = {
+ 'vocab_embed_weight': {'uniform': 0.1},
+ 'convolution0_weight': {'uniform': 0.1}, 'convolution0_bias':
{'costant': 0},
+ 'convolution1_weight': {'uniform': 0.1}, 'convolution1_bias':
{'costant': 0},
+ 'convolution2_weight': {'uniform': 0.1}, 'convolution2_bias':
{'costant': 0},
+ 'high_weight': {'uniform': 0.1}, 'high_bias': {'costant': 0},
+ 'high_trans_weight': {'uniform': 0.1}, 'high_trans_bias': {'costant':
-2},
+ 'cls_weight': {'uniform': 0.1}, 'cls_bias': {'costant': 0},
+ }
+ # custom init_params
+ module.bind(data_shapes=train_iter.provide_data,
label_shapes=train_iter.provide_label)
+ module.init_params(CustomInit(init_params))
+ lr_sch = mx.lr_scheduler.FactorScheduler(step=25000, factor=0.999)
+ module.init_optimizer(
+ optimizer='rmsprop', optimizer_params={'learning_rate': 0.0005,
'lr_scheduler': lr_sch})
+
+ def norm_stat(d):
+ return mx.nd.norm(d) / np.sqrt(d.size)
+ mon = mx.mon.Monitor(25000, norm_stat)
+
+ module.fit(train_data=train_iter,
+ eval_data=valid_iter,
+ eval_metric='acc',
+ kvstore=args.kv_store,
+ monitor=mon,
+ num_epoch=args.num_epochs,
+ batch_end_callback=mx.callback.Speedometer(args.batch_size,
args.disp_batches),
+ epoch_end_callback=save_model())
+
+
[email protected]
+class CustomInit(Initializer):
+ """
+
https://mxnet.incubator.apache.org/api/python/optimization.html#mxnet.initializer.register
+ Create and register a custom initializer that
+ Initialize the weight and bias with custom requirements
+
+ """
+ weightMethods = ["normal", "uniform", "orthogonal", "xavier"]
+ biasMethods = ["costant"]
+
+ def __init__(self, kwargs):
+ self._kwargs = kwargs
+ super(CustomInit, self).__init__(**kwargs)
+
+ def _init_weight(self, name, arr):
+ if name in self._kwargs.keys():
+ init_params = self._kwargs[name]
+ for (k, v) in init_params.items():
+ if k.lower() == "normal":
+ random.normal(0, v, out=arr)
+ elif k.lower() == "uniform":
+ random.uniform(-v, v, out=arr)
+ elif k.lower() == "orthogonal":
+ raise NotImplementedError("Not support at the moment")
+ elif k.lower() == "xavier":
+ xa = Xavier(v[0], v[1], v[2])
+ xa(name, arr)
+ else:
+ raise NotImplementedError("Not support")
+
+ def _init_bias(self, name, arr):
+ if name in self._kwargs.keys():
+ init_params = self._kwargs[name]
+ for (k, v) in init_params.items():
+ if k.lower() == "costant":
+ arr[:] = v
+ else:
+ raise NotImplementedError("Not support")
+
+
+if __name__ == '__main__':
+ # parse args
+ args = parser.parse_args()
+
+ # data iter
+ train_iter, valid_iter, sentence_size, embed_size, vocab_size =
data_iter(args.batch_size,
+
args.num_embed,
+
args.pretrained_embedding)
+
+ # network symbol
+ symbol, data_names, label_names = sym_gen(args.batch_size,
+ sentence_size,
+ embed_size,
+ vocab_size,
+ num_label=2, filter_list=[3, 4,
5], num_filter=100,
+ dropout=args.dropout,
pre_trained_word2vec=args.pretrained_embedding)
+ # train cnn model
+ train(symbol, train_iter, valid_iter, data_names, label_names)
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].