Peng Cheng created MXNET-1437:
---------------------------------
Summary: Shape-safe Scala/Kotlin/JVM autograd API
Key: MXNET-1437
URL: https://issues.apache.org/jira/browse/MXNET-1437
Project: Apache MXNet
Issue Type: New Feature
Components: Apache MXNet Scala API, Gluon
Reporter: Peng Cheng
The purpose of this feature is to delegate the mental burden of manually
capturing invalid tensor functions caused by incoherent tensor shapes & data
types from programmers to language compilers as much as possible. Typical use
cases and advantages are:
1. Type & shape safety for tensor operands that may have either deterministic
or partially specified shapes, e.g.
- Tensor[float, (2, 3)] * Tensor[float, (4, 5)] will trigger a pre-execution
error
- Tensor[float-cpu, (2, 3)] * Tensor[float-gpu, (3, 4)] will also trigger a
pre-execution error
- Tensor[float, (2, 3)] * Tensor[float, (?, 4)] -> Tensor[float, (2, 4)]
- Tensor[float, (2, 3)] * Tensor[float, (3, ?)] -> Tensor[float, (2, ?)]
- det(Tensor[float, (3,3)]) -> Tensor[float, ()]
- * complex dependent type algebra like pooling(Tensor[double, (H, W)]) ->
Tensor[double, (H/2, W/2)] is not in the scope of this feature
2. First-class support for named tensors and structured/product types, e.g.
- Tensor[int, (1, 3)] * Tensor[int, (2, 3)] will trigger an pre-execution error
- but Tensor[int, (1, xyz: 3)] * Tensor[int, (proj: 2, xyz: 3)] -> Tensor[int,
(1, proj: 2)]
- loss(img: Tensor[int, (64, 64)], label: Tensor[Int, ()]) can only be applied
to input with such product type, otherwise will trigger an pre-execution error
3. Matrix or Tensor dimension of length 1 can be implicitly squeezed, e.g.:
- Tensor[int, (1, 3)] * Tensor[int, (3, 1)] > 0 is a valid condition
- Tensor[int, (1, 3)] * Tensor[int, (3, 2)] > 0 is not and will trigger an
pre-execution error
4. Higher-order functions and blocks can be defined with more constraints, e.g.:
- the last block of regression can be defined as Tensor[float, (?, ?)] ->
Tensor[float, ()], which accepts only blocks with scalar output
These requirements are hard to implement on existing python based API: The
python language is not dependently typed, functional or have native support for
meta-programming. As a result, enforcement of shape safety at compile-time is
almost impossible, and run-time validation can only be done by either compiling
python code into an intermediate representation (e.g. torchscript, RelayIR)
using reflection or requires extensive rewriting of a static type checker (e.g.
mypy), both of which are not easy.
It has been demonstrated that some JVM languages with even incomplete support
for dependent typing can approximate some of the above much easier:
- https://openreview.net/forum?id=SkluMSZ08H
- https://arxiv.org/abs/1803.10228
Specifically:
- types that depends on arity in 1 can be constructed by combining literal type
of scala 2.13+ or shapeless with recursive abstract typing
- 3 can benefit from implicit type conversion of scala
- class extensions of kotlin and implicit type conversion/magnet pattern of
scala allows concise and readable definition of differentiable functions
- first-class support for dependent type will become available in scala 3.0
None of the above implementations requires advanced language features like
macro, polyglot or compiler hack. These combined with the fact that JVM is very
reflective and have mature support for meta-programming can make them more
suitable API for mega-engineering (and arguabl also for prototyping)
--
This message was sent by Atlassian Jira
(v8.3.4#803005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]