electriclilies opened a new pull request #10352:
URL: https://github.com/apache/tvm/pull/10352
This PR is just one part of the design I outline below.
# Background
Currently, users can write virtual devices into the program by specifying
the `function_result_virtual_device`, and the `function_param_virtual_devices`
in the function attributes. Users can also introduce annotations using the
`on_device` op, mostly on function arguments, but sometimes on let-bound
variables.
Device planning solves the constraints created by the user-specified virtual
devices— for each sub-expression, it determines a unique and consistent virtual
device. We’ll call this the **complete representation.**
# Motivation
Currently, function parameter virtual devices and the function result
virtual device is represented through attributes on the function. Since there
is already a text format for attributes on functions, we can represent the
function parameter virtual devices and the function result virtual device in
the text format by just putting them in the attributes— we get the text
representation automatically.
However, once we move the function virtual devices out of the attributes, we
no longer get the text representation for free. The most immediate challenge is
that the unit tests for the device planner are written in RelayScript— so
without rethinking the text format for virtual devices, we can’t run the unit
tests.
A second motivation is that the current text representation is clunky— for
the user to specify the virtual device of a subexpr, they need to wrap that
expression in an `on_device` op. In the current implementation, let-bound
variables and also arguments to functions must be wrapped in `on_device` .
While I am rethinking the text format, I’d like to rethink this as well.
# Goal
The goal for the text representation is a RelayScript program that preserves
all virtual device information from device planning in a **minimal
representation**. The minimal representation is the least amount of information
(subexprs assigned virtual devices) we need to reconstruct the device planned
program using simple lexical scoping rules.
Additionally, I want to be able to reconstruct the **complete
representation** from the **minimal representation** without using device
planning itself. This will let us express the expected result of running
PlanDevices in RelayScript without running PlanDevices itself.
# Proposed design
In general, I propose removing the `on_device` op from the RelayScript
representation, and simplifying the way function virtual devices are
represented in text. I will
1) formalize the minimal representation
2) introduce syntax for the critical virtual devices
3) introduce a pass that uses simple, lexical scoping rules to expand the
minimal representation into the complete representation
## The minimal representation of device planning information
The current text format for representing the device planned program in
RelayScript preserves the virtual devices for
1) function parameters and the function result
2) let-bound variables
3) inputs to `device_copy`
For our minimal representation, we don’t need 3, since the virtual devices
of the inputs of `device_copy` must agree with the virtual devices specified in
the `device_copy` op itself.
In this example from the current text representation, the virtual device
specified in `%1` is the same as the source virtual device in `%2`. We can
remove the `on_device` op and not lose any information.
```python
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z:
Tensor[(5, 7), float32],
param_virtual_devices=[meta[VirtualDevice][1],
meta[VirtualDevice][1], meta[VirtualDevice][0]],
result_virtual_device=meta[VirtualDevice][0]) {
%0 = add(%x, %y);
%1 = on_device(%0, virtual_device=meta[VirtualDevice][1],
constrain_result=True);
%2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1],
dst_virtual_device=meta[VirtualDevice][0]);
subtract(%2, %z)
}
```
So, we’ll define our minimal representation to consist of:
1. The virtual devices of variable bindings (namely, function parameters and
let-bound variables)
2. The virtual device of function results
We’ll call these the **critical virtual devices.**
## Syntax for the critical virtual devices
Piggy-backing off the current text representation, we’d like to represent
the virtual devices for let-bound variables and function parameters directly
after the variable definition in a structure that looks like attributes.
Here is an example:
```python
def @main(%x: Tensor[(3, 3, 4)], float32]
{virtual_device=meta[VirtualDevice][0]},
virtual_device=meta[VirtualDevice][1] /* result virtual device*/) {
%0 = split(%x, indices_or_sections=3);
let %t {virtual_device=meta[VirtualDevice][0]} = %0;
%2 = %t.1;
%3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0],
dst_virtual_device=meta[VirtualDevice][1]);
%4 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0],
dst_virtual_device=meta[VirtualDevice][1]);
subtract(%3, %4)
}
```
The function result virtual device is represented directly in the function’s
attributes, just like in the current implementation. We are still using the
`meta` syntax to represent the value of the virtual device function. In the
parser, we’ll “promote” the function’s `virtual_device` attribute to
first-class by setting the `virtual_device_` field of the function. We won’t
put the virtual device in the attributes of the function.
For variable definitions, the virtual device will be represented directly
after the type annotation of the variable in text that looks like attributes.
We won’t actually add attributes to variable definitions. However, by using the
same syntax as attributes, we can reuse the utilities in the parser for parsing
attributes to parse the virtual device. (If there are fields other than
`virtual_device` in the fake attributes, the parser will fail). An advantage to
this approach is that if we do want to add attributes to bound variables in the
future, we don’t need to change our syntax at all.
## Expansion of minimal representation / propagation of ‘critical’ virtual
devices
We’ll introduce a new pass, called DPL (”device plan lite”), which
propagates the ‘critical’ virtual devices and `device_copy` virtual devices so
that every subexpr’s `virtual_device_` field is populated. This pass will
follow simple lexical propagation rules; if it finds a ‘critical’ virtual
device that is not set, it will fail.
Note that in the current implementation, the minimal representation is not
expanded into the complete representation at all— rather, during traversal in
the DeviceAwareVisitExpr, the visitor keeps track of what the current virtual
device is using simple lexical scoping rules. Every time you traverse the
program, you must recompute the virtual devices of all the subexpressions. With
DPL, we only have to do the propagation once, and we can get rid of
DeviceAwareVisitExpr completely.
## Tests in device planning
The test cases will use the new Relay Script syntax and the DPL pass to test
device planning.
Let `input` be the input program (in text format), containing `on_device`
ops and `device_copy` ops, and `expected` be our expected output program (in
text format), which has the virtual device information for every ‘critical’
virtual device.
Then, let `complete_output = DP(parse(input))` , where `complete_output` is
a fully device planned program with all the `virtual_device_` fields propagated
(the complete representation).
Now, let `minimal_expected = parse(expected)` be the minimal representation
of the fully device planned program. `minimal_expected` doesn’t have all the
`virtual_device_` fields propagated, but it does contain enough information to
completely reconstruct the device planned program through simple lexical rules.
We then use DPL (”device plan lite”) to recreate the original program.
Finally, we can check that `complete_output == DPL(minimal_expected)`. It is
also true that `parse(print(complete_output)) == minimal_expected`.
Note that the DPL pass will also be useful for reconstructing any virtual
device information that is removed or not propagated correctly by some other
Relay pass. We expect that this may occasionally happen. As long as the
‘critical’ virtual devices are preserved, we can run DPL to get the complete
representation.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]