electriclilies opened a new pull request #10352:
URL: https://github.com/apache/tvm/pull/10352


   This PR is just one part of the design I outline below.
   
   
   # Background
   
   Currently, users can write virtual devices into the program by specifying 
the `function_result_virtual_device`, and the `function_param_virtual_devices` 
in the function attributes. Users can also introduce annotations using the 
`on_device` op, mostly on function arguments, but sometimes on let-bound 
variables. 
   
   Device planning solves the constraints created by the user-specified virtual 
devices— for each sub-expression, it determines a unique and consistent virtual 
device. We’ll call this the **complete representation.**
   
   # Motivation
   
   Currently, function parameter virtual devices and the function result 
virtual device is represented through attributes on the function. Since there 
is already a text format for attributes on functions, we can represent the 
function parameter virtual devices and the function result virtual device in 
the text format by just putting them in the attributes— we get the text 
representation automatically. 
   
   However, once we move the function virtual devices out of the attributes, we 
no longer get the text representation for free. The most immediate challenge is 
that the unit tests for the device planner are written in RelayScript— so 
without rethinking the text format for virtual devices, we can’t run the unit 
tests. 
   
   A second motivation is that the current text representation is clunky— for 
the user to specify the virtual device of a subexpr, they need to wrap that 
expression in an `on_device` op. In the current implementation, let-bound 
variables and also arguments to functions must be wrapped in `on_device` . 
While I am rethinking the text format, I’d like to rethink this as well. 
   
   # Goal
   
   The goal for the text representation is a RelayScript program that preserves 
all virtual device information from device planning in a **minimal 
representation**. The minimal representation is the least amount of information 
(subexprs assigned virtual devices) we need to reconstruct the device planned 
program using simple lexical scoping rules. 
   
   Additionally, I want to be able to reconstruct the **complete 
representation** from the **minimal representation** without using device 
planning itself. This will let us express the expected result of running 
PlanDevices in RelayScript without running PlanDevices itself. 
   
   # Proposed design
   
   In general, I propose removing the `on_device` op from the RelayScript 
representation, and simplifying the way function virtual devices are 
represented in text. I will
   
   1) formalize the minimal representation
   
   2) introduce syntax for the critical virtual devices
   
   3)  introduce a pass that uses simple, lexical scoping rules to expand the 
minimal representation into the complete representation
   
   ## The minimal representation of device planning information
   
   The current text format for representing the device planned program in 
RelayScript preserves the virtual devices for
   
   1) function parameters and the function result
   
   2) let-bound variables
   
   3) inputs to `device_copy`
   
   For our minimal representation, we don’t need 3, since the virtual devices 
of the inputs of `device_copy` must agree with the virtual devices specified in 
the `device_copy` op itself.
   
   In this example from the current text representation, the virtual device 
specified in `%1`  is the same as the source virtual device in `%2`. We can 
remove the `on_device` op and not lose any information.
   
   ```python
   def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: 
Tensor[(5, 7), float32],
             param_virtual_devices=[meta[VirtualDevice][1], 
meta[VirtualDevice][1], meta[VirtualDevice][0]],
             result_virtual_device=meta[VirtualDevice][0]) {
        %0 = add(%x, %y);
        %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], 
constrain_result=True);
        %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], 
dst_virtual_device=meta[VirtualDevice][0]);
        subtract(%2, %z)
   }
   ```
   
   So, we’ll define our minimal representation to consist of:
   
   1. The virtual devices of variable bindings (namely, function parameters and 
let-bound variables)
   2. The virtual device of function results
   
   We’ll call these the **critical virtual devices.**
   
   ## Syntax for the critical virtual devices
   
   Piggy-backing off the current text representation, we’d like to represent 
the virtual devices for let-bound variables and function parameters directly 
after the variable definition in a structure that looks like attributes. 
   
   Here is an example:
   
   ```python
   def @main(%x: Tensor[(3, 3, 4)], float32] 
{virtual_device=meta[VirtualDevice][0]},
             virtual_device=meta[VirtualDevice][1] /* result virtual device*/) {
       %0 = split(%x, indices_or_sections=3);
       let %t {virtual_device=meta[VirtualDevice][0]} = %0;
       %2 = %t.1;
       %3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], 
dst_virtual_device=meta[VirtualDevice][1]);
       %4 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], 
dst_virtual_device=meta[VirtualDevice][1]);
          subtract(%3, %4)
   }
   ```
   
   The function result virtual device is represented directly in the function’s 
attributes, just like in the current implementation. We are still using the 
`meta` syntax to represent the value of the virtual device function. In the 
parser, we’ll “promote” the function’s `virtual_device` attribute to 
first-class by setting the `virtual_device_` field of the function. We won’t 
put the virtual device in the attributes of the function.
   
   For variable definitions, the virtual device will be represented directly 
after the type annotation of the variable in text that looks like attributes. 
We won’t actually add attributes to variable definitions. However, by using the 
same syntax as attributes, we can reuse the utilities in the parser for parsing 
attributes to parse the virtual device. (If there are fields other than 
`virtual_device` in the fake attributes, the parser will fail). An advantage to 
this approach is that if we do want to add attributes to bound variables in the 
future, we don’t need to change our syntax at all. 
   
   ## Expansion of minimal representation / propagation of ‘critical’ virtual 
devices
   
   We’ll introduce a new pass, called DPL (”device plan lite”), which 
propagates the ‘critical’ virtual devices and `device_copy` virtual devices so 
that every subexpr’s `virtual_device_` field is populated. This pass will 
follow simple lexical propagation rules; if it finds a ‘critical’ virtual 
device that is not set, it will fail. 
   
   Note that in the current implementation, the minimal representation is not 
expanded into the complete representation at all— rather, during traversal in 
the DeviceAwareVisitExpr, the visitor keeps track of what the current virtual 
device is using simple lexical scoping rules. Every time you traverse the 
program, you must recompute the virtual devices of all the subexpressions. With 
DPL, we only have to do the propagation once, and we can get rid of 
DeviceAwareVisitExpr completely.
   
   ## Tests in device planning
   
   The test cases will use the new Relay Script syntax and the DPL pass to test 
device planning.
   
   Let `input` be the input program (in text format), containing `on_device` 
ops and `device_copy` ops, and `expected` be our expected output program (in 
text format), which has the virtual device information for every ‘critical’ 
virtual device.
   
   Then, let  `complete_output = DP(parse(input))` , where `complete_output` is 
a fully device planned program with all the `virtual_device_` fields propagated 
(the complete representation).
   
   Now, let `minimal_expected = parse(expected)` be the minimal representation 
of the fully device planned program. `minimal_expected` doesn’t have all the 
`virtual_device_` fields propagated, but it does contain enough information to 
completely reconstruct the device planned program through simple lexical rules. 
We then use DPL (”device plan lite”) to recreate the original program.
   
   Finally, we can check that `complete_output == DPL(minimal_expected)`. It is 
also true that `parse(print(complete_output)) == minimal_expected`.
   
   Note that the DPL pass will also be useful for reconstructing any virtual 
device information that is removed or not propagated correctly by some other 
Relay pass. We expect that this may occasionally happen. As long as the 
‘critical’ virtual devices are preserved, we can run DPL to get the complete 
representation.
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to