yangulei opened a new pull request #10112: URL: https://github.com/apache/tvm/pull/10112
### Motivation: We are enabling [bfloat16](https://discuss.tvm.apache.org/t/rfc-add-bfloat16-data-type/6778) in [BYOC-oneDNN](https://discuss.tvm.apache.org/t/rfc-byoc-intel-r-onednn-integration/11582) following the path: [float32 graph] --> \<[AMP](https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994)\> --> [bfloat16 graph] --> \<BYOC\> --> [TVM + oneDNN module]. While some of the Passes like `FoldConstant` can not work for bfloat16 before the improvements below. ### Changes: - Add runtime datatype dispatch and skip asserts for uint16 for bfloat16 compatibility. - Add bfloat16 casting for unary intrinsic operators to enable the graph optimization. - Improve the bf16_legalize module to enable bfloat16 lowering. With those improvements, a float32 graph could be converted to bfloat16 through AMP, and then be lowered to inference in bfloat16 mode now. ### Tested Models (gluoncv): - ResNet<18/34/50/101/152>_v1b - VGG<11/13/16/19> - VGG<11/13/16/19>_bn - DenseNet121 - InceptionV3 > By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet. As @AndrewZhaoLuo said at https://github.com/apache/tvm/pull/8069 ### Pending: The support for bfloat16 in BYOC-oneDNN is based on [multi-blocking layout transform](https://github.com/apache/tvm/pull/9996) and the [extensions on BYOC-oneDNN](https://github.com/apache/tvm/pull/9995) and pending. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
