This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a1f1321  Text APIs: interface re-design with name changes (#9534)
a1f1321 is described below

commit a1f1321e2167252ee45b61f4f7684fd1a13c027d
Author: Aston Zhang <ast...@amazon.com>
AuthorDate: Wed Jan 24 15:20:30 2018 -0800

    Text APIs: interface re-design with name changes (#9534)
    
    * revise text api names
    
    * update
    
    * Add test cases
    
    * remove unused copy
    
    * update docstrings
    
    * Add API doc
    
    * rebuild
    
    * Fix broken links, shorten func name
    
    * remove redundant underscore
    
    * move api docs to contrib
    
    * naming: contrib package
---
 docs/api/python/contrib/contrib.md                 |  12 +
 docs/api/python/contrib/text.md                    | 401 +++++++++++++++
 docs/api/python/index.md                           |  19 +-
 docs/api/python/text/text.md                       | 455 -----------------
 python/mxnet/contrib/text/__init__.py              |   3 +-
 python/mxnet/contrib/text/embedding.py             | 395 ++++++++++-----
 python/mxnet/contrib/text/glossary.py              | 118 -----
 python/mxnet/contrib/text/{indexer.py => vocab.py} |  16 +-
 tests/python/unittest/test_contrib_text.py         | 552 ++++++++++++---------
 9 files changed, 1031 insertions(+), 940 deletions(-)

diff --git a/docs/api/python/contrib/contrib.md 
b/docs/api/python/contrib/contrib.md
new file mode 100644
index 0000000..66fc391
--- /dev/null
+++ b/docs/api/python/contrib/contrib.md
@@ -0,0 +1,12 @@
+# Contrib Package
+
+## Overview
+
+The `Contrib` APIs, defined in the `mxnet.contrib` package, provides
+many useful experimental APIs for new features.
+This is a place for the community to try out the new features,
+so that feature contributors can receive feedback.
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the 
near future.
+```
diff --git a/docs/api/python/contrib/text.md b/docs/api/python/contrib/text.md
new file mode 100644
index 0000000..f203a11
--- /dev/null
+++ b/docs/api/python/contrib/text.md
@@ -0,0 +1,401 @@
+# Text API
+
+## Overview
+
+The `mxnet.contrib.text` APIs refer to classes and functions related to text 
data processing, such
+as bulding indices and loading pre-trained embedding vectors for text tokens 
and storing them in the
+`mxnet.ndarray.NDArray` format.
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the 
near future.
+```
+
+This document lists the text APIs in mxnet:
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    mxnet.contrib.text.embedding
+    mxnet.contrib.text.vocab
+    mxnet.contrib.text.utils
+```
+
+All the code demonstrated in this document assumes that the following modules 
or packages are
+imported.
+
+```python
+>>> from mxnet import gluon
+>>> from mxnet import nd
+>>> from mxnet.contrib import text
+>>> import collections
+
+```
+
+### Looking up pre-trained word embeddings for indexed words
+
+As a common use case, let us look up pre-trained word embedding vectors for 
indexed words in just a
+few lines of code. 
+
+To begin with, Suppose that we have a simple text data set in the string 
format. We can count
+word frequency in the data set.
+
+```python
+>>> text_data = " hello world \n hello nice world \n hi world \n"
+>>> counter = text.utils.count_tokens_from_str(text_data)
+
+```
+
+The obtained `counter` has key-value pairs whose keys are words and values are 
word frequencies.
+Suppose that we want to build indices for all the keys in `counter` and load 
the defined fastText
+word embedding for all such indexed words. First, we need a Vocabulary object 
with `counter` as its
+argument
+
+```python
+>>> my_vocab = text.vocab.Vocabulary(counter)
+
+```
+
+We can create a fastText word embedding object by specifying the embedding 
name `fasttext` and
+the pre-trained file `wiki.simple.vec`. We also specify that the indexed 
tokens for loading the
+fastText word embedding come from the defined Vocabulary object `my_vocab`.
+
+```python
+>>> my_embedding = text.embedding.create('fasttext', 
pretrained_file_name='wiki.simple.vec',
+...     vocabulary=my_vocab)
+
+```
+
+Now we are ready to look up the fastText word embedding vectors for indexed 
words, such as 'hello'
+and 'world'.
+
+```python
+>>> my_embedding.get_vecs_by_tokens(['hello', 'world'])
+
+[[  3.95669997e-01   2.14540005e-01  -3.53889987e-02  -2.42990002e-01
+    ...
+   -7.54180014e-01  -3.14429998e-01   2.40180008e-02  -7.61009976e-02]
+ [  1.04440004e-01  -1.08580001e-01   2.72119999e-01   1.32990003e-01
+    ...
+   -3.73499990e-01   5.67310005e-02   5.60180008e-01   2.90190000e-02]]
+<NDArray 2x300 @cpu(0)>
+
+```
+
+### Using pre-trained word embeddings in `gluon`
+
+To demonstrate how to use pre-trained word embeddings in the `gluon` package, 
let us first obtain
+indices of the words 'hello' and 'world'.
+
+```python
+>>> my_embedding.to_indices(['hello', 'world'])
+[2, 1]
+
+```
+
+We can obtain the vector representation for the words 'hello' and 'world' by 
specifying their
+indices (2 and 1) and the `my_embedding.idx_to_vec` in 
`mxnet.gluon.nn.Embedding`.
+ 
+```python
+>>> layer = gluon.nn.Embedding(len(my_embedding), my_embedding.vec_len)
+>>> layer.initialize()
+>>> layer.weight.set_data(my_embedding.idx_to_vec)
+>>> layer(nd.array([2, 1]))
+
+[[  3.95669997e-01   2.14540005e-01  -3.53889987e-02  -2.42990002e-01
+    ...
+   -7.54180014e-01  -3.14429998e-01   2.40180008e-02  -7.61009976e-02]
+ [  1.04440004e-01  -1.08580001e-01   2.72119999e-01   1.32990003e-01
+    ...
+   -3.73499990e-01   5.67310005e-02   5.60180008e-01   2.90190000e-02]]
+<NDArray 2x300 @cpu(0)>
+
+```
+
+## Vocabulary
+
+The vocabulary builds indices for text tokens. Such indexed tokens can be used 
by token embedding
+instances. The input counter whose keys are candidate indices may be obtained 
via
+[`count_tokens_from_str`](#mxnet.contrib.text.utils.count_tokens_from_str).
+
+
+```eval_rst
+.. currentmodule:: mxnet.contrib.text.vocab
+.. autosummary::
+    :nosignatures:
+
+    Vocabulary
+```
+
+Suppose that we have a simple text data set in the string format. We can count 
word frequency in the
+data set.
+
+```python
+>>> text_data = " hello world \n hello nice world \n hi world \n"
+>>> counter = text.utils.count_tokens_from_str(text_data)
+
+```
+
+The obtained `counter` has key-value pairs whose keys are words and values are 
word frequencies.
+Suppose that we want to build indices for the 2 most frequent keys in 
`counter` with the unknown
+token representation '<UnK>' and a reserved token '<pad>'.
+
+```python
+>>> my_vocab = text.vocab.Vocabulary(counter, most_freq_count=2, 
unknown_token='<UnK>', 
+...     reserved_tokens=['<pad>'])
+
+```
+
+We can access properties such as `token_to_idx` (mapping tokens to indices), 
`idx_to_token` (mapping
+indices to tokens), `vec_len` (length of each embedding vector), and 
`unknown_token` (representation
+of any unknown token) and `reserved_tokens`.
+
+
+```python
+>>> my_vocab.token_to_idx
+{'<UnK>': 0, '<pad>': 1, 'world': 2, 'hello': 3}
+>>> my_vocab.idx_to_token
+['<UnK>', '<pad>', 'world', 'hello']
+>>> my_vocab.unknown_token
+'<UnK>'
+>>> my_vocab.reserved_tokens
+['<pad>']
+>>> len(my_vocab)
+4
+```
+
+Besides the specified unknown token '<UnK>' and reserved_token '<pad>' are 
indexed, the 2 most
+frequent words 'world' and 'hello' are also indexed.
+
+
+
+
+## Text token embedding
+
+To load token embeddings from an externally hosted pre-trained token embedding 
file, such as those
+of GloVe and FastText, use
+[`embedding.create(embedding_name, 
pretrained_file_name)`](#mxnet.contrib.text.embedding.create).
+
+To get all the available `embedding_name` and `pretrained_file_name`, use
+[`embedding.get_pretrained_file_names()`](#mxnet.contrib.text.embedding.get_pretrained_file_names).
+
+```python
+>>> text.embedding.get_pretrained_file_names()
+{'glove': ['glove.42B.300d.txt', 'glove.6B.50d.txt', 'glove.6B.100d.txt', ...],
+'fasttext': ['wiki.en.vec', 'wiki.simple.vec', 'wiki.zh.vec', ...]}
+
+```
+
+Alternatively, to load embedding vectors from a custom pre-trained text token
+embedding file, use 
[`CustomEmbedding`](#mxnet.contrib.text.embedding.CustomEmbedding).
+
+Moreover, to load composite embedding vectors, such as to concatenate 
embedding vectors,
+use [`CompositeEmbedding`](#mxnet.contrib.text.embedding.CompositeEmbedding).
+
+The indexed tokens in a text token embedding may come from a vocabulary or 
from the loaded embedding
+vectors. In the former case, only the indexed tokens in a vocabulary are 
associated with the loaded
+embedding vectors, such as loaded from a pre-trained token embedding file. In 
the later case, all
+the tokens from the loaded embedding vectors, such as loaded from a 
pre-trained token embedding
+file, are taken as the indexed tokens of the embedding.
+
+
+```eval_rst
+.. currentmodule:: mxnet.contrib.text.embedding
+.. autosummary::
+    :nosignatures:
+
+    register
+    create
+    get_pretrained_file_names
+    GloVe
+    FastText
+    CustomEmbedding
+    CompositeEmbedding
+```
+
+
+### Indexed tokens are from a vocabulary
+
+One can specify that only the indexed tokens in a vocabulary are associated 
with the loaded
+embedding vectors, such as loaded from a pre-trained token embedding file.
+
+To begin with, suppose that we have a simple text data set in the string 
format. We can count word
+frequency in the data set.
+
+```python
+>>> text_data = " hello world \n hello nice world \n hi world \n"
+>>> counter = text.utils.count_tokens_from_str(text_data)
+
+```
+
+The obtained `counter` has key-value pairs whose keys are words and values are 
word frequencies.
+Suppose that we want to build indices for the most frequent 2 keys in 
`counter` and load the defined
+fastText word embedding with pre-trained file `wiki.simple.vec` for all these 
2 words. 
+
+```python
+>>> my_vocab = text.vocab.Vocabulary(counter, most_freq_count=2)
+>>> my_embedding = text.embedding.create('fasttext', 
pretrained_file_name='wiki.simple.vec',
+...     vocabulary=my_vocab)
+
+```
+
+Now we are ready to look up the fastText word embedding vectors for indexed 
words.
+
+```python
+>>> my_embedding.get_vecs_by_tokens(['hello', 'world'])
+
+[[  3.95669997e-01   2.14540005e-01  -3.53889987e-02  -2.42990002e-01
+    ...
+   -7.54180014e-01  -3.14429998e-01   2.40180008e-02  -7.61009976e-02]
+ [  1.04440004e-01  -1.08580001e-01   2.72119999e-01   1.32990003e-01
+    ...
+   -3.73499990e-01   5.67310005e-02   5.60180008e-01   2.90190000e-02]]
+<NDArray 2x300 @cpu(0)>
+
+```
+
+We can also access properties such as `token_to_idx` (mapping tokens to 
indices), `idx_to_token`
+(mapping indices to tokens), and `vec_len` (length of each embedding vector).
+
+```python
+>>> my_embedding.token_to_idx
+{'<unk>': 0, 'world': 1, 'hello': 2}
+>>> my_embedding.idx_to_token
+['<unk>', 'world', 'hello']
+>>> len(my_embedding)
+3
+>>> my_embedding.vec_len
+300
+
+```
+
+If a token is unknown to `glossary`, its embedding vector is initialized 
according to the default
+specification in `fasttext_simple` (all elements are 0).
+
+```python
+
+>>> my_embedding.get_vecs_by_tokens('nice')
+
+[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
+  ...
+  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
+<NDArray 300 @cpu(0)>
+
+```
+
+
+### Indexed tokens are from the loaded embedding vectors
+
+One can also use all the tokens from the loaded embedding vectors, such as 
loaded from a pre-trained
+token embedding file, as the indexed tokens of the embedding.
+
+To begin with, we can create a fastText word embedding object by specifying 
the embedding name
+'fasttext' and the pre-trained file 'wiki.simple.vec'. The argument 
`init_unknown_vec` specifies
+default vector representation for any unknown token. To index all the tokens 
from this pre-trained
+word embedding file, we do not need to specify any vocabulary.
+
+```python
+>>> my_embedding = text.embedding.create('fasttext', 
pretrained_file_name='wiki.simple.vec',
+...     init_unknown_vec=nd.zeros)
+
+```
+
+We can access properties such as `token_to_idx` (mapping tokens to indices), 
`idx_to_token` (mapping
+indices to tokens), `vec_len` (length of each embedding vector), and 
`unknown_token` (representation
+of any unknown token, default value is '<unk>').
+
+```python
+>>> my_embedding.token_to_idx['nice']
+2586
+>>> my_embedding.idx_to_token[2586]
+'nice'
+>>> my_embedding.vec_len
+300
+>>> my_embedding.unknown_token
+'<unk>'
+
+```
+
+For every unknown token, if its representation '<unk>' is encountered in the 
pre-trained token
+embedding file, index 0 of property `idx_to_vec` maps to the pre-trained token 
embedding vector
+loaded from the file; otherwise, index 0 of property `idx_to_vec` maps to the 
default token
+embedding vector specified via `init_unknown_vec` (set to nd.zeros here). 
Since the pre-trained file
+does not have a vector for the token '<unk>', index 0 has to map to an 
additional token '<unk>' and
+the number of tokens in the embedding is 111,052.
+
+
+```python
+>>> len(my_embedding)
+111052
+>>> my_embedding.idx_to_vec[0]
+
+[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
+  ...
+  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
+<NDArray 300 @cpu(0)>
+>>> my_embedding.get_vecs_by_tokens('nice')
+
+[ 0.49397001  0.39996001  0.24000999 -0.15121    -0.087512    0.37114
+  ...
+  0.089521    0.29175001 -0.40917999 -0.089206   -0.1816     -0.36616999]
+<NDArray 300 @cpu(0)>
+>>> my_embedding.get_vecs_by_tokens(['unknownT0kEN', 'unknownT0kEN'])
+
+[[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
+   ...
+   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
+ [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
+   ...
+   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]
+<NDArray 2x50 @cpu(0)>
+
+```
+
+
+### Implement a new text token embedding
+
+For ``optimizer``, create a subclass of 
`mxnet.contrib.text.embedding._TokenEmbedding`.
+Also add ``@mxnet.contrib.text.embedding._TokenEmbedding.register`` before 
this class. See
+[`embedding.py`](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/contrib/text/embedding.py)
+for examples.
+
+
+## Text utilities
+
+The following functions provide utilities for text data processing.
+
+```eval_rst
+.. currentmodule:: mxnet.contrib.text.utils
+.. autosummary::
+    :nosignatures:
+
+    count_tokens_from_str
+```
+
+
+## API Reference
+
+<script type="text/javascript" 
src='../../_static/js/auto_module_index.js'></script>
+
+```eval_rst
+
+.. automodule:: mxnet.contrib.text.embedding
+    :members: register, create, get_pretrained_file_names
+.. autoclass:: mxnet.contrib.text.embedding.GloVe
+    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+.. autoclass:: mxnet.contrib.text.embedding.FastText
+    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+.. autoclass:: mxnet.contrib.text.embedding.CustomEmbedding
+    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+.. autoclass:: mxnet.contrib.text.embedding.CompositeEmbedding
+    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
+
+.. automodule:: mxnet.contrib.text.vocab
+.. autoclass:: mxnet.contrib.text.vocab.Vocabulary
+    :members: to_indices, to_tokens
+
+.. automodule:: mxnet.contrib.text.utils
+    :members: count_tokens_from_str
+
+```
+<script>auto_index("api-reference");</script>
\ No newline at end of file
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 7a3ad7c..f65d3ab 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -98,15 +98,6 @@ imported by running:
    io/io.md
 ```
 
-## Text API
-
-```eval_rst
-.. toctree::
-   :maxdepth: 1
-
-   text/text.md
-```
-
 ## Image API
 
 ```eval_rst
@@ -151,3 +142,13 @@ imported by running:
 
    rtc/rtc.md
 ```
+
+## Contrib Package
+
+```eval_rst
+.. toctree::
+   :maxdepth: 1
+
+   contrib/contrib.md
+   contrib/text.md
+```
diff --git a/docs/api/python/text/text.md b/docs/api/python/text/text.md
deleted file mode 100644
index 3b70b76..0000000
--- a/docs/api/python/text/text.md
+++ /dev/null
@@ -1,455 +0,0 @@
-# Text API
-
-## Overview
-
-The mxnet.contrib.text APIs refer to classes and functions related to text data
-processing, such as bulding indices and loading pre-trained embedding vectors
-for text tokens and storing them in the `mxnet.ndarray.NDArray` format.
-
-```eval_rst
-.. warning:: This package contains experimental APIs and may change in the 
near future.
-```
-
-This document lists the text APIs in mxnet:
-
-```eval_rst
-.. autosummary::
-    :nosignatures:
-
-    mxnet.contrib.text.glossary
-    mxnet.contrib.text.embedding
-    mxnet.contrib.text.indexer
-    mxnet.contrib.text.utils
-```
-
-All the code demonstrated in this document assumes that the following modules
-or packages are imported.
-
-```python
->>> from mxnet import gluon
->>> from mxnet import nd
->>> from mxnet.contrib import text
->>> import collections
-
-```
-
-### Look up pre-trained word embeddings for indexed words
-
-As a common use case, let us look up pre-trained word embedding vectors for
-indexed words in just a few lines of code. To begin with, we can create a
-fastText word embedding object by specifying the embedding name `fasttext` and
-the pre-trained file `wiki.simple.vec`.
-
-```python
->>> fasttext_simple = text.embedding.TokenEmbedding.create('fasttext',
-...     pretrained_file_name='wiki.simple.vec')
-
-```
-
-Suppose that we have a simple text data set in the string format. We can count
-word frequency in the data set.
-
-```python
->>> text_data = " hello world \n hello nice world \n hi world \n"
->>> counter = text.utils.count_tokens_from_str(text_data)
-
-```
-
-The obtained `counter` has key-value pairs whose keys are words and values are
-word frequencies. Suppose that we want to build indices for all the keys in
-`counter` and load the defined fastText word embedding for all such indexed
-words. First, we need a TokenIndexer object with `counter` as its argument
-
-```python
->>> token_indexer = text.indexer.TokenIndexer(counter)
-
-```
-
-Then, we can create a Glossary object by specifying `token_indexer` and 
`fasttext_simple` as its
-arguments.
-
-```python
->>> glossary = text.glossary.Glossary(token_indexer, fasttext_simple)
-
-```
-
-Now we are ready to look up the fastText word embedding vectors for indexed
-words.
-
-```python
->>> glossary.get_vecs_by_tokens(['hello', 'world'])
-
-[[  3.95669997e-01   2.14540005e-01  -3.53889987e-02  -2.42990002e-01
-    ...
-   -7.54180014e-01  -3.14429998e-01   2.40180008e-02  -7.61009976e-02]
- [  1.04440004e-01  -1.08580001e-01   2.72119999e-01   1.32990003e-01
-    ...
-   -3.73499990e-01   5.67310005e-02   5.60180008e-01   2.90190000e-02]]
-<NDArray 2x300 @cpu(0)>
-
-```
-
-### Use `glossary` in `gluon`
-
-To demonstrate how to use a glossary with the loaded word embedding in the
-`gluon` package, let us first obtain indices of the words 'hello' and 'world'.
-
-```python
->>> glossary.to_indices(['hello', 'world'])
-[2, 1]
-
-```
-
-We can obtain the vector representation for the words 'hello' and 'world'
-by specifying their indices (2 and 1) and the `glossary.idx_to_vec` in
-`mxnet.gluon.nn.Embedding`.
- 
-```python
->>> layer = gluon.nn.Embedding(len(glossary), glossary.vec_len)
->>> layer.initialize()
->>> layer.weight.set_data(glossary.idx_to_vec)
->>> layer(nd.array([2, 1]))
-
-[[  3.95669997e-01   2.14540005e-01  -3.53889987e-02  -2.42990002e-01
-    ...
-   -7.54180014e-01  -3.14429998e-01   2.40180008e-02  -7.61009976e-02]
- [  1.04440004e-01  -1.08580001e-01   2.72119999e-01   1.32990003e-01
-    ...
-   -3.73499990e-01   5.67310005e-02   5.60180008e-01   2.90190000e-02]]
-<NDArray 2x300 @cpu(0)>
-
-```
-
-
-## Glossary
-
-The glossary provides indexing and embedding for text tokens in a glossary. For
-each indexed token in a glossary, an embedding vector will be associated with
-it. Such embedding vectors can be loaded from externally hosted or custom
-pre-trained token embedding files, such as via instances of
-[`TokenEmbedding`](#mxnet.contrib.text.embedding.TokenEmbedding). 
-The input counter whose keys are
-candidate indices may be obtained via
-[`count_tokens_from_str`](#mxnet.contrib.text.utils.count_tokens_from_str).
-
-```eval_rst
-.. currentmodule:: mxnet.contrib.text.glossary
-.. autosummary::
-    :nosignatures:
-
-    Glossary
-```
-
-To get all the valid names for pre-trained embeddings and files, we can use
-[`TokenEmbedding.get_embedding_and_pretrained_file_names`](#mxnet.contrib.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names).
-
-```python
->>> text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names()
-{'glove': ['glove.42B.300d.txt', 'glove.6B.50d.txt', 'glove.6B.100d.txt',
-'glove.6B.200d.txt', 'glove.6B.300d.txt', 'glove.840B.300d.txt',
-'glove.twitter.27B.25d.txt', 'glove.twitter.27B.50d.txt',
-'glove.twitter.27B.100d.txt', 'glove.twitter.27B.200d.txt'],
-'fasttext': ['wiki.en.vec', 'wiki.simple.vec', 'wiki.zh.vec']}
-
-```
-
-To begin with, we can create a fastText word embedding object by specifying the
-embedding name `fasttext` and the pre-trained file `wiki.simple.vec`.
-
-```python
->>> fasttext_simple = text.embedding.TokenEmbedding.create('fasttext',
-...     pretrained_file_name='wiki.simple.vec')
-
-```
-
-Suppose that we have a simple text data set in the string format. We can count
-word frequency in the data set.
-
-```python
->>> text_data = " hello world \n hello nice world \n hi world \n"
->>> counter = text.utils.count_tokens_from_str(text_data)
-
-```
-
-The obtained `counter` has key-value pairs whose keys are words and values are
-word frequencies. Suppose that we want to build indices for the most frequent 2
-keys in `counter` and load the defined fastText word embedding for all these
-2 words. 
-
-```python
->>> token_indexer = text.indexer.TokenIndexer(counter, most_freq_count=2)
->>> glossary = text.glossary.Glossary(token_indexer, fasttext_simple)
-
-```
-
-Now we are ready to look up the fastText word embedding vectors for indexed
-words.
-
-```python
->>> glossary.get_vecs_by_tokens(['hello', 'world'])
-
-[[  3.95669997e-01   2.14540005e-01  -3.53889987e-02  -2.42990002e-01
-    ...
-   -7.54180014e-01  -3.14429998e-01   2.40180008e-02  -7.61009976e-02]
- [  1.04440004e-01  -1.08580001e-01   2.72119999e-01   1.32990003e-01
-    ...
-   -3.73499990e-01   5.67310005e-02   5.60180008e-01   2.90190000e-02]]
-<NDArray 2x300 @cpu(0)>
-
-```
-
-We can also access properties such as `token_to_idx` (mapping tokens to
-indices), `idx_to_token` (mapping indices to tokens), and `vec_len`
-(length of each embedding vector).
-
-```python
->>> glossary.token_to_idx
-{'<unk>': 0, 'world': 1, 'hello': 2, 'hi': 3, 'nice': 4}
->>> glossary.idx_to_token
-['<unk>', 'world', 'hello', 'hi', 'nice']
->>> len(glossary)
-5
->>> glossary.vec_len
-300
-
-```
-
-If a token is unknown to `glossary`, its embedding vector is initialized
-according to the default specification in `fasttext_simple` (all elements are
-0).
-
-```python
-
->>> glossary.get_vecs_by_tokens('unknownT0kEN')
-
-[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
-  ...
-  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
-<NDArray 300 @cpu(0)>
-
-```
-
-## Text token embedding
-
-The text token embedding builds indices for text tokens. Such indexed tokens 
can
-be used by instances of 
[`TokenEmbedding`](#mxnet.contrib.text.embedding.TokenEmbedding)
-and [`Glossary`](#mxnet.contrib.text.glossary.Glossary).
-
-To load token embeddings from an externally hosted pre-trained token embedding
-file, such as those of GloVe and FastText, use
-[`TokenEmbedding.create(embedding_name, 
pretrained_file_name)`](#mxnet.contrib.text.embedding.TokenEmbedding.create).
-To get all the available `embedding_name` and `pretrained_file_name`, use
-[`TokenEmbedding.get_embedding_and_pretrained_file_names()`](#mxnet.contrib.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names).
-
-Alternatively, to load embedding vectors from a custom pre-trained text token
-embedding file, use 
[`CustomEmbedding`](#mxnet.contrib.text.embedding.CustomEmbedding).
-
-
-```eval_rst
-.. currentmodule:: mxnet.contrib.text.embedding
-.. autosummary::
-    :nosignatures:
-
-    TokenEmbedding
-    GloVe
-    FastText
-    CustomEmbedding
-```
-
-To get all the valid names for pre-trained embeddings and files, we can use
-[`TokenEmbedding.get_embedding_and_pretrained_file_names`](#mxnet.contrib.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names).
-
-```python
->>> text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names()
-{'glove': ['glove.42B.300d.txt', 'glove.6B.50d.txt', 'glove.6B.100d.txt',
-'glove.6B.200d.txt', 'glove.6B.300d.txt', 'glove.840B.300d.txt',
-'glove.twitter.27B.25d.txt', 'glove.twitter.27B.50d.txt',
-'glove.twitter.27B.100d.txt', 'glove.twitter.27B.200d.txt'],
-'fasttext': ['wiki.en.vec', 'wiki.simple.vec', 'wiki.zh.vec']}
-
-```
-
-To begin with, we can create a GloVe word embedding object by specifying the
-embedding name `glove` and the pre-trained file `glove.6B.50d.txt`. The
-argument `init_unknown_vec` specifies default vector representation for any
-unknown token.
-
-```python
->>> glove_6b_50d = text.embedding.TokenEmbedding.create('glove',
-...     pretrained_file_name='glove.6B.50d.txt', init_unknown_vec=nd.zeros)
-
-```
-
-We can access properties such as `token_to_idx` (mapping tokens to indices),
-`idx_to_token` (mapping indices to tokens), `vec_len` (length of each embedding
-vector), and `unknown_token` (representation of any unknown token, default
-value is '<unk>').
-
-```python
->>> glove_6b_50d.token_to_idx['hi']
-11084
->>> glove_6b_50d.idx_to_token[11084]
-'hi'
->>> glove_6b_50d.vec_len
-50
->>> glove_6b_50d.unknown_token
-'<unk>'
-
-```
-
-For every unknown token, if its representation '<unk>' is encountered in the
-pre-trained token embedding file, index 0 of property `idx_to_vec` maps to the
-pre-trained token embedding vector loaded from the file; otherwise, index 0 of
-property `idx_to_vec` maps to the default token embedding vector specified via
-`init_unknown_vec` (set to nd.zeros here). Since the pre-trained file
-does not have a vector for the token '<unk>', index 0 has to map to an
-additional token '<unk>' and the number of tokens in the embedding is 400,001.
-
-
-```python
->>> len(glove_6b_50d)
-400001
->>> glove_6b_50d.idx_to_vec[0]
-
-[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
-  ...
-  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
-<NDArray 50 @cpu(0)>
->>> glove_6b_50d.get_vecs_by_tokens('unknownT0kEN')
-
-[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
-  ...
-  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
-<NDArray 50 @cpu(0)>
->>> glove_6b_50d.get_vecs_by_tokens(['unknownT0kEN', 'unknownT0kEN'])
-
-[[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
-   ...
-   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
- [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
-   ...
-   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]
-<NDArray 2x50 @cpu(0)>
-
-```
-
-
-### Implement a new text token embedding
-
-For ``optimizer``, create a subclass of
-[`TokenEmbedding`](#mxnet.contrib.text.embedding.TokenEmbedding).
-Also add ``@TokenEmbedding.register`` before this class. See
-[`embedding.py`](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/contrib/text/embedding.py)
-for examples.
-
-
-## Text token indexer
-
-The text token indexer builds indices for text tokens. Such indexed tokens can
-be used by instances of 
[`TokenEmbedding`](#mxnet.contrib.text.embedding.TokenEmbedding)
-and [`Glossary`](#mxnet.contrib.text.glossary.Glossary). The input
-counter whose keys are candidate indices may be obtained via
-[`count_tokens_from_str`](#mxnet.contrib.text.utils.count_tokens_from_str).
-
-
-```eval_rst
-.. currentmodule:: mxnet.contrib.text.indexer
-.. autosummary::
-    :nosignatures:
-
-    TokenIndexer
-```
-
-Suppose that we have a simple text data set in the string format. We can count
-word frequency in the data set.
-
-```python
->>> text_data = " hello world \n hello nice world \n hi world \n"
->>> counter = text.utils.count_tokens_from_str(text_data)
-
-```
-
-The obtained `counter` has key-value pairs whose keys are words and values are
-word frequencies. Suppose that we want to build indices for the 2 most frequent
-keys in `counter` with the unknown token representation '<UnK>' and a reserved
-token '<pad>'.
-
-```python
->>> token_indexer = text.indexer.TokenIndexer(counter, most_freq_count=2,
-...     unknown_token='<UnK>', reserved_tokens=['<pad>'])
-
-```
-
-We can access properties such as `token_to_idx` (mapping tokens to indices),
-`idx_to_token` (mapping indices to tokens), `vec_len` (length of each embedding
-vector), and `unknown_token` (representation of any unknown token) and
-`reserved_tokens`.
-
-```python
->>> token_indexer = text.indexer.TokenIndexer(counter, most_freq_count=2,
-...     unknown_token='<UnK>', reserved_tokens=['<pad>'])
-
-```
-
-```python
->>> token_indexer.token_to_idx
-{'<UnK>': 0, '<pad>': 1, 'world': 2, 'hello': 3}
->>> token_indexer.idx_to_token
-['<UnK>', '<pad>', 'world', 'hello']
->>> token_indexer.unknown_token
-'<UnK>'
->>> token_indexer.reserved_tokens
-['<pad>']
->>> len(token_indexer)
-4
-```
-
-Besides the specified unknown token '<UnK>' and reserved_token '<pad>' are
-indexed, the 2 most frequent words 'world' and 'hello' are also indexed.
-
-
-
-## Text utilities
-
-The following functions provide utilities for text data processing.
-
-```eval_rst
-.. currentmodule:: mxnet.contrib.text.utils
-.. autosummary::
-    :nosignatures:
-
-    count_tokens_from_str
-```
-
-
-
-
-## API Reference
-
-<script type="text/javascript" 
src='../../_static/js/auto_module_index.js'></script>
-
-```eval_rst
-
-.. automodule:: mxnet.contrib.text.glossary
-.. autoclass:: mxnet.contrib.text.glossary.Glossary
-    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
-
-.. automodule:: mxnet.contrib.text.embedding
-.. autoclass:: mxnet.contrib.text.embedding.TokenEmbedding
-    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens, 
register, create, get_embedding_and_pretrained_file_names
-.. autoclass:: mxnet.contrib.text.embedding.GloVe
-    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
-.. autoclass:: mxnet.contrib.text.embedding.FastText
-    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
-.. autoclass:: mxnet.contrib.text.embedding.CustomEmbedding
-    :members: get_vecs_by_tokens, update_token_vectors, to_indices, to_tokens
-
-.. automodule:: mxnet.contrib.text.indexer
-.. autoclass:: mxnet.contrib.text.indexer.TokenIndexer
-    :members: to_indices, to_tokens
-
-.. automodule:: mxnet.contrib.text.utils
-    :members: count_tokens_from_str
-
-```
-<script>auto_index("api-reference");</script>
\ No newline at end of file
diff --git a/python/mxnet/contrib/text/__init__.py 
b/python/mxnet/contrib/text/__init__.py
index fff2b94..0a0f173 100644
--- a/python/mxnet/contrib/text/__init__.py
+++ b/python/mxnet/contrib/text/__init__.py
@@ -19,6 +19,5 @@
 """This module includes utilities for indexing and embedding text."""
 
 from . import utils
-from . import indexer
+from . import vocab
 from . import embedding
-from . import glossary
diff --git a/python/mxnet/contrib/text/embedding.py 
b/python/mxnet/contrib/text/embedding.py
index 54635f1..4fc6aac 100644
--- a/python/mxnet/contrib/text/embedding.py
+++ b/python/mxnet/contrib/text/embedding.py
@@ -17,6 +17,7 @@
 
 # coding: utf-8
 # pylint: disable=consider-iterating-dictionary
+# pylint: disable=super-init-not-called
 
 """Text token embeddings."""
 from __future__ import absolute_import
@@ -30,23 +31,120 @@ import warnings
 import zipfile
 
 from . import _constants as C
-from . import indexer
+from . import vocab
 from ... import ndarray as nd
 from ... import registry
 
 
-class TokenEmbedding(indexer.TokenIndexer):
+def register(embedding_cls):
+    """Registers a new token embedding.
+
+
+    Once an embedding is registered, we can create an instance of this 
embedding with
+    :func:`~mxnet.contrib.text.embedding.create`.
+
+
+    Examples
+    --------
+    >>> @mxnet.contrib.text.embedding.register
+    ... class MyTextEmbed(mxnet.contrib.text.embedding._TokenEmbedding):
+    ...     def __init__(self, pretrained_file_name='my_pretrain_file'):
+    ...         pass
+    >>> embed = mxnet.contrib.text.embedding.create('MyTokenEmbed')
+    >>> print(type(embed))
+    <class '__main__.MyTokenEmbed'>
+    """
+
+    register_text_embedding = registry.get_register_func(_TokenEmbedding, 
'token embedding')
+    return register_text_embedding(embedding_cls)
+
+
+def create(embedding_name, **kwargs):
+    """Creates an instance of token embedding.
+
+
+    Creates a token embedding instance by loading embedding vectors from an 
externally hosted
+    pre-trained token embedding file, such as those of GloVe and FastText. To 
get all the valid
+    `embedding_name` and `pretrained_file_name`, use
+    `mxnet.contrib.text.embedding.get_pretrained_file_names()`.
+
+
+    Parameters
+    ----------
+    embedding_name : str
+        The token embedding name (case-insensitive).
+
+
+    Returns
+    -------
+    An instance of `mxnet.contrib.text.glossary._TokenEmbedding`:
+        A token embedding instance that loads embedding vectors from an 
externally hosted
+        pre-trained token embedding file.
+    """
+
+    create_text_embedding = registry.get_create_func(_TokenEmbedding, 'token 
embedding')
+    return create_text_embedding(embedding_name, **kwargs)
+
+
+def get_pretrained_file_names(embedding_name=None):
+    """Get valid token embedding names and their pre-trained file names.
+
+
+    To load token embedding vectors from an externally hosted pre-trained 
token embedding file,
+    such as those of GloVe and FastText, one should use
+    `mxnet.contrib.text.embedding.create(embedding_name, 
pretrained_file_name)`.
+    This method returns all the valid names of `pretrained_file_name` for the 
specified
+    `embedding_name`. If `embedding_name` is set to None, this method returns 
all the valid
+    names of `embedding_name` with their associated `pretrained_file_name`.
+
+
+    Parameters
+    ----------
+    embedding_name : str or None, default None
+        The pre-trained token embedding name.
+
+
+    Returns
+    -------
+    dict or list:
+        A list of all the valid pre-trained token embedding file names 
(`pretrained_file_name`)
+        for the specified token embedding name (`embedding_name`). If the text 
embeding name is
+        set to None, returns a dict mapping each valid token embedding name to 
a list of valid
+        pre-trained files (`pretrained_file_name`). They can be plugged into
+        `mxnet.contrib.text.embedding.create(embedding_name,
+        pretrained_file_name)`.
+    """
+
+    text_embedding_reg = registry.get_registry(_TokenEmbedding)
+
+    if embedding_name is not None:
+        if embedding_name not in text_embedding_reg:
+            raise KeyError('Cannot find `embedding_name` %s. Use '
+                           '`get_pretrained_file_names('
+                           'embedding_name=None).keys()` to get all the valid 
embedding '
+                           'names.' % embedding_name)
+        return 
list(text_embedding_reg[embedding_name].pretrained_file_name_sha1.keys())
+    else:
+        return {embedding_name: 
list(embedding_cls.pretrained_file_name_sha1.keys())
+                for embedding_name, embedding_cls in 
registry.get_registry(_TokenEmbedding).items()}
+
+
+class _TokenEmbedding(vocab.Vocabulary):
     """Token embedding base class.
 
 
     To load token embeddings from an externally hosted pre-trained token 
embedding file, such as
-    those of GloVe and FastText, use `TokenEmbedding.create(embedding_name, 
pretrained_file_name)`.
+    those of GloVe and FastText, use
+    :func:`~mxnet.contrib.text.embedding.create(embedding_name, 
pretrained_file_name)`.
     To get all the available `embedding_name` and `pretrained_file_name`, use
-    `TokenEmbedding.get_embedding_and_pretrained_file_names()`.
+    :func:`~mxnet.contrib.text.embedding.get_pretrained_file_names()`.
 
     Alternatively, to load embedding vectors from a custom pre-trained token 
embedding file, use
     :class:`~mxnet.contrib.text.embedding.CustomEmbedding`.
 
+    Moreover, to load composite embedding vectors, such as to concatenate 
embedding vectors, use
+    :class:`~mxnet.contrib.text.embedding.CompositeEmbedding`.
+
     For every unknown token, if its representation `self.unknown_token` is 
encountered in the
     pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the 
pre-trained token
     embedding vector loaded from the file; otherwise, index 0 of 
`self.idx_to_vec` maps to the
@@ -55,8 +153,11 @@ class TokenEmbedding(indexer.TokenIndexer):
     If a token is encountered multiple times in the pre-trained token 
embedding file, only the
     first-encountered token embedding vector will be loaded and the rest will 
be skipped.
 
-    For the same token, its index and embedding vector may vary across 
different instances of
-    :class:`~mxnet.contrib.text.embedding.TokenEmbedding`.
+    The indexed tokens in a text token embedding may come from a vocabulary or 
from the loaded
+    embedding vectors. In the former case, only the indexed tokens in a 
vocabulary are associated
+    with the loaded embedding vectors, such as loaded from a pre-trained token 
embedding file. In
+    the later case, all the tokens from the loaded embedding vectors, such as 
loaded from a
+    pre-trained token embedding file, are taken as the indexed tokens of the 
embedding.
 
 
     Properties
@@ -79,7 +180,7 @@ class TokenEmbedding(indexer.TokenIndexer):
     """
 
     def __init__(self, **kwargs):
-        super(TokenEmbedding, self).__init__(**kwargs)
+        super(_TokenEmbedding, self).__init__(**kwargs)
 
     @classmethod
     def _get_download_file_name(cls, pretrained_file_name):
@@ -178,8 +279,8 @@ class TokenEmbedding(indexer.TokenIndexer):
                 else:
                     if vec_len is None:
                         vec_len = len(elems)
-                        # Reserve a vector slot for the unknown token at the
-                        # very beggining because the unknown index is 0.
+                        # Reserve a vector slot for the unknown token at the 
very beggining because
+                        # the unknown index is 0.
                         all_elems.extend([0] * vec_len)
                     else:
                         assert len(elems) == vec_len, \
@@ -200,6 +301,59 @@ class TokenEmbedding(indexer.TokenIndexer):
         else:
             self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
 
+    def _index_tokens_from_vocabulary(self, vocabulary):
+        self._token_to_idx = vocabulary.token_to_idx.copy() \
+            if vocabulary.token_to_idx is not None else None
+        self._idx_to_token = vocabulary.idx_to_token[:] \
+            if vocabulary.idx_to_token is not None else None
+        self._unknown_token = vocabulary.unknown_token
+        self._reserved_tokens = vocabulary.reserved_tokens[:] \
+            if vocabulary.reserved_tokens is not None else None
+
+    def _set_idx_to_vec_by_embeddings(self, token_embeddings, vocab_len, 
vocab_idx_to_token):
+        """Sets the mapping between token indices and token embedding vectors.
+
+
+        Parameters
+        ----------
+        token_embeddings : instance or list 
`mxnet.contrib.text.embedding._TokenEmbedding`
+            One or multiple pre-trained token embeddings to load. If it is a 
list of multiple
+            embeddings, these embedding vectors will be concatenated for each 
token.
+        vocab_len : int
+            Length of vocabulary whose tokens are indexed in the token 
embedding.
+        vocab_idx_to_token: list of str
+            A list of indexed tokens in the vocabulary. These tokens are 
indexed in the token
+            embedding.
+        """
+
+        new_vec_len = sum(embed.vec_len for embed in token_embeddings)
+        new_idx_to_vec = nd.zeros(shape=(vocab_len, new_vec_len))
+
+        col_start = 0
+        # Concatenate all the embedding vectors in token_embeddings.
+        for embed in token_embeddings:
+            col_end = col_start + embed.vec_len
+            # Cancatenate vectors of the unknown token.
+            new_idx_to_vec[0, col_start:col_end] = embed.idx_to_vec[0]
+            new_idx_to_vec[1:, col_start:col_end] = 
embed.get_vecs_by_tokens(vocab_idx_to_token[1:])
+            col_start = col_end
+
+        self._vec_len = new_vec_len
+        self._idx_to_vec = new_idx_to_vec
+
+    def _build_embedding_for_vocabulary(self, vocabulary):
+        if vocabulary is not None:
+            assert isinstance(vocabulary, vocab.Vocabulary), \
+                'The argument `vocabulary` must be an instance of ' \
+                'mxnet.contrib.text.vocab.Vocabulary.'
+
+            # Set _idx_to_vec so that indices of tokens from vocabulary are 
associated with the
+            # loaded token embedding vectors.
+            self._set_idx_to_vec_by_embeddings([self], len(vocabulary), 
vocabulary.idx_to_token)
+
+            # Index tokens from vocabulary.
+            self._index_tokens_from_vocabulary(vocabulary)
+
     @property
     def vec_len(self):
         return self._vec_len
@@ -276,9 +430,8 @@ class TokenEmbedding(indexer.TokenIndexer):
             assert isinstance(new_vectors, nd.NDArray) and 
len(new_vectors.shape) == 2, \
                 '`new_vectors` must be a 2-D NDArray if `tokens` is a list of 
multiple strings.'
         assert new_vectors.shape == (len(tokens), self.vec_len), \
-            'The length of new_vectors must be equal to the number of tokens ' 
\
-            'and the width of new_vectors must be equal to the dimension of ' \
-            'embeddings of the glossary.'
+            'The length of new_vectors must be equal to the number of tokens 
and the width of' \
+            'new_vectors must be equal to the dimension of embeddings of the 
glossary.'
 
         indices = []
         for token in tokens:
@@ -292,56 +445,6 @@ class TokenEmbedding(indexer.TokenIndexer):
 
         self._idx_to_vec[nd.array(indices)] = new_vectors
 
-    @staticmethod
-    def register(embedding_cls):
-        """Registers a new token embedding.
-
-
-        Once an embedding is registered, we can create an instance of this 
embedding with
-        :func:`~mxnet.contrib.text.embedding.TokenEmbedding.create`.
-
-
-        Examples
-        --------
-        >>> @mxnet.contrib.text.embedding.TokenEmbedding.register
-        ... class MyTextEmbed(mxnet.contrib.text.embedding.TokenEmbedding):
-        ...     def __init__(self, pretrained_file_name='my_pretrain_file'):
-        ...         pass
-        >>> embed = 
mxnet.contrib.text.embedding.TokenEmbedding.create('MyTokenEmbed')
-        >>> print(type(embed))
-        <class '__main__.MyTokenEmbed'>
-        """
-
-        register_text_embedding = registry.get_register_func(TokenEmbedding, 
'token embedding')
-        return register_text_embedding(embedding_cls)
-
-    @staticmethod
-    def create(embedding_name, **kwargs):
-        """Creates an instance of 
:class:`~mxnet.contrib.text.embedding.TokenEmbedding`.
-
-
-        Creates a token embedding instance by loading embedding vectors from 
an externally hosted
-        pre-trained token embedding file, such as those of GloVe and FastText. 
To get all the valid
-        `embedding_name` and `pretrained_file_name`, use
-        
`mxnet.contrib.text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names()`.
-
-
-        Parameters
-        ----------
-        embedding_name : str
-            The token embedding name (case-insensitive).
-
-
-        Returns
-        -------
-        :class:`~mxnet.contrib.text.glossary.TokenEmbedding`:
-            A token embedding instance that loads embedding vectors from an 
externally hosted
-            pre-trained token embedding file.
-        """
-
-        create_text_embedding = registry.get_create_func(TokenEmbedding, 
'token embedding')
-        return create_text_embedding(embedding_name, **kwargs)
-
     @classmethod
     def _check_pretrained_file_names(cls, pretrained_file_name):
         """Checks if a pre-trained token embedding file name is valid.
@@ -360,55 +463,9 @@ class TokenEmbedding(indexer.TokenIndexer):
                            (pretrained_file_name, embedding_name, 
embedding_name,
                             ', '.join(cls.pretrained_file_name_sha1.keys())))
 
-    @staticmethod
-    def get_embedding_and_pretrained_file_names(embedding_name=None):
-        """Get valid token embedding names and their pre-trained file names.
-
 
-        To load token embedding vectors from an externally hosted pre-trained 
token embedding file,
-        such as those of GloVe and FastText, one should use
-        `mxnet.contrib.text.embedding.TokenEmbedding.create(embedding_name, 
pretrained_file_name)`.
-        This method returns all the valid names of `pretrained_file_name` for 
the specified
-        `embedding_name`. If `embedding_name` is set to None, this method 
returns all the valid
-        names of `embedding_name` with associated `pretrained_file_name`.
-
-
-        Parameters
-        ----------
-        embedding_name : str or None, default None
-            The pre-trained token embedding name.
-
-
-        Returns
-        -------
-        dict or list:
-            A list of all the valid pre-trained token embedding file names 
(`pretrained_file_name`)
-            for the specified token embedding name (`embedding_name`). If the 
text embeding name is
-            set to None, returns a dict mapping each valid token embedding 
name to a list of valid
-            pre-trained files (`pretrained_file_name`). They can be plugged 
into
-            `mxnet.contrib.text.embedding.TokenEmbedding.create(embedding_name,
-            pretrained_file_name)`.
-        """
-
-        text_embedding_reg = registry.get_registry(TokenEmbedding)
-
-        if embedding_name is not None:
-            if embedding_name not in text_embedding_reg:
-                raise KeyError('Cannot find `embedding_name` %s. Use '
-                               '`get_embedding_and_pretrained_file_names('
-                               'embedding_name=None).keys()` to get all the 
valid embedding '
-                               'names.' % embedding_name)
-            return list(text_embedding_reg[
-                embedding_name].pretrained_file_name_sha1.keys())
-        else:
-            return {embedding_name: list(
-                embedding_cls.pretrained_file_name_sha1.keys())
-                    for embedding_name, embedding_cls in
-                    registry.get_registry(TokenEmbedding).items()}
-
-
-@TokenEmbedding.register
-class GloVe(TokenEmbedding):
+@register
+class GloVe(_TokenEmbedding):
     """The GloVe word embedding.
 
 
@@ -437,12 +494,17 @@ class GloVe(TokenEmbedding):
 
     Parameters
     ----------
-    pretrain_file : str, default 'glove.840B.300d.txt'
+    pretrained_file_name : str, default 'glove.840B.300d.txt'
         The name of the pre-trained token embedding file.
-    embed_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+    embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
         The root directory for storing embedding-related files.
-    unknown_vec : callback
+    init_unknown_vec : callback
         The callback used to initialize the embedding vector for the unknown 
token.
+    vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
+        It contains the tokens to index. Each indexed token will be associated 
with the loaded
+        embedding vectors, such as loaded from a pre-trained token embedding 
file. If None, all the
+        tokens from the loaded embedding vectors, such as loaded from a 
pre-trained token embedding
+        file, will be indexed.
 
 
     Properties
@@ -472,7 +534,7 @@ class GloVe(TokenEmbedding):
 
     @classmethod
     def _get_download_file_name(cls, pretrained_file_name):
-        # Map a pretrained embedding file to its archive to download.
+        # Map a pre-trained embedding file to its archive to download.
         src_archive = {archive.split('.')[1]: archive for archive in
                        GloVe.pretrained_archive_name_sha1.keys()}
         archive = src_archive[pretrained_file_name.split('.')[1]]
@@ -480,7 +542,7 @@ class GloVe(TokenEmbedding):
 
     def __init__(self, pretrained_file_name='glove.840B.300d.txt',
                  embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
-                 init_unknown_vec=nd.zeros, **kwargs):
+                 init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
         GloVe._check_pretrained_file_names(pretrained_file_name)
 
         super(GloVe, self).__init__(**kwargs)
@@ -488,9 +550,12 @@ class GloVe(TokenEmbedding):
 
         self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
 
+        if vocabulary is not None:
+            self._build_embedding_for_vocabulary(vocabulary)
 
-@TokenEmbedding.register
-class FastText(TokenEmbedding):
+
+@register
+class FastText(_TokenEmbedding):
     """The fastText word embedding.
 
 
@@ -527,12 +592,17 @@ class FastText(TokenEmbedding):
 
     Parameters
     ----------
-    pretrain_file : str, default 'wiki.en.vec'
+    pretrained_file_name : str, default 'wiki.en.vec'
         The name of the pre-trained token embedding file.
-    embed_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+    embedding_root : str, default os.path.join('~', '.mxnet', 'embeddings')
         The root directory for storing embedding-related files.
-    unknown_vec : callback
+    init_unknown_vec : callback
         The callback used to initialize the embedding vector for the unknown 
token.
+    vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
+        It contains the tokens to index. Each indexed token will be associated 
with the loaded
+        embedding vectors, such as loaded from a pre-trained token embedding 
file. If None, all the
+        tokens from the loaded embedding vectors, such as loaded from a 
pre-trained token embedding
+        file, will be indexed.
 
 
     Properties
@@ -559,7 +629,7 @@ class FastText(TokenEmbedding):
 
     def __init__(self, pretrained_file_name='wiki.simple.vec',
                  embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
-                 init_unknown_vec=nd.zeros, **kwargs):
+                 init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
         FastText._check_pretrained_file_names(pretrained_file_name)
 
         super(FastText, self).__init__(**kwargs)
@@ -567,8 +637,11 @@ class FastText(TokenEmbedding):
 
         self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
 
+        if vocabulary is not None:
+            self._build_embedding_for_vocabulary(vocabulary)
+
 
-class CustomEmbedding(TokenEmbedding):
+class CustomEmbedding(_TokenEmbedding):
     """User-defined token embedding.
 
     This is to load embedding vectors from a user-defined pre-trained text 
embedding file.
@@ -585,13 +658,20 @@ class CustomEmbedding(TokenEmbedding):
 
     Parameters
     ----------
-    pretrain_file_path : str
+    pretrained_file_path : str
         The path to the custom pre-trained token embedding file.
     elem_delim : str, default ' '
         The delimiter for splitting a token and every embedding vector element 
value on the same
         line of the custom pre-trained token embedding file.
-    unknown_vec : callback
+    encoding : str, default 'utf8'
+        The encoding scheme for reading the custom pre-trained token embedding 
file.
+    init_unknown_vec : callback
         The callback used to initialize the embedding vector for the unknown 
token.
+    vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
+        It contains the tokens to index. Each indexed token will be associated 
with the loaded
+        embedding vectors, such as loaded from a pre-trained token embedding 
file. If None, all the
+        tokens from the loaded embedding vectors, such as loaded from a 
pre-trained token embedding
+        file, will be indexed.
 
 
     Properties
@@ -614,6 +694,71 @@ class CustomEmbedding(TokenEmbedding):
     """
 
     def __init__(self, pretrained_file_path, elem_delim=' ', encoding='utf8',
-                 init_unknown_vec=nd.zeros, **kwargs):
+                 init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
         super(CustomEmbedding, self).__init__(**kwargs)
         self._load_embedding(pretrained_file_path, elem_delim, 
init_unknown_vec, encoding)
+
+        if vocabulary is not None:
+            self._build_embedding_for_vocabulary(vocabulary)
+
+
+class CompositeEmbedding(_TokenEmbedding):
+    """Composite token embeddings.
+
+
+    For each indexed token in a vocabulary, multiple embedding vectors, such 
as concatenated
+    multiple embedding vectors, will be associated with it. Such embedding 
vectors can be loaded
+    from externally hosted or custom pre-trained token embedding files, such 
as via token embedding
+    instances.
+
+
+    Parameters
+    ----------
+    vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`
+        For each indexed token in a vocabulary, multiple embedding vectors, 
such as concatenated
+        multiple embedding vectors, will be associated with it.
+    token_embeddings : instance or list of 
`mxnet.contrib.text.embedding._TokenEmbedding`
+        One or multiple pre-trained token embeddings to load. If it is a list 
of multiple
+        embeddings, these embedding vectors will be concatenated for each 
token.
+
+
+    Properties
+    ----------
+    token_to_idx : dict mapping str to int
+        A dict mapping each token to its index integer.
+    idx_to_token : list of strs
+        A list of indexed tokens where the list indices and the token indices 
are aligned.
+    unknown_token : hashable object
+        The representation for any unknown token. In other words, any unknown 
token will be indexed
+        as the same representation.
+    reserved_tokens : list of strs or None
+        A list of reserved tokens that will always be indexed.
+    vec_len : int
+        The length of the embedding vector for each token.
+    idx_to_vec : mxnet.ndarray.NDArray
+        For all the indexed tokens in this embedding, this NDArray maps each 
token's index to an
+        embedding vector. The largest valid index maps to the initialized 
embedding vector for every
+        reserved token, such as an unknown_token token and a padding token.
+    """
+    def __init__(self, vocabulary, token_embeddings):
+
+        # Sanity checks.
+        assert isinstance(vocabulary, vocab.Vocabulary), \
+            'The argument `vocabulary` must be an instance of ' \
+            'mxnet.contrib.text.indexer.Vocabulary.'
+
+        if not isinstance(token_embeddings, list):
+            token_embeddings = [token_embeddings]
+
+        for embed in token_embeddings:
+            assert isinstance(embed, _TokenEmbedding), \
+                'The argument `token_embeddings` must be an instance or a list 
of instances ' \
+                'of `mxnet.contrib.text.embedding.TextEmbedding` whose 
embedding vectors will be' \
+                'loaded or concatenated-then-loaded to map to the indexed 
tokens.'
+
+        # Index tokens.
+        self._index_tokens_from_vocabulary(vocabulary)
+
+        # Set _idx_to_vec so that indices of tokens from keys of `counter` are 
associated with token
+        # embedding vectors from `token_embeddings`.
+        self._set_idx_to_vec_by_embeddings(token_embeddings, len(self), 
self.idx_to_token)
diff --git a/python/mxnet/contrib/text/glossary.py 
b/python/mxnet/contrib/text/glossary.py
deleted file mode 100644
index 88517e5..0000000
--- a/python/mxnet/contrib/text/glossary.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding: utf-8
-# pylint: disable=super-init-not-called
-
-"""Index text tokens and load their embeddings."""
-from __future__ import absolute_import
-from __future__ import print_function
-
-from . import embedding
-from . import indexer
-from ... import ndarray as nd
-
-
-class Glossary(embedding.TokenEmbedding):
-    """Indexing and embedding for text tokens in a glossary.
-
-
-    For each indexed token in a glossary, an embedding vector will be 
associated with it. Such
-    embedding vectors can be loaded from externally hosted or custom 
pre-trained token embedding
-    files, such as via instances of 
:class:`~mxnet.contrib.text.embedding.TokenEmbedding`.
-
-
-    Parameters
-    ----------
-    token_indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`
-        It contains the indexed tokens to load, where each token is associated 
with an index.
-    token_embeddings : instance or list of 
:class:`~mxnet.contrib.text.embedding.TokenEmbedding`
-        One or multiple pre-trained token embeddings to load. If it is a list 
of multiple
-        embeddings, these embedding vectors will be concatenated for each 
token.
-
-
-    Properties
-    ----------
-    token_to_idx : dict mapping str to int
-        A dict mapping each token to its index integer.
-    idx_to_token : list of strs
-        A list of indexed tokens where the list indices and the token indices 
are aligned.
-    unknown_token : hashable object
-        The representation for any unknown token. In other words, any unknown 
token will be indexed
-        as the same representation.
-    reserved_tokens : list of strs or None
-        A list of reserved tokens that will always be indexed.
-    vec_len : int
-        The length of the embedding vector for each token.
-    idx_to_vec : mxnet.ndarray.NDArray
-        For all the indexed tokens in this embedding, this NDArray maps each 
token's index to an
-        embedding vector. The largest valid index maps to the initialized 
embedding vector for every
-        reserved token, such as an unknown_token token and a padding token.
-    """
-    def __init__(self, token_indexer, token_embeddings):
-
-        # Sanity checks.
-        assert isinstance(token_indexer, indexer.TokenIndexer), \
-            'The argument `token_indexer` must be an instance of ' \
-            'mxnet.contrib.text.indexer.TokenIndexer.'
-
-        if not isinstance(token_embeddings, list):
-            token_embeddings = [token_embeddings]
-
-        for embed in token_embeddings:
-            assert isinstance(embed, embedding.TokenEmbedding), \
-                'The argument `token_embeddings` must be an instance or a list 
of instances ' \
-                'of `mxnet.contrib.text.embedding.TextEmbedding` whose 
embedding vectors will be' \
-                'loaded or concatenated-then-loaded to map to the indexed 
tokens.'
-
-        # Index tokens.
-        self._token_to_idx = token_indexer.token_to_idx.copy() \
-            if token_indexer.token_to_idx is not None else None
-        self._idx_to_token = token_indexer.idx_to_token[:] \
-            if token_indexer.idx_to_token is not None else None
-        self._unknown_token = token_indexer.unknown_token
-        self._reserved_tokens = token_indexer.reserved_tokens[:] \
-            if token_indexer.reserved_tokens is not None else None
-
-        # Set _idx_to_vec so that indices of tokens from keys of `counter` are
-        # associated with token embedding vectors from `token_embeddings`.
-        self._set_idx_to_vec_by_embeds(token_embeddings)
-
-    def _set_idx_to_vec_by_embeds(self, token_embeddings):
-        """Sets the mapping between token indices and token embedding vectors.
-
-
-        Parameters
-        ----------
-        token_embeddings : an instance or a list of instances of
-            :class:`~mxnet.contrib.text.embedding.TokenEmbedding`
-            One or multiple pre-trained token embeddings to load. If it is a 
list of multiple
-            embeddings, these embedding vectors will be concatenated for each 
token.
-        """
-
-        self._vec_len = sum(embed.vec_len for embed in token_embeddings)
-        self._idx_to_vec = nd.zeros(shape=(len(self), self.vec_len))
-
-        col_start = 0
-        # Concatenate all the embedding vectors in token_embeddings.
-        for embed in token_embeddings:
-            col_end = col_start + embed.vec_len
-            # Cancatenate vectors of the unknown token.
-            self._idx_to_vec[0, col_start:col_end] = embed.idx_to_vec[0]
-            self._idx_to_vec[1:, col_start:col_end] = embed.get_vecs_by_tokens(
-                self.idx_to_token[1:])
-            col_start = col_end
diff --git a/python/mxnet/contrib/text/indexer.py 
b/python/mxnet/contrib/text/vocab.py
similarity index 94%
rename from python/mxnet/contrib/text/indexer.py
rename to python/mxnet/contrib/text/vocab.py
index 1add7cf..04c3326 100644
--- a/python/mxnet/contrib/text/indexer.py
+++ b/python/mxnet/contrib/text/vocab.py
@@ -27,13 +27,12 @@ import collections
 from . import _constants as C
 
 
-class TokenIndexer(object):
+class Vocabulary(object):
     """Indexing for text tokens.
 
 
     Build indices for the unknown token, reserved tokens, and input counter 
keys. Indexed tokens can
-    be used by instances of 
:class:`~mxnet.contrib.text.embedding.TokenEmbedding`, such as instances
-    of :class:`~mxnet.contrib.text.glossary.Glossary`.
+    be used by token embeddings.
 
 
     Parameters
@@ -69,8 +68,7 @@ class TokenIndexer(object):
     token_to_idx : dict mapping str to int
         A dict mapping each token to its index integer.
     idx_to_token : list of strs
-        A list of indexed tokens where the list indices and the token indices
-        are aligned.
+        A list of indexed tokens where the list indices and the token indices 
are aligned.
     unknown_token : hashable object
         The representation for any unknown token. In other words, any unknown 
token will be indexed
         as the same representation.
@@ -160,7 +158,7 @@ class TokenIndexer(object):
         return self._reserved_tokens
 
     def to_indices(self, tokens):
-        """Converts tokens to indices according to the text indexer.
+        """Converts tokens to indices according to the vocabulary.
 
 
         Parameters
@@ -172,7 +170,7 @@ class TokenIndexer(object):
         Returns
         -------
         int or list of ints
-            A token index or a list of token indices according to the text 
indexer.
+            A token index or a list of token indices according to the 
vocabulary.
         """
 
         to_reduce = False
@@ -186,7 +184,7 @@ class TokenIndexer(object):
         return indices[0] if to_reduce else indices
 
     def to_tokens(self, indices):
-        """Converts token indices to tokens according to the text indexer.
+        """Converts token indices to tokens according to the vocabulary.
 
 
         Parameters
@@ -198,7 +196,7 @@ class TokenIndexer(object):
         Returns
         -------
         str or list of strs
-            A token or a list of tokens according to the text indexer.
+            A token or a list of tokens according to the vocabulary.
         """
 
         to_reduce = False
diff --git a/tests/python/unittest/test_contrib_text.py 
b/tests/python/unittest/test_contrib_text.py
index dc0e7bc..673f975 100644
--- a/tests/python/unittest/test_contrib_text.py
+++ b/tests/python/unittest/test_contrib_text.py
@@ -74,71 +74,66 @@ def test_count_tokens_from_str():
 def test_tokens_to_indices():
     counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
 
-    indexer = text.indexer.TokenIndexer(counter, most_freq_count=None, 
min_freq=1,
-                                        unknown_token='<unk>', 
reserved_tokens=None)
+    vocab = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
+                                  reserved_tokens=None)
 
-    i1 = indexer.to_indices('c')
+    i1 = vocab.to_indices('c')
     assert i1 == 1
 
-    i2 = indexer.to_indices(['c'])
+    i2 = vocab.to_indices(['c'])
     assert i2 == [1]
 
-    i3 = indexer.to_indices(['<unk>', 'non-exist'])
+    i3 = vocab.to_indices(['<unk>', 'non-exist'])
     assert i3 == [0, 0]
 
-    i4 = indexer.to_indices(['a', 'non-exist', 'a', 'b'])
+    i4 = vocab.to_indices(['a', 'non-exist', 'a', 'b'])
     assert i4 == [3, 0, 3, 2]
 
 
 def test_indices_to_tokens():
     counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
 
-    indexer = text.indexer.TokenIndexer(counter, most_freq_count=None, 
min_freq=1,
-                                        unknown_token='<unknown>', 
reserved_tokens=None)
-    i1 = indexer.to_tokens(1)
+    vocab = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1,
+                                  unknown_token='<unknown>', 
reserved_tokens=None)
+    i1 = vocab.to_tokens(1)
     assert i1 == 'c'
 
-    i2 = indexer.to_tokens([1])
+    i2 = vocab.to_tokens([1])
     assert i2 == ['c']
 
-    i3 = indexer.to_tokens([0, 0])
+    i3 = vocab.to_tokens([0, 0])
     assert i3 == ['<unknown>', '<unknown>']
 
-    i4 = indexer.to_tokens([3, 0, 3, 2])
+    i4 = vocab.to_tokens([3, 0, 3, 2])
     assert i4 == ['a', '<unknown>', 'a', 'b']
 
-    assertRaises(ValueError, indexer.to_tokens, 100)
+    assertRaises(ValueError, vocab.to_tokens, 100)
 
 
 def test_download_embed():
-    @text.embedding.TokenEmbedding.register
-    class Test(text.embedding.TokenEmbedding):
-        # 33 bytes
+    @text.embedding.register
+    class Test(text.embedding._TokenEmbedding):
+        # 33 bytes.
         pretrained_file_name_sha1 = \
             {'embedding_test.vec': '29b9a6511cf4b5aae293c44a9ec1365b74f2a2f8'}
         namespace = 'test'
 
-        def __init__(self, embedding_root='embeddings',
-                     init_unknown_vec=nd.zeros, **kwargs):
+        def __init__(self, embedding_root='embeddings', 
init_unknown_vec=nd.zeros, **kwargs):
             pretrained_file_name = 'embedding_test.vec'
             Test._check_pretrained_file_names(pretrained_file_name)
 
             super(Test, self).__init__(**kwargs)
 
-            pretrained_file_path = Test._get_pretrained_file(
-                embedding_root, pretrained_file_name)
+            pretrained_file_path = Test._get_pretrained_file(embedding_root, 
pretrained_file_name)
 
             self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
 
-    test_embed = text.embedding.TokenEmbedding.create('test')
+    test_embed = text.embedding.create('test')
     assert test_embed.token_to_idx['hello'] == 1
     assert test_embed.token_to_idx['world'] == 2
-    assert_almost_equal(
-        test_embed.idx_to_vec[1].asnumpy(), (nd.arange(5) + 1).asnumpy())
-    assert_almost_equal(
-        test_embed.idx_to_vec[2].asnumpy(), (nd.arange(5) + 6).asnumpy())
-    assert_almost_equal(
-        test_embed.idx_to_vec[0].asnumpy(), nd.zeros((5,)).asnumpy())
+    assert_almost_equal(test_embed.idx_to_vec[1].asnumpy(), (nd.arange(5) + 
1).asnumpy())
+    assert_almost_equal(test_embed.idx_to_vec[2].asnumpy(), (nd.arange(5) + 
6).asnumpy())
+    assert_almost_equal(test_embed.idx_to_vec[0].asnumpy(), 
nd.zeros((5,)).asnumpy())
 
 
 def _mk_my_pretrain_file(path, token_delim, pretrain_file):
@@ -243,20 +238,17 @@ def test_custom_embed():
 
     # Test loaded unknown vectors.
     pretrain_file2 = 'my_pretrain_file2.txt'
-    _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim,
-                          pretrain_file2)
+    _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim, 
pretrain_file2)
     pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file2)
-    my_embed2 = text.embedding.CustomEmbedding(
-        pretrain_file_path, elem_delim, init_unknown_vec=nd.ones,
-        unknown_token='<unk>')
+    my_embed2 = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim,
+                                               init_unknown_vec=nd.ones, 
unknown_token='<unk>')
     unk_vec2 = my_embed2.get_vecs_by_tokens('<unk>')
     assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
     unk_vec2 = my_embed2.get_vecs_by_tokens('<unk$unk@unk>')
     assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
 
-    my_embed3 = text.embedding.CustomEmbedding(
-        pretrain_file_path, elem_delim,init_unknown_vec=nd.ones,
-        unknown_token='<unk1>')
+    my_embed3 = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim,
+                                               init_unknown_vec=nd.ones, 
unknown_token='<unk1>')
     unk_vec3 = my_embed3.get_vecs_by_tokens('<unk1>')
     assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 
1.5]))
     unk_vec3 = my_embed3.get_vecs_by_tokens('<unk$unk@unk>')
@@ -270,144 +262,263 @@ def test_custom_embed():
     assertRaises(AssertionError, text.embedding.CustomEmbedding, 
pretrain_file_path, elem_delim)
 
     invalid_pretrain_file2 = 'invalid_pretrain_file2.txt'
-    _mk_my_invalid_pretrain_file2(os.path.join(embed_root, embed_name),
-                                  elem_delim, invalid_pretrain_file2)
-    pretrain_file_path = os.path.join(embed_root, embed_name,
-                                      invalid_pretrain_file2)
+    _mk_my_invalid_pretrain_file2(os.path.join(embed_root, embed_name), 
elem_delim,
+                                  invalid_pretrain_file2)
+    pretrain_file_path = os.path.join(embed_root, embed_name, 
invalid_pretrain_file2)
     assertRaises(AssertionError, text.embedding.CustomEmbedding, 
pretrain_file_path, elem_delim)
 
 
-def test_token_indexer():
+def test_vocabulary():
     counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
 
-    i1 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    assert len(i1) == 5
-    assert i1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 
'some_word$': 4}
-    assert i1.idx_to_token[1] == 'c'
-    assert i1.unknown_token == '<unk>'
-    assert i1.reserved_tokens is None
-
-    i2 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=2, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    assert len(i2) == 3
-    assert i2.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
-    assert i2.idx_to_token[1] == 'c'
-    assert i2.unknown_token == '<unk>'
-    assert i2.reserved_tokens is None
-
-    i3 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=100,
-                                   unknown_token='<unk>', reserved_tokens=None)
-    assert len(i3) == 1
-    assert i3.token_to_idx == {'<unk>': 0}
-    assert i3.idx_to_token[0] == '<unk>'
-    assert i3.unknown_token == '<unk>'
-    assert i3.reserved_tokens is None
-
-    i4 = text.indexer.TokenIndexer(counter, most_freq_count=2, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    assert len(i4) == 3
-    assert i4.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
-    assert i4.idx_to_token[1] == 'c'
-    assert i4.unknown_token == '<unk>'
-    assert i4.reserved_tokens is None
-
-    i5 = text.indexer.TokenIndexer(counter, most_freq_count=3, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    assert len(i5) == 4
-    assert i5.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3}
-    assert i5.idx_to_token[1] == 'c'
-    assert i5.unknown_token == '<unk>'
-    assert i5.reserved_tokens is None
-
-    i6 = text.indexer.TokenIndexer(counter, most_freq_count=100, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    assert len(i6) == 5
-    assert i6.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+    v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v1) == 5
+    assert v1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 
'some_word$': 4}
+    assert v1.idx_to_token[1] == 'c'
+    assert v1.unknown_token == '<unk>'
+    assert v1.reserved_tokens is None
+
+    v2 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v2) == 3
+    assert v2.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
+    assert v2.idx_to_token[1] == 'c'
+    assert v2.unknown_token == '<unk>'
+    assert v2.reserved_tokens is None
+
+    v3 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=100, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v3) == 1
+    assert v3.token_to_idx == {'<unk>': 0}
+    assert v3.idx_to_token[0] == '<unk>'
+    assert v3.unknown_token == '<unk>'
+    assert v3.reserved_tokens is None
+
+    v4 = text.vocab.Vocabulary(counter, most_freq_count=2, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v4) == 3
+    assert v4.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
+    assert v4.idx_to_token[1] == 'c'
+    assert v4.unknown_token == '<unk>'
+    assert v4.reserved_tokens is None
+
+    v5 = text.vocab.Vocabulary(counter, most_freq_count=3, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v5) == 4
+    assert v5.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3}
+    assert v5.idx_to_token[1] == 'c'
+    assert v5.unknown_token == '<unk>'
+    assert v5.reserved_tokens is None
+
+    v6 = text.vocab.Vocabulary(counter, most_freq_count=100, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v6) == 5
+    assert v6.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
                                'some_word$': 4}
-    assert i6.idx_to_token[1] == 'c'
-    assert i6.unknown_token == '<unk>'
-    assert i6.reserved_tokens is None
-
-    i7 = text.indexer.TokenIndexer(counter, most_freq_count=1, min_freq=2, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    assert len(i7) == 2
-    assert i7.token_to_idx == {'<unk>': 0, 'c': 1}
-    assert i7.idx_to_token[1] == 'c'
-    assert i7.unknown_token == '<unk>'
-    assert i7.reserved_tokens is None
-
-    assertRaises(AssertionError, text.indexer.TokenIndexer, counter, 
most_freq_count=None,
+    assert v6.idx_to_token[1] == 'c'
+    assert v6.unknown_token == '<unk>'
+    assert v6.reserved_tokens is None
+
+    v7 = text.vocab.Vocabulary(counter, most_freq_count=1, min_freq=2, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    assert len(v7) == 2
+    assert v7.token_to_idx == {'<unk>': 0, 'c': 1}
+    assert v7.idx_to_token[1] == 'c'
+    assert v7.unknown_token == '<unk>'
+    assert v7.reserved_tokens is None
+
+    assertRaises(AssertionError, text.vocab.Vocabulary, counter, 
most_freq_count=None,
                  min_freq=0, unknown_token='<unknown>', reserved_tokens=['b'])
 
-    assertRaises(AssertionError, text.indexer.TokenIndexer, counter, 
most_freq_count=None,
+    assertRaises(AssertionError, text.vocab.Vocabulary, counter, 
most_freq_count=None,
                  min_freq=1, unknown_token='<unknown>', reserved_tokens=['b', 
'b'])
 
-    assertRaises(AssertionError, text.indexer.TokenIndexer, counter, 
most_freq_count=None,
+    assertRaises(AssertionError, text.vocab.Vocabulary, counter, 
most_freq_count=None,
                  min_freq=1, unknown_token='<unknown>', reserved_tokens=['b', 
'<unknown>'])
 
-    i8 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1,
-                                   unknown_token='<unknown>', 
reserved_tokens=['b'])
-    assert len(i8) == 5
-    assert i8.token_to_idx == {'<unknown>': 0, 'b': 1, 'c': 2, 'a': 3, 
'some_word$': 4}
-    assert i8.idx_to_token[1] == 'b'
-    assert i8.unknown_token == '<unknown>'
-    assert i8.reserved_tokens == ['b']
-
-    i9 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=2, 
unknown_token='<unk>',
-                                   reserved_tokens=['b', 'a'])
-    assert len(i9) == 4
-    assert i9.token_to_idx == {'<unk>': 0, 'b': 1, 'a': 2, 'c': 3}
-    assert i9.idx_to_token[1] == 'b'
-    assert i9.unknown_token == '<unk>'
-    assert i9.reserved_tokens == ['b', 'a']
-
-    i10 = text.indexer.TokenIndexer(counter, most_freq_count=None, 
min_freq=100,
-                                    unknown_token='<unk>', 
reserved_tokens=['b', 'c'])
-    assert len(i10) == 3
-    assert i10.token_to_idx == {'<unk>': 0, 'b': 1, 'c': 2}
-    assert i10.idx_to_token[1] == 'b'
-    assert i10.unknown_token == '<unk>'
-    assert i10.reserved_tokens == ['b', 'c']
-
-    i11 = text.indexer.TokenIndexer(counter, most_freq_count=1, min_freq=2, 
unknown_token='<unk>',
-                                    reserved_tokens=['<pad>', 'b'])
-    assert len(i11) == 4
-    assert i11.token_to_idx == {'<unk>': 0, '<pad>': 1, 'b': 2, 'c': 3}
-    assert i11.idx_to_token[1] == '<pad>'
-    assert i11.unknown_token == '<unk>'
-    assert i11.reserved_tokens == ['<pad>', 'b']
-
-    i12 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=2, 
unknown_token='b',
-                                    reserved_tokens=['<pad>'])
-    assert len(i12) == 3
-    assert i12.token_to_idx == {'b': 0, '<pad>': 1, 'c': 2}
-    assert i12.idx_to_token[1] == '<pad>'
-    assert i12.unknown_token == 'b'
-    assert i12.reserved_tokens == ['<pad>']
-
-    i13 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=2, 
unknown_token='a',
-                                    reserved_tokens=['<pad>'])
-    assert len(i13) == 4
-    assert i13.token_to_idx == {'a': 0, '<pad>': 1, 'c': 2, 'b': 3}
-    assert i13.idx_to_token[1] == '<pad>'
-    assert i13.unknown_token == 'a'
-    assert i13.reserved_tokens == ['<pad>']
+    v8 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unknown>',
+                               reserved_tokens=['b'])
+    assert len(v8) == 5
+    assert v8.token_to_idx == {'<unknown>': 0, 'b': 1, 'c': 2, 'a': 3, 
'some_word$': 4}
+    assert v8.idx_to_token[1] == 'b'
+    assert v8.unknown_token == '<unknown>'
+    assert v8.reserved_tokens == ['b']
+
+    v9 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, 
unknown_token='<unk>',
+                               reserved_tokens=['b', 'a'])
+    assert len(v9) == 4
+    assert v9.token_to_idx == {'<unk>': 0, 'b': 1, 'a': 2, 'c': 3}
+    assert v9.idx_to_token[1] == 'b'
+    assert v9.unknown_token == '<unk>'
+    assert v9.reserved_tokens == ['b', 'a']
+
+    v10 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=100, 
unknown_token='<unk>',
+                                reserved_tokens=['b', 'c'])
+    assert len(v10) == 3
+    assert v10.token_to_idx == {'<unk>': 0, 'b': 1, 'c': 2}
+    assert v10.idx_to_token[1] == 'b'
+    assert v10.unknown_token == '<unk>'
+    assert v10.reserved_tokens == ['b', 'c']
+
+    v11 = text.vocab.Vocabulary(counter, most_freq_count=1, min_freq=2, 
unknown_token='<unk>',
+                                reserved_tokens=['<pad>', 'b'])
+    assert len(v11) == 4
+    assert v11.token_to_idx == {'<unk>': 0, '<pad>': 1, 'b': 2, 'c': 3}
+    assert v11.idx_to_token[1] == '<pad>'
+    assert v11.unknown_token == '<unk>'
+    assert v11.reserved_tokens == ['<pad>', 'b']
+
+    v12 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, 
unknown_token='b',
+                                reserved_tokens=['<pad>'])
+    assert len(v12) == 3
+    assert v12.token_to_idx == {'b': 0, '<pad>': 1, 'c': 2}
+    assert v12.idx_to_token[1] == '<pad>'
+    assert v12.unknown_token == 'b'
+    assert v12.reserved_tokens == ['<pad>']
+
+    v13 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=2, 
unknown_token='a',
+                                reserved_tokens=['<pad>'])
+    assert len(v13) == 4
+    assert v13.token_to_idx == {'a': 0, '<pad>': 1, 'c': 2, 'b': 3}
+    assert v13.idx_to_token[1] == '<pad>'
+    assert v13.unknown_token == 'a'
+    assert v13.reserved_tokens == ['<pad>']
 
     counter_tuple = Counter([('a', 'a'), ('b', 'b'), ('b', 'b'), ('c', 'c'), 
('c', 'c'), ('c', 'c'),
                              ('some_word$', 'some_word$')])
 
