For machine learning in particular deep learning this is how I do it:

1\. I use inheritance, but only to be able to store the objects in the same 
container, **never for dispatch**. Example on neural network layer, called 
`Gate`, re-using the terminology from Andrej Karpathy's [Hacker's Guide to 
Neural Network](http://karpathy.github.io/neuralnets/)

[https://github.com/mratsim/Arraymancer/blob/5b24877b/src/autograd/autograd_common.nim#L72-L85](https://github.com/mratsim/Arraymancer/blob/5b24877b/src/autograd/autograd_common.nim#L72-L85)
    
    
    type
      Gate*[TT] = ref object of RootObj # {.acyclic.}
        ## Base operator or layer. You can describe your custom operations or 
layers
        ## by inheriting from Gate and add a forward and optionally a backward 
method.
        ## Each operations should set the number of gradients produced during 
backpropagation.
        ## Additional fields specific to the operations like weights or inputs 
cache should be added too.
      
      PayloadKind* = enum
        pkVar, pkSeq
      Payload*[TT] = object
        case kind*: PayloadKind
        of pkVar: variable*: Variable[TT]
        of pkSeq: sequence*: seq[Variable[TT]]
      
      Backward*[TT] = proc(self: Gate[TT], payload: Payload[TT]): 
SmallDiffs[TT] {.nimcall.}
    
    
    Run

2\. For dispatching I associate a compile-time proc for each kind of `Gate` I 
have. The `Gate` carry the state I need to pass. The compile-time proc just 
unwrap it and pass it to the 
_[real](https://forum.nim-lang.org/postActivity.xml#real) which has the proper 
signature.

For example for a MaxPool layer, the backward proc should have for signature. 
[https://github.com/mratsim/Arraymancer/blob/5b24877b/src/nn_primitives/nnp_maxpooling.nim#L70-L74](https://github.com/mratsim/Arraymancer/blob/5b24877b/src/nn_primitives/nnp_maxpooling.nim#L70-L74)
 So how to map it to `proc(self: Gate[TT], payload: Payload[TT]): 
SmallDiffs[TT] {.nimcall.}`?

The MaxPoolGate and backward shim does the trick 
[https://github.com/mratsim/Arraymancer/blob/5b24877b/src/nn/layers/maxpool2D.nim#L20-L46](https://github.com/mratsim/Arraymancer/blob/5b24877b/src/nn/layers/maxpool2D.nim#L20-L46),
 `maxpool2D_backward_ag` stands for autograd version
    
    
    type MaxPool2DGate*[TT] {.final.} = ref object of Gate[TT]
      cached_input_shape: MetadataArray
      cached_max_indices: Tensor[int]
      kernel, padding, stride: Size2D
    
    proc maxpool2D_backward_ag[TT](self: MaxPool2DGate[TT], payload: 
Payload[TT]): SmallDiffs[TT] =
      let gradient = payload.variable.grad
      result = newDiffs[TT](1)
      result[0] = maxpool2d_backward(
        self.cached_input_shape,
        self.cached_max_indices,
        gradient
      )
    
    
    Run

So I get compile-time dispatch on arbitrary state with a fixed interface. And 
even though the wrapper proc does not much, 2 static function calls are always 
faster than a closure or a method.

Other example, convolution, this may seem more complex as the backward 
operation has a lot of inputs: 
[https://github.com/mratsim/Arraymancer/blob/5b24877b/src/nn_primitives/nnp_convolution.nim#L65-L70](https://github.com/mratsim/Arraymancer/blob/5b24877b/src/nn_primitives/nnp_convolution.nim#L65-L70)
    
    
    proc conv2d_backward*[T](input, weight, bias: Tensor[T],
                             padding: Size2D,
                             stride: Size2D,
                             grad_output: Tensor[T],
                             grad_input, grad_weight, grad_bias: var Tensor[T],
                             algorithm = Conv2DAlgorithm.Im2ColGEMM)
    
    
    Run

But it's actually straightforward as well
    
    
    type Conv2DGate*[TT]{.final.} = ref object of Gate[TT]
      cached_input: Variable[TT]
      weight, bias: Variable[TT]
      padding, stride: Size2D
      # TODO: store the algorithm (NNPACK / im2col)
    
    proc conv2d_backward_ag[TT](self: Conv2DGate[TT], payload: Payload[TT]): 
SmallDiffs[TT] =
      let gradient = payload.variable.grad
      if self.bias.isNil:
        result = newDiffs[TT](2)
      else:
        result = newDiffs[TT](3)
      conv2d_backward(
        self.cached_input.value,
        self.weight.value, self.bias.value,
        self.padding, self.stride,
        gradient,
        result[0], result[1], result[2]
      )
    
    
    Run

Reply via email to