joddiy commented on issue #691:
URL: https://github.com/apache/singa/issues/691#issuecomment-629147694


   ## Conslusion first
   
   Good news: 
   > The ONNX can defines the loss and optimizer now within its format. 
However, current loss only have `NegativeLogLikelihoodLoss` and 
`SoftmaxCrossEntropyLoss`. Also, it only can store optimizers, only have - 
`Adagrad`, `Adam`, `Momentum`(SGD with standard momentum). 
   
   Bad news:
   > we need to update the onnx to 1.7, which is released last week, may not be 
so stable. In this release, ONNX defines a comlicated node called `GraphCall` 
to specify which gradients should be computed and how to update the tensors by 
using these gradients. Since we will update the weights following the backward, 
so this part may not be useful for us.
   
   ## ONNX Training Preview (TrainingInfoProto)
   
   In last week, the ONNX team has released a new version 
[1.7.0](https://github.com/onnx/onnx/releases/tag/v1.7.0) which upgrade its 
opset version to 12. In this new rleases, they add a new feature called 
[`TrainingInfoProto`](https://github.com/onnx/onnx/blob/3368834cf0b1f0ab9838cf6bdf78a27299d08187/onnx/onnx.proto3#L211-L316).
 
   
   This new feature defines something about training information. There are two 
main parts in it, `initialization-step` and `training-algorithm-step`.
   
   ### initialization-step
   
   `initialization-step` means the developer can defines a `initialization`. 
For its type, the `initialization` is a formal ONNX graph. It doesn't have 
input but seveal outputs. The developer can defines some nodes in this graph, 
such as `RandomNormal` or `RandomUniform`, and in another field called 
`initialization_binding`, the developer can assign these outputs to the 
specific tensors in the inference graph.
   
   The current supported ramdom methods are: `RandomNormal` or `RandomUniform`.
   
   ### training-algorithm-step
   
   `training-algorithm-step` defines a field called `algorithm`. It defines a 
inference graph which represents a training algorithm's step. Given required 
inputs, it computes outputs to update tensors in its own or in the main 
computaton graph. `update_binding` contains a key-value pair of strings to 
assign the outputs to some specific tensors.
   
   In general, this graph contains loss node, gradient node, optimizer node, 
increment of iteration count, and some calls to the inference graph. The field 
algorithm.node is the only place the user can use GraphCall operator. 
   
   #### Loss node
   
   - `NegativeLogLikelihoodLoss`
   - `SoftmaxCrossEntropyLoss`
   
   
   #### Optimizer node
   
   - `Adagrad`
   - `Adam`
   - `Momentum`: SG with standard momentum
   
   #### Gradient node
   
   The gradient node actually only defines the necessary information to compute 
the gradient for all graph, for example, at the following graph, the gradient 
defines its inputs containing the `xs`(intermidate weights) and `zs`(input of 
the graph), and `y`(the output of the graph), and its outputs having `dY/dW`, 
`dY/dZ` whose order corresponds to the inputs in `xs`. 
   
   It doesn't defines any logic about how to compute the `dY/dW`, `dY/dZ`.
   
   ```
   W --> Conv --> H --> Gemm --> Y
   |      ^              ^
   |      |              |
   |      X              Z
   |      |              |
   |      |   .----------'
   |      |   |  (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in
   |      |   |   "xs" followed by "zs")
   |      v   v
   '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y")
          |   |
          |   '-----------------------------------> dY/dW (1st output of 
Gradient)
          |
          '---------------------------------------> dY/dZ (2nd output of 
Gradient)
   ```
   
   #### GraphCall node
   
   The GraphCall operator invokes a graph inside TrainingInfoProto's algorithm 
field. The GraphCall inputs and outputs are bound to those of invoked graph by 
position.
   
   Based on the above inference graph, the GraphCall can use like this:
   
   ```
   .-------- W (a global and mutable variable from
   |         |  the inference graph)
   |         |
   |   .-----'-----------.
   |   |                 |
   |   |                 v
   |   | .-- X_1 --> GraphCall(graph_name="MyInferenceGraph")
   |   | |            |  |
   |   | |            |  |
   |   | |   Z_1 -----'  |
   |   | |    |          V
   |   | |    |         Y_1 ---> Loss ---> O
   |   | |    |                    ^
   |   | |    |                    |
   |   | `--. |                    C
   |   |    | |                    |
   |   |    | |   .----------------'
   |   |    | |   |
   |   |    v v   v
   |   `--> Gradient(xs=["W"], zs=["X_1", "Z_1", "C"], y="O")
   |        |
   |        v
   |      dO_dW (gradient of W)      1 (a scalar one)
   |        |                        |
   |        V                        v
   |       Div <--- T ------------> Add ---> T_new
   |        |    (T is the number of training iterations.
   |        |     T is also globally visible and mutable.)
   |        v
   `-----> Sub ----> W_new
   ```
   
   The previous section's inference graph is called by 
`GraphCall(graph_name="MyInferenceGraph")`, and it uses a new batch of inputs 
(`X_1`, `Z_1`) to compute `Y_1`. 
   
   `Gradient` defines the graidents the graph should compute, finally, it gets 
`W_new` amd `T_new`.
   
   The it uses the following `update_binding` to udpate the tensors:
   
   ```
   update_binding: {"W": "W_new", "T": "T_new"}
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to