-    i14 = text.indexer.TokenIndexer(counter_tuple, most_freq_count=None, 
min_freq=1,
-                                    unknown_token=('<unk>', '<unk>'), 
reserved_tokens=None)
-    assert len(i14) == 5
-    assert i14.token_to_idx == {('<unk>', '<unk>'): 0, ('c', 'c'): 1, ('b', 
'b'): 2, ('a', 'a'): 3,
+    v14 = text.vocab.Vocabulary(counter_tuple, most_freq_count=None, 
min_freq=1,
+                                unknown_token=('<unk>', '<unk>'), 
reserved_tokens=None)
+    assert len(v14) == 5
+    assert v14.token_to_idx == {('<unk>', '<unk>'): 0, ('c', 'c'): 1, ('b', 
'b'): 2, ('a', 'a'): 3,
                                 ('some_word$', 'some_word$'): 4}
-    assert i14.idx_to_token[1] == ('c', 'c')
-    assert i14.unknown_token == ('<unk>', '<unk>')
-    assert i14.reserved_tokens is None
+    assert v14.idx_to_token[1] == ('c', 'c')
+    assert v14.unknown_token == ('<unk>', '<unk>')
+    assert v14.reserved_tokens is None
+
+
+def test_custom_embedding_with_vocabulary():
+    embed_root = 'embeddings'
+    embed_name = 'my_embed'
+    elem_delim = '\t'
+    pretrain_file = 'my_pretrain_file1.txt'
+
+    _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim, 
pretrain_file)
+
+    pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
+
+    counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+    v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=['<pad>'])
+
+    e1 = text.embedding.CustomEmbedding(pretrain_file_path, elem_delim, 
init_unknown_vec=nd.ones,
+                                        vocabulary=v1)
+
+    assert e1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4, 
'some_word$': 5}
+    assert e1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
+
+    assert_almost_equal(e1.idx_to_vec.asnumpy(),
+                        np.array([[1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [0.6, 0.7, 0.8, 0.9, 1],
+                                  [0.1, 0.2, 0.3, 0.4, 0.5],
+                                  [1, 1, 1, 1, 1]])
+                        )
+
+    assert e1.vec_len == 5
+    assert e1.reserved_tokens == ['<pad>']
+
+    assert_almost_equal(e1.get_vecs_by_tokens('c').asnumpy(),
+                        np.array([1, 1, 1, 1, 1])
+                        )
+
+    assert_almost_equal(e1.get_vecs_by_tokens(['c']).asnumpy(),
+                        np.array([[1, 1, 1, 1, 1]])
+                        )
+
+    assert_almost_equal(e1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
+                        np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+                                  [1, 1, 1, 1, 1]])
+                        )
+
+    assert_almost_equal(e1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
+                        np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+                                  [0.6, 0.7, 0.8, 0.9, 1]])
+                        )
+
+    assert_almost_equal(e1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
+                        np.array([[1, 1, 1, 1, 1],
+                                  [0.6, 0.7, 0.8, 0.9, 1]])
+                        )
+
+    assert_almost_equal(e1.get_vecs_by_tokens(['A', 'b'], 
lower_case_backup=True).asnumpy(),
+                        np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+                                  [0.6, 0.7, 0.8, 0.9, 1]])
+                        )
 
