Peng Cheng created MXNET-1437:
---------------------------------

             Summary: Shape-safe Scala/Kotlin/JVM autograd API
                 Key: MXNET-1437
                 URL: https://issues.apache.org/jira/browse/MXNET-1437
             Project: Apache MXNet
          Issue Type: New Feature
          Components: Apache MXNet Scala API, Gluon
            Reporter: Peng Cheng


The purpose of this feature is to delegate the mental burden of manually 
capturing invalid tensor functions caused by incoherent tensor shapes & data 
types from programmers to language compilers as much as possible. Typical use 
cases and advantages are:

1. Type & shape safety for tensor operands that may have either deterministic 
or partially specified shapes, e.g.

- Tensor[float, (2, 3)] * Tensor[float, (4, 5)] will trigger a pre-execution 
error
 - Tensor[float-cpu, (2, 3)] * Tensor[float-gpu, (3, 4)] will also trigger a 
pre-execution error
 - Tensor[float, (2, 3)] * Tensor[float, (?, 4)] -> Tensor[float, (2, 4)]
 - Tensor[float, (2, 3)] * Tensor[float, (3, ?)] -> Tensor[float, (2, ?)]
 - det(Tensor[float, (3,3)]) -> Tensor[float, ()]
 - * complex dependent type algebra like pooling(Tensor[double, (H, W)]) -> 
Tensor[double, (H/2, W/2)] is not in the scope of this feature

2. First-class support for named tensors and structured/product types, e.g.

- Tensor[int, (1, 3)] * Tensor[int, (2, 3)] will trigger an pre-execution error
 - but Tensor[int, (1, xyz: 3)] * Tensor[int, (proj: 2, xyz: 3)] -> Tensor[int, 
(1, proj: 2)]
 - loss(img: Tensor[int, (64, 64)], label: Tensor[Int, ()]) can only be applied 
to input with such product type, otherwise will trigger an pre-execution error

3. Matrix or Tensor dimension of length 1 can be implicitly squeezed, e.g.:

- Tensor[int, (1, 3)] * Tensor[int, (3, 1)] > 0 is a valid condition
 - Tensor[int, (1, 3)] * Tensor[int, (3, 2)] > 0 is not and will trigger an 
pre-execution error

4. Higher-order functions and blocks can be defined with more constraints, e.g.:

- the last block of regression can be defined as Tensor[float, (?, ?)] -> 
Tensor[float, ()], which accepts only blocks with scalar output

These requirements are hard to implement on existing python based API: The 
python language is not dependently typed, functional or have native support for 
meta-programming. As a result, enforcement of shape safety at compile-time is 
almost impossible, and run-time validation can only be done by either compiling 
python code into an intermediate representation (e.g. torchscript, RelayIR) 
using reflection or requires extensive rewriting of a static type checker (e.g. 
mypy), both of which are not easy.

It has been demonstrated that some JVM languages with even incomplete support 
for dependent typing can approximate some of the above much easier:

- https://openreview.net/forum?id=SkluMSZ08H
- https://arxiv.org/abs/1803.10228

Specifically:

- types that depends on arity in 1 can be constructed by combining literal type 
of scala 2.13+ or shapeless with recursive abstract typing
- 3 can benefit from implicit type conversion of scala
- class extensions of kotlin and implicit type conversion/magnet pattern of 
scala allows concise and readable definition of differentiable functions
- first-class support for dependent type will become available in scala 3.0

None of the above implementations requires advanced language features like 
macro, polyglot or compiler hack. These combined with the fact that JVM is very 
reflective and have mature support for meta-programming can make them more 
suitable API for mega-engineering (and arguabl also for prototyping)



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to