This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c1277f  DL: modify the logic to get number of output units from 
model_arch table
4c1277f is described below

commit 4c1277fbb7efc401704cedfc143d01b559e93fc3
Author: Jingyi Mei <[email protected]>
AuthorDate: Tue May 28 12:21:28 2019 -0700

    DL: modify the logic to get number of output units from model_arch table
    
    Previously, we assumed that the last layer in the model architecture
    contains the num_classes (units), which is not necessarily true.
    An example can be:
    
    ```
    ...
    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes))
    model.add(Activation('softmax'))
    
    ```
    
    where activation goes after dense layer and thus we can't get
    num_classes from the last layer. The get_num_classes() would fail to
    get the units for such architectures.
    This commits makes the change to use 'units' from the first layer from the
    end of the architecture that defines it.
    
    Closes #400
    Co-authored-by: Nandish Jayaram <[email protected]>
---
 src/ports/postgres/modules/utilities/model_arch_info.py_in | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/src/ports/postgres/modules/utilities/model_arch_info.py_in 
b/src/ports/postgres/modules/utilities/model_arch_info.py_in
index 765aed7..a03594a 100644
--- a/src/ports/postgres/modules/utilities/model_arch_info.py_in
+++ b/src/ports/postgres/modules/utilities/model_arch_info.py_in
@@ -42,8 +42,11 @@ def get_input_shape(model_arch):
 
 def get_num_classes(model_arch):
     arch_layers = _get_layers(model_arch)
-    if 'units' in arch_layers[-1]['config']:
-        return arch_layers[-1]['config']['units']
+    i = len(arch_layers) - 1
+    while i >= 0:
+        if 'units' in arch_layers[i]['config']:
+            return arch_layers[i]['config']['units']
+        i -= 1
     plpy.error('Unable to get number of classes from model architecture.')
 
 def get_model_arch_layers_str(model_arch):

Reply via email to