+    e1.update_token_vectors(['a', 'b'],
+                            nd.array([[2, 2, 2, 2, 2],
+                                      [3, 3, 3, 3, 3]])
+                            )
+
+    assert_almost_equal(e1.idx_to_vec.asnumpy(),
+                        np.array([[1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [3, 3, 3, 3, 3],
+                                  [2, 2, 2, 2, 2],
+                                  [1, 1, 1, 1, 1]])
+                        )
+
+    assertRaises(ValueError, e1.update_token_vectors, 'unknown$$$', 
nd.array([0, 0, 0, 0, 0]))
 
-def test_glossary_with_one_embed():
+    assertRaises(AssertionError, e1.update_token_vectors, '<unk>',
+                 nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
+
+    assertRaises(AssertionError, e1.update_token_vectors, '<unk>', 
nd.array([0]))
+
+    e1.update_token_vectors(['<unk>'], nd.array([0, 0, 0, 0, 0]))
+    assert_almost_equal(e1.idx_to_vec.asnumpy(),
+                        np.array([[0, 0, 0, 0, 0],
+                                  [1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [3, 3, 3, 3, 3],
+                                  [2, 2, 2, 2, 2],
+                                  [1, 1, 1, 1, 1]])
+                        )
+    e1.update_token_vectors(['<unk>'], nd.array([[10, 10, 10, 10, 10]]))
+    assert_almost_equal(e1.idx_to_vec.asnumpy(),
+                        np.array([[10, 10, 10, 10, 10],
+                                  [1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [3, 3, 3, 3, 3],
+                                  [2, 2, 2, 2, 2],
+                                  [1, 1, 1, 1, 1]])
+                        )
+    e1.update_token_vectors('<unk>', nd.array([0, 0, 0, 0, 0]))
+    assert_almost_equal(e1.idx_to_vec.asnumpy(),
+                        np.array([[0, 0, 0, 0, 0],
+                                  [1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [3, 3, 3, 3, 3],
+                                  [2, 2, 2, 2, 2],
+                                  [1, 1, 1, 1, 1]])
+                        )
+    e1.update_token_vectors('<unk>', nd.array([[10, 10, 10, 10, 10]]))
+    assert_almost_equal(e1.idx_to_vec.asnumpy(),
+                        np.array([[10, 10, 10, 10, 10],
+                                  [1, 1, 1, 1, 1],
+                                  [1, 1, 1, 1, 1],
+                                  [3, 3, 3, 3, 3],
+                                  [2, 2, 2, 2, 2],
+                                  [1, 1, 1, 1, 1]])
+                        )
+
+
+def test_composite_embedding_with_one_embedding():
     embed_root = 'embeddings'
     embed_name = 'my_embed'
     elem_delim = '\t'
@@ -422,14 +533,14 @@ def test_glossary_with_one_embed():
 
     counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
 
-    i1 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=['<pad>'])
-    g1 = text.glossary.Glossary(i1, my_embed)
+    v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=['<pad>'])
+    ce1 = text.embedding.CompositeEmbedding(v1, my_embed)
 
-    assert g1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4, 
'some_word$': 5}
-    assert g1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
+    assert ce1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 
4, 'some_word$': 5}
+    assert ce1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
 
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
@@ -438,43 +549,43 @@ def test_glossary_with_one_embed():
                                   [1, 1, 1, 1, 1]])
                         )
 
