This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 4c1277f DL: modify the logic to get number of output units from
model_arch table
4c1277f is described below
commit 4c1277fbb7efc401704cedfc143d01b559e93fc3
Author: Jingyi Mei <[email protected]>
AuthorDate: Tue May 28 12:21:28 2019 -0700
DL: modify the logic to get number of output units from model_arch table
Previously, we assumed that the last layer in the model architecture
contains the num_classes (units), which is not necessarily true.
An example can be:
```
...
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))
```
where activation goes after dense layer and thus we can't get
num_classes from the last layer. The get_num_classes() would fail to
get the units for such architectures.
This commits makes the change to use 'units' from the first layer from the
end of the architecture that defines it.
Closes #400
Co-authored-by: Nandish Jayaram <[email protected]>
---
src/ports/postgres/modules/utilities/model_arch_info.py_in | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/ports/postgres/modules/utilities/model_arch_info.py_in
b/src/ports/postgres/modules/utilities/model_arch_info.py_in
index 765aed7..a03594a 100644
--- a/src/ports/postgres/modules/utilities/model_arch_info.py_in
+++ b/src/ports/postgres/modules/utilities/model_arch_info.py_in
@@ -42,8 +42,11 @@ def get_input_shape(model_arch):
def get_num_classes(model_arch):
arch_layers = _get_layers(model_arch)
- if 'units' in arch_layers[-1]['config']:
- return arch_layers[-1]['config']['units']
+ i = len(arch_layers) - 1
+ while i >= 0:
+ if 'units' in arch_layers[i]['config']:
+ return arch_layers[i]['config']['units']
+ i -= 1
plpy.error('Unable to get number of classes from model architecture.')
def get_model_arch_layers_str(model_arch):