-    assert g1.vec_len == 5
-    assert g1.reserved_tokens == ['<pad>']
+    assert ce1.vec_len == 5
+    assert ce1.reserved_tokens == ['<pad>']
 
-    assert_almost_equal(g1.get_vecs_by_tokens('c').asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens('c').asnumpy(),
                         np.array([1, 1, 1, 1, 1])
                         )
 
-    assert_almost_equal(g1.get_vecs_by_tokens(['c']).asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens(['c']).asnumpy(),
                         np.array([[1, 1, 1, 1, 1]])
                         )
 
-    assert_almost_equal(g1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
                         np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
                                   [1, 1, 1, 1, 1]])
                         )
 
-    assert_almost_equal(g1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
                         np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
                                   [0.6, 0.7, 0.8, 0.9, 1]])
                         )
 
-    assert_almost_equal(g1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
                         np.array([[1, 1, 1, 1, 1],
                                   [0.6, 0.7, 0.8, 0.9, 1]])
                         )
 
-    assert_almost_equal(g1.get_vecs_by_tokens(['A', 'b'], 
lower_case_backup=True).asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens(['A', 'b'], 
lower_case_backup=True).asnumpy(),
                         np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
                                   [0.6, 0.7, 0.8, 0.9, 1]])
                         )
 
-    g1.update_token_vectors(['a', 'b'],
-                            nd.array([[2, 2, 2, 2, 2],
+    ce1.update_token_vectors(['a', 'b'],
+                             nd.array([[2, 2, 2, 2, 2],
                                       [3, 3, 3, 3, 3]])
-                            )
+                             )
 
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
@@ -483,15 +594,15 @@ def test_glossary_with_one_embed():
                                   [1, 1, 1, 1, 1]])
                         )
 
-    assertRaises(ValueError, g1.update_token_vectors, 'unknown$$$', 
nd.array([0, 0, 0, 0, 0]))
+    assertRaises(ValueError, ce1.update_token_vectors, 'unknown$$$', 
nd.array([0, 0, 0, 0, 0]))
 
-    assertRaises(AssertionError, g1.update_token_vectors, '<unk>',
+    assertRaises(AssertionError, ce1.update_token_vectors, '<unk>',
                  nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
 
-    assertRaises(AssertionError, g1.update_token_vectors, '<unk>', 
nd.array([0]))
+    assertRaises(AssertionError, ce1.update_token_vectors, '<unk>', 
nd.array([0]))
 
-    g1.update_token_vectors(['<unk>'], nd.array([0, 0, 0, 0, 0]))
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    ce1.update_token_vectors(['<unk>'], nd.array([0, 0, 0, 0, 0]))
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[0, 0, 0, 0, 0],
                                   [1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
@@ -499,8 +610,8 @@ def test_glossary_with_one_embed():
                                   [2, 2, 2, 2, 2],
                                   [1, 1, 1, 1, 1]])
                         )
-    g1.update_token_vectors(['<unk>'], nd.array([[10, 10, 10, 10, 10]]))
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    ce1.update_token_vectors(['<unk>'], nd.array([[10, 10, 10, 10, 10]]))
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[10, 10, 10, 10, 10],
                                   [1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
@@ -508,8 +619,8 @@ def test_glossary_with_one_embed():
                                   [2, 2, 2, 2, 2],
                                   [1, 1, 1, 1, 1]])
                         )
-    g1.update_token_vectors('<unk>', nd.array([0, 0, 0, 0, 0]))
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    ce1.update_token_vectors('<unk>', nd.array([0, 0, 0, 0, 0]))
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[0, 0, 0, 0, 0],
                                   [1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
@@ -517,8 +628,8 @@ def test_glossary_with_one_embed():
                                   [2, 2, 2, 2, 2],
                                   [1, 1, 1, 1, 1]])
                         )
-    g1.update_token_vectors('<unk>', nd.array([[10, 10, 10, 10, 10]]))
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    ce1.update_token_vectors('<unk>', nd.array([[10, 10, 10, 10, 10]]))
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[10, 10, 10, 10, 10],
                                   [1, 1, 1, 1, 1],
                                   [1, 1, 1, 1, 1],
@@ -528,7 +639,7 @@ def test_glossary_with_one_embed():
                         )
 
 
-def test_glossary_with_two_embeds():
+def test_composite_embedding_with_two_embeddings():
     embed_root = '.'
     embed_name = 'my_embed'
     elem_delim = '\t'
@@ -547,14 +658,14 @@ def test_glossary_with_two_embeds():
 
     counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
 
-    i1 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    g1 = text.glossary.Glossary(i1, [my_embed1, my_embed2])
+    v1 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    ce1 = text.embedding.CompositeEmbedding(v1, [my_embed1, my_embed2])
 
-    assert g1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 
'some_word$': 4}
-    assert g1.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
+    assert ce1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3, 
'some_word$': 4}
+    assert ce1.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
 
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
                                   [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
                                   [0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
@@ -563,22 +674,22 @@ def test_glossary_with_two_embeds():
                                   [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
                         )
 
-    assert g1.vec_len == 10
-    assert g1.reserved_tokens is None
-    assert_almost_equal(g1.get_vecs_by_tokens('c').asnumpy(),
+    assert ce1.vec_len == 10
+    assert ce1.reserved_tokens is None
+    assert_almost_equal(ce1.get_vecs_by_tokens('c').asnumpy(),
                         np.array([1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1])
                         )
 
-    assert_almost_equal(g1.get_vecs_by_tokens(['b', 'not_exist']).asnumpy(),
+    assert_almost_equal(ce1.get_vecs_by_tokens(['b', 'not_exist']).asnumpy(),
                         np.array([[0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
                                   [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
                         )
 
-    g1.update_token_vectors(['a', 'b'],
+    ce1.update_token_vectors(['a', 'b'],
                             nd.array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
                                       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])
                             )
-    assert_almost_equal(g1.idx_to_vec.asnumpy(),
+    assert_almost_equal(ce1.idx_to_vec.asnumpy(),
                         np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
                                   [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
                                   [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
@@ -601,10 +712,10 @@ def test_glossary_with_two_embeds():
     my_embed4 = text.embedding.CustomEmbedding(pretrain_file_path4, elem_delim,
                                                unknown_token='<unk2>')
 
-    i2 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
-                                   reserved_tokens=None)
-    g2 = text.glossary.Glossary(i2, [my_embed3, my_embed4])
-    assert_almost_equal(g2.idx_to_vec.asnumpy(),
+    v2 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk>',
+                               reserved_tokens=None)
+    ce2 = text.embedding.CompositeEmbedding(v2, [my_embed3, my_embed4])
+    assert_almost_equal(ce2.idx_to_vec.asnumpy(),
                         np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
                                    0.11, 0.12, 0.13, 0.14, 0.15],
                                   [1.1, 1.2, 1.3, 1.4, 1.5,
@@ -617,10 +728,10 @@ def test_glossary_with_two_embeds():
                                    0.11, 0.12, 0.13, 0.14, 0.15]])
                         )
 
-    i3 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1,
-                                   unknown_token='<unk1>', 
reserved_tokens=None)
-    g3 = text.glossary.Glossary(i3, [my_embed3, my_embed4])
-    assert_almost_equal(g3.idx_to_vec.asnumpy(),
+    v3 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk1>',
+                               reserved_tokens=None)
+    ce3 = text.embedding.CompositeEmbedding(v3, [my_embed3, my_embed4])
+    assert_almost_equal(ce3.idx_to_vec.asnumpy(),
                         np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
                                    0.11, 0.12, 0.13, 0.14, 0.15],
                                   [1.1, 1.2, 1.3, 1.4, 1.5,
@@ -633,10 +744,10 @@ def test_glossary_with_two_embeds():
                                    0.11, 0.12, 0.13, 0.14, 0.15]])
                         )
 
-    i4 = text.indexer.TokenIndexer(counter, most_freq_count=None, min_freq=1,
-                                   unknown_token='<unk2>', 
reserved_tokens=None)
-    g4 = text.glossary.Glossary(i4, [my_embed3, my_embed4])
-    assert_almost_equal(g4.idx_to_vec.asnumpy(),
+    v4 = text.vocab.Vocabulary(counter, most_freq_count=None, min_freq=1, 
unknown_token='<unk2>',
+                               reserved_tokens=None)
+    ce4 = text.embedding.CompositeEmbedding(v4, [my_embed3, my_embed4])
+    assert_almost_equal(ce4.idx_to_vec.asnumpy(),
                         np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
                                    0.11, 0.12, 0.13, 0.14, 0.15],
                                   [1.1, 1.2, 1.3, 1.4, 1.5,
@@ -651,12 +762,12 @@ def test_glossary_with_two_embeds():
 
     counter2 = Counter(['b', 'b', 'c', 'c', 'c', 'some_word$'])
 
-    i5 = text.indexer.TokenIndexer(counter2, most_freq_count=None, min_freq=1, 
unknown_token='a',
-                                   reserved_tokens=None)
-    g5 = text.glossary.Glossary(i5, [my_embed3, my_embed4])
-    assert g5.token_to_idx == {'a': 0, 'c': 1, 'b': 2, 'some_word$': 3}
-    assert g5.idx_to_token == ['a', 'c', 'b', 'some_word$']
-    assert_almost_equal(g5.idx_to_vec.asnumpy(),
+    v5 = text.vocab.Vocabulary(counter2, most_freq_count=None, min_freq=1, 
unknown_token='a',
+                               reserved_tokens=None)
+    ce5 = text.embedding.CompositeEmbedding(v5, [my_embed3, my_embed4])
+    assert ce5.token_to_idx == {'a': 0, 'c': 1, 'b': 2, 'some_word$': 3}
+    assert ce5.idx_to_token == ['a', 'c', 'b', 'some_word$']
+    assert_almost_equal(ce5.idx_to_vec.asnumpy(),
                         np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
                                    0.11, 0.12, 0.13, 0.14, 0.15],
                                   [1.1, 1.2, 1.3, 1.4, 1.5,
@@ -668,21 +779,18 @@ def test_glossary_with_two_embeds():
                         )
 
 
-def test_get_embedding_names_and_pretrain_files():
-    assert 
len(text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names(
+def test_get_and_pretrain_file_names():
+    assert len(text.embedding.get_pretrained_file_names(
         embedding_name='fasttext')) == 294
 
-    assert 
len(text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names(
-        embedding_name='glove')) == 10
+    assert 
len(text.embedding.get_pretrained_file_names(embedding_name='glove')) == 10
 
-    reg = 
text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names(
-        embedding_name=None)
+    reg = text.embedding.get_pretrained_file_names(embedding_name=None)
 
     assert len(reg['glove']) == 10
     assert len(reg['fasttext']) == 294
 
-    assertRaises(KeyError, 
text.embedding.TokenEmbedding.get_embedding_and_pretrained_file_names,
-                 'unknown$$')
+    assertRaises(KeyError, text.embedding.get_pretrained_file_names, 
'unknown$$')
 
 
 if __name__ == '__main__':

-- 
To stop receiving notification emails like this one, please contact
zhash...@apache.org.

Reply via email to