This is an automated email from the ASF dual-hosted git repository.
ziheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 5aafff9 Bring Your Own Datatypes (#5812)
5aafff9 is described below
commit 5aafff913b963879f6ea6f24e01533793ea1a68a
Author: Gus Smith <[email protected]>
AuthorDate: Sat Sep 26 15:49:20 2020 -0700
Bring Your Own Datatypes (#5812)
* Add ChangeDatatype pass and unittest
* [WIP] Jared's work on Fri
This was work that Jared did on my computer, trying to get Inception v3
running.
* Fix simplify inference to work over different data types.
* Formatting
* Copy setup code from other test file
* Logging in Relay
* Remove duplicate TVM_DLL
* Add Sub, Mul, Div, Max to bfloat lib
* Fix previous broken rebased commit
* Remove line
* Add LowerCustomDatatypes to build passes
* Upcast ints to custom datatypes too, as well as to floats
* Add and use convert_ndarray
* Lower Call
* Relay: create constant scalars of custom dtypes
We use the same method we use in TVM: store the value in a double.
* Custom datatype formatting in Relay
* Update unittests
* Add simpler example that's not working yet
* Add Python unittests to Makefile
* Fix bug
* Fix function name in GetPackedFunc call
* convert_ndarray makes its own executor
* Add simple test case
* Move setup() calls
* Use convert_ndarray
* Change import to make it more specific
* Fix another Registry::Get call
* Allow users to register minimum functions for custom datatypes
This commit allows users to register global functions named
`tvm.datatype.min.<type name>` which take the number of bits in the custom
type
and return the corresponding minimum value (as a double).
A similar commit will need to be created for max, whenever that ends up
being
needed!
* Remove check for float
* Add test
* Fix inception test
* Add MobileNet
* Lower custom datatypes before intrinsics
* Add exp and sqrt bfloat functions
* [buggy commit] Lower intrinsics like sqrt, exp
This commit has bugs in it, I'm fairly certain.
* Formatting
* Fix bug
* Add lowering for new ops in test
* Add int to bfloat
* Remove print
* Add all tests
* Correct image size
* Add TODO
* Add "notbfloat" type
This type is for testing purposes. It just stores a float in a uint32. It
was
used to confirm the fact that my bfloat "implementation" is very numerically
unstable and was causing issues when running the model.
* Convert arguments
Not sure how necessary this actually is.
* Rewrite custom datatype constants in Relay
* Add test_ops
* Print constants in Relay
* Use topi.testing
* Test conv2d
* Add test_model
* Comment out model tests
* Register notbfloat
This could be unregistered at some point later
* Add commented code
Remove later
* Add posit tests
* test_ops_same_function
* [temporary] move incomplete commit to macbook
* Add more to tests
* Formatting
* Uncomment add
* Remove bad tests
* Change comments
* Change function name and docstring
* Change main function
* Restructure tests
* Fix visibility of posit functions
* YAPF
* Switching keywords around to resolve build errors on some systems
* Improve test by running smaller mobilenet
* Add test_cast
* Change datatype name; add simple test
* Rename to posit32
* Merge 3 posit types into one file
* Add a nop type
* Remove bfloat
* Refactor test comments
* Refactor conv2d test
* Add optional tolerance arguments
* Add posit8 and posit16
* Add comment about posit8
* Whoops -- actually add noptype to CMakeLists
* Add rtol, atol to run_workload
* Add noptype to tests
* Run noptype over other models, too
* Pass correct arguments to calls
* Fix line length errors
* Raise tolerances (again) to avoid flaky test
* fix style
* add test for tanh, log, sigmoid
* Remove references to bfloat, notbfloat
* Change comments
* Remove old test file
* fix min func
* refactoring unit test file
* use posits es2
* cleanup
* comment
* coment if_then_else
* support different bit widths
* use random seed to create stable tests
* update documentation
* removed nop-type and code consistency
* add batchnorm test
* rebase and update
* fix tests and format
* pylint
* change order of include
* include order
* fix style
* remove posit c linkage
* update universal
* fix style
* fix test
* fix overflow error with minfunc and posits
* style
* use change_dtype to convert params
* update universal
* fix fatal error
* fix constant repr
* minor update to posites2
* update universal
* fix rst
* fix invalid import and sqrt
* update universal
* comments
* comments and expand testing
* increase atol/rtol for custom[posites2]32
* Re-add newline
* Remove comment
* Remove opt level and comment
* Change docstring
* Add TODO
* Add file header and newline
* Update docstring
* Update file docstring
* Update docstrings
* Delete todos
* create_min_lower_func
* add better debugging message
* docs
* add BYODT tutorial
* add todo
* Reformat some of tutorial to RST, plus code fixes
* tutorial notebook runs now
* fix hyperlink
* rebase
* add to tutorial
* fix mobilenet model
* add skip tag
* black lint
* add compiler flag and add dummy float
* myfloat and posites2 test
* remove universal
* lint
* lint
* add setup
* build with USE_POSIT for CI/CD
* fix posit cmake
* add cd /
* undo docker changes
* change tutorial to use myfloat
* move files
* lint
* fix
* remove filter
* fix lint
* fix suggestions
Co-authored-by: Jared Roesch <[email protected]>
Co-authored-by: Andrew Liu <[email protected]>
---
3rdparty/bfloat16/bfloat16.cc | 84 ---
CMakeLists.txt | 5 +-
LICENSE | 1 -
cmake/config.cmake | 3 +
cmake/modules/LibInfo.cmake | 1 +
.../modules/contrib/Posit.cmake | 28 +-
licenses/LICENSE.bfloat16.txt | 9 -
python/tvm/driver/build_module.py | 2 +
python/tvm/relay/backend/_backend.py | 3 +
python/tvm/relay/frontend/__init__.py | 1 +
python/tvm/relay/frontend/change_datatype.py | 107 ++++
python/tvm/target/datatype.py | 303 +++++++++--
src/arith/rewrite_simplify.cc | 4 +-
src/driver/driver_api.cc | 2 +
src/relay/backend/utils.h | 6 +-
src/relay/transforms/pattern_util.h | 77 +--
src/support/libinfo.cc | 5 +
src/target/datatype/myfloat/myfloat.cc | 144 ++++++
src/target/datatype/posit/posit-wrapper.cc | 238 +++++++++
src/target/datatype/registry.cc | 19 +
src/target/datatype/registry.h | 21 +-
src/tir/op/op.cc | 17 +-
src/tir/transforms/lower_custom_datatypes.cc | 18 +-
tests/python/unittest/test_custom_datatypes.py | 562 +++++++++++++++++++++
.../unittest/test_target_custom_datatypes.py | 154 ------
tutorials/dev/bring_your_own_datatypes.py | 408 +++++++++++++++
26 files changed, 1872 insertions(+), 350 deletions(-)
diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc
deleted file mode 100644
index 674feb4..0000000
--- a/3rdparty/bfloat16/bfloat16.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- Copyright (c) 2019 by Contributors
- \file tvm/src/codegen/custom_datatypes/mybfloat16.cc
- \brief Small bfloat16 library for use in unittests
-
- Code originally from TensorFlow; taken and simplified. Original license:
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
==============================================================================*/
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include <cstddef>
-#include <cstdint>
-
-void FloatToBFloat16(const float* src, uint16_t* dst, size_t size) {
- const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
- uint16_t* q = reinterpret_cast<uint16_t*>(dst);
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p += 2, q++, size--) {
- *q = p[0];
- }
-#else
- for (; size != 0; p += 2, q++, size--) {
- *q = p[1];
- }
-#endif
-}
-
-void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) {
- const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
- uint16_t* q = reinterpret_cast<uint16_t*>(dst);
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p++, q += 2, size--) {
- q[0] = *p;
- q[1] = 0;
- }
-#else
- for (; size != 0; p++, q += 2, size--) {
- q[0] = 0;
- q[1] = *p;
- }
-#endif
-}
-
-void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, size_t
size) {
- float a_f, b_f;
- BFloat16ToFloat(a, &a_f, 1);
- BFloat16ToFloat(b, &b_f, 1);
- float out_f = a_f + b_f;
- FloatToBFloat16(&out_f, dst, 1);
-}
-
-extern "C" {
-TVM_DLL uint16_t FloatToBFloat16_wrapper(float in);
-TVM_DLL float BFloat16ToFloat_wrapper(uint16_t in);
-TVM_DLL uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b);
-
-uint16_t FloatToBFloat16_wrapper(float in) {
- uint16_t out;
- FloatToBFloat16(&in, &out, 1);
- return out;
-}
-
-float BFloat16ToFloat_wrapper(uint16_t in) {
- float out;
- BFloat16ToFloat(&in, &out, 1);
- return out;
-}
-
-uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b) {
- uint16_t out;
- BFloat16Add(&a, &b, &out, 1);
- return out;
-}
-}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a29364b..f3182a4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -56,6 +56,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT"
"3rdparty/compiler-rt")
tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
# Contrib library options
+tvm_option(USE_BYOC_POSIT "Build with BYOC software emulated posit custom
datatype" OFF)
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
@@ -257,6 +258,7 @@ endif(USE_VM_PROFILER)
file(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
+list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
@@ -279,8 +281,6 @@ if (INDEX_DEFAULT_I64)
add_definitions(-DTVM_INDEX_DEFAULT_I64=1)
endif()
-list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc)
-
if(USE_RPC)
message(STATUS "Build with RPC support...")
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
@@ -334,6 +334,7 @@ include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/Random.cmake)
+include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
diff --git a/LICENSE b/LICENSE
index 1314c16..52b2219 100644
--- a/LICENSE
+++ b/LICENSE
@@ -209,7 +209,6 @@ for text of these licenses.
Apache Software Foundation License 2.0
--------------------------------------
-3rdparty/bfloat16/bfloat16.cc
3rdparty/dlpack
3rdparty/dmlc-core
diff --git a/cmake/config.cmake b/cmake/config.cmake
index 6ed660c..7e8df55 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -120,6 +120,9 @@ set(USE_LLVM OFF)
#---------------------------------------------
# Contrib libraries
#---------------------------------------------
+# Whether to build with BYOC software emulated posit custom datatype
+set(USE_BYOC_POSIT OFF)
+
# Whether use BLAS, choices: openblas, atlas, apple
set(USE_BLAS none)
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index e7685f3..53d88ff 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -53,6 +53,7 @@ function(add_lib_info src_file)
TVM_INFO_HIDE_PRIVATE_SYMBOLS="${HIDE_PRIVATE_SYMBOLS}"
TVM_INFO_USE_TF_TVMDSOOP="${USE_TF_TVMDSOOP}"
TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}"
+ TVM_INFO_USE_BYOC_POSIT="${USE_BYOC_POSIT}"
TVM_INFO_USE_BLAS="${USE_BLAS}"
TVM_INFO_USE_MKL="${USE_MKL}"
TVM_INFO_USE_MKLDNN="${USE_MKLDNN}"
diff --git a/python/tvm/relay/frontend/__init__.py
b/cmake/modules/contrib/Posit.cmake
similarity index 59%
copy from python/tvm/relay/frontend/__init__.py
copy to cmake/modules/contrib/Posit.cmake
index 7154f5a..cd2f2f6 100644
--- a/python/tvm/relay/frontend/__init__.py
+++ b/cmake/modules/contrib/Posit.cmake
@@ -14,23 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-Frontends for constructing Relay programs.
-Contains the model importers currently defined
-for Relay.
-"""
-
-from __future__ import absolute_import
-
-from .mxnet import from_mxnet
-from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var
-from .keras import from_keras
-from .onnx import from_onnx
-from .tflite import from_tflite
-from .coreml import from_coreml
-from .caffe2 import from_caffe2
-from .tensorflow import from_tensorflow
-from .darknet import from_darknet
-from .pytorch import from_pytorch
-from .caffe import from_caffe
+if(USE_BYOC_POSIT)
+ message(STATUS "Build with contrib.posit")
+ if (NOT UNIVERSAL_PATH)
+ message(FATAL_ERROR "Fail to get Universal path")
+ endif(NOT UNIVERSAL_PATH)
+
+ include_directories(${UNIVERSAL_PATH}/include)
+ list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc")
+endif(USE_BYOC_POSIT)
diff --git a/licenses/LICENSE.bfloat16.txt b/licenses/LICENSE.bfloat16.txt
deleted file mode 100644
index ce537b4..0000000
--- a/licenses/LICENSE.bfloat16.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
\ No newline at end of file
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 1c11a6c..2b8346d 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -264,6 +264,7 @@ def _build_for_device(input_mod, target, target_host):
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
+ tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
]
)
@@ -279,6 +280,7 @@ def _build_for_device(input_mod, target, target_host):
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
+ tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall(),
]
diff --git a/python/tvm/relay/backend/_backend.py
b/python/tvm/relay/backend/_backend.py
index 641ff04..65b0c0b 100644
--- a/python/tvm/relay/backend/_backend.py
+++ b/python/tvm/relay/backend/_backend.py
@@ -90,6 +90,9 @@ def _tensor_value_repr(tvalue):
@tvm._ffi.register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
+ dtype = tvm.runtime.DataType(tvalue.data.dtype)
+ if tvm.target.datatype.get_type_registered(dtype.type_code):
+ return "custom tensor of type " + dtype.type_code
return str(tvalue.data.asnumpy())
diff --git a/python/tvm/relay/frontend/__init__.py
b/python/tvm/relay/frontend/__init__.py
index 7154f5a..7e16499 100644
--- a/python/tvm/relay/frontend/__init__.py
+++ b/python/tvm/relay/frontend/__init__.py
@@ -34,3 +34,4 @@ from .tensorflow import from_tensorflow
from .darknet import from_darknet
from .pytorch import from_pytorch
from .caffe import from_caffe
+from .change_datatype import ChangeDatatype
diff --git a/python/tvm/relay/frontend/change_datatype.py
b/python/tvm/relay/frontend/change_datatype.py
new file mode 100644
index 0000000..dc80b3e
--- /dev/null
+++ b/python/tvm/relay/frontend/change_datatype.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Change Datatype Pass"""
+from ..function import Function
+from ..expr_functor import ExprMutator
+from ..transform.transform import function_pass
+from ..expr import var, bind
+
+
+@function_pass(opt_level=0)
+class ChangeDatatype(ExprMutator):
+ """Mutator for changing the datatype of Relay programs.
+
+ This pass should be useful for users of the Bring Your Own Datatypes
+ framework.
+ TODO(@gussmith23 @hypercubestart) Add link to documentation when it exists
+
+ Example:
+
+ .. code-block:: python
+
+ from tvm.relay.testing.inception_v3 import get_workload
+ mod, params = get_workload()
+
+ def change_dtype(mod, params, src, dst):
+ mod = ChangeDatatype(src, dst)(mod)
+ params = dict((p, tvm.nd.array(params[p].asnumpy().astype(dst)))
for p in params)
+ return mod, params
+
+ mod, params = change_dtype(mod, params, "float32",
"custom[posites2]32")
+
+ Parameters
+ ----------
+ src : String
+ The source datatype name, e.g. "float" or "posites2" (but not "float32"
+ or "custom[posites2]32").
+ dst : String
+ The destination datatype name, in the same format.
+
+ Returns
+ -------
+ mod : tvm.IRModule
+ Module where all nodes of dtype `src` have been changed to have dtype
+ `dst`.
+ """
+
+ def __init__(self, src, dst):
+ self.src = src
+ self.dst = dst
+ super().__init__()
+
+ def transform_function(self, func, mod, ctx):
+ return self.visit(func)
+
+ def visit_constant(self, const):
+ if const.data.dtype == self.src:
+ return const.astype(self.dst)
+ return const
+
+ def visit_function(self, fn):
+ new_params = []
+ binds = {}
+
+ for param in fn.params:
+ # Get the parameter's type annotation.
+ var_type = param.type_annotation
+
+ # See if we want to replace dtype.
+ if var_type.dtype == self.src:
+ dtype = self.dst
+ else:
+ dtype = var_type.dtype
+
+ # Generate new variable.
+ new_param = var(param.name_hint, shape=var_type.shape, dtype=dtype)
+
+ new_params.append(new_param)
+ binds[param] = new_param
+
+ new_body = self.visit(fn.body)
+ # Rewrite the body to use new parameters.
+ new_body = bind(new_body, binds)
+
+ # Construct the updated function and return.
+ return Function(
+ new_params,
+ new_body,
+ # You could change the return type, if you use None it will
re-infer.
+ None,
+ type_params=fn.type_params,
+ attrs=fn.attrs,
+ )
diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py
index 5d3ca5f..cdd8149 100644
--- a/python/tvm/target/datatype.py
+++ b/python/tvm/target/datatype.py
@@ -14,73 +14,154 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Custom datatype functionality"""
-import tvm._ffi
+"""Bring Your Own Datatypes custom datatype framework
-import tvm.runtime._ffi_api
-from tvm.runtime import DataType
-import tvm.tir
-from tvm.tir.expr import Cast as _Cast, FloatImm as _FloatImm
+TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist"""
+import tvm
+from tvm.runtime import convert, DataType
+from tvm.tir.expr import (
+ Call as _Call,
+ Cast as _Cast,
+ FloatImm as _FloatImm,
+ BinaryOpExpr as _BinaryOpExpr,
+)
+from tvm.tir.op import call_pure_extern
+from tvm._ffi import register_func as _register_func
+from tvm.tir import call_intrin
def register(type_name, type_code):
"""Register a custom datatype with the given type name and type code
- Currently, the type code is manually allocated by the user, and the
- user must ensure that no two custom types share the same code.
- Generally, this should be straightforward, as the user will be
- manually registering all of their custom types.
+
+ Currently, the type code is manually allocated by the user, and the user
+ must ensure that no two custom types share the same code. Generally, this
+ should be straightforward, as the user will be manually registering all of
+ their custom types.
+
+ Example:
+
+ .. code-block:: python
+
+ # Register a dtype named 'posites2' under type code 130.
+ tvm.datatype.register('posites2', 130)
+
Parameters
----------
type_name : str
- The name of the custom datatype
+ The name of the custom datatype.
type_code : int
- The type's code, which should be >= kCustomBegin
+ The type's code, which should be >= kCustomBegin. See
+ include/tvm/runtime/data_type.h.
"""
tvm.runtime._ffi_api._datatype_register(type_name, type_code)
def get_type_name(type_code):
- """Get the type name from the type code
+ """Get the type name of a custom datatype from the type code.
+
+ Note that this only works for custom datatypes registered with
+ tvm.datatype.register(). It does not work for TVM-native types.
+
+ Example:
+
+ .. code-block:: python
+
+ tvm.datatype.register('posites2', 130)
+ assert tvm.datatype.get_type_name(130) == 'posites2'
Parameters
----------
type_code : int
- The type code
+ The type code of the custom datatype.
+
+ Returns
+ -------
+ type_name : String
+ The name of the custom datatype.
+
"""
return tvm.runtime._ffi_api._datatype_get_type_name(type_code)
def get_type_code(type_name):
- """Get the type code from the type name
+ """Get the type code of a custom datatype from its type name
+
+ Note that this only works for custom datatypes registered with
+ tvm.datatype.register(). It does not work for TVM-native types.
+
+ Example:
+
+ .. code-block:: python
+
+ tvm.datatype.register('posites2', 130)
+ assert tvm.datatype.get_type_code('posites2') == 130
Parameters
----------
type_name : str
The type name
+
+ Returns
+ -------
+ type_code : int
+ The type code of the custom datatype.
"""
return tvm.runtime._ffi_api._datatype_get_type_code(type_name)
def get_type_registered(type_code):
- """Get a boolean representing whether the type is registered
+ """Returns true if a custom datatype is registered under the given type
code
+
+ Example:
+
+ .. code-block:: python
+
+ tvm.datatype.register('posites2', 130)
+ assert tvm.datatype.get_type_registered(130)
Parameters
----------
type_code: int
The type code
+
+ Returns
+ -------
+ type_registered : bool
+ True if a custom datatype is registered under this type code, and false
+ otherwise.
"""
return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
-def register_op(lower_func, op_name, target, type_name, src_type_name=None):
- """Register an external function which computes the given op.
+def register_op(
+ lower_func, op_name, target, src_type_name, dest_type_name=None,
intrinsic_name=None
+):
+ """Register a lowering function for a specific operator of a custom
datatype
+
+ At build time, Relay must lower operators over custom datatypes into
+ operators it understands how to compile. For each custom datatype operator
+ which Relay finds while lowering custom datatypes, Relay expects to find a
+ user-defined lowering function. Users register their user-defined lowering
+ functions using this function.
- Currently, this will only work with Casts and binary expressions
- whose arguments are named `a` and `b`.
- TODO(gus) figure out what other special cases must be handled by
- looking through expr.py.
+ Users should use create_lower_func to create their lowering function. It
+ should serve most use-cases.
+
+ Currently, this will work with Casts, intrinsics (e.g. sqrt, sigmoid), and
+ binary expressions (e.g. Add, Sub, Mul, Div).
+
+ See the LowerCustomDatatypes pass to see how registered functions are used.
+
+ Lowering Functions
+ ------------------
+ TODO(@gussmith23) Get the terminology right here.
+ Lowering functions take in a Relay node, and should return a semantically
+ equivalent Relay node which Relay can build. This means that the returned
+ node should not contain any custom datatypes. Users should likely not need
+ to define lowering functions by hand -- see the helper function
+ create_lower_func.
Parameters
----------
@@ -89,43 +170,123 @@ def register_op(lower_func, op_name, target, type_name,
src_type_name=None):
op_name : str
The name of the operation which the function computes, given by its
- class name (e.g. Add, LE, Cast).
+ class name (e.g. Add, LE, Cast, Call).
target : str
The name of codegen target.
- type_name : str
- The name of the custom datatype, e.g. posit (but not custom[posit]8).
-
src_type_name : str
- If op_name is "Cast", then this should be set to the source datatype of
+ The name of the custom datatype, e.g. posites2 (but not
custom[posites2]32).
+ If op_name is not "Cast", then target type is guaranteed to be the
same as src_type_name.
+
+ dest_type_name : str
+ If op_name is "Cast", then this is required and should be set to the
dest datatype of
the argument to the Cast. If op_name is not "Cast", this is unused.
+
+ intrinsic_name : str
+ If op_name is "Call" and intrinsic_name is not None, then we assume the
+ op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's
+ name.
"""
if op_name == "Cast":
- assert src_type_name is not None
+ assert dest_type_name is not None
+ lower_func_name = (
+ "tvm.datatype.lower."
+ + target
+ + "."
+ + op_name
+ + "."
+ + dest_type_name
+ + "."
+ + src_type_name
+ )
+ elif op_name == "Call" and intrinsic_name is not None:
lower_func_name = (
- "tvm.datatype.lower." + target + "." + op_name + "." + type_name +
"." + src_type_name
+ "tvm.datatype.lower."
+ + target
+ + "."
+ + op_name
+ + ".intrin."
+ + intrinsic_name
+ + "."
+ + src_type_name
)
else:
- lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "."
+ type_name
+ lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "."
+ src_type_name
tvm._ffi.register_func(lower_func_name, lower_func)
-def create_lower_func(extern_func_name):
+def register_min_func(func, type_name):
+ """Register the function that returns the minimum representable value of
type_name.
+
+ Operators such as max pooling and argmax require the minimum
+ finite value representable by the datatype the op operating on.
+ Users can use this function to register a function that returns a TIR
expression node
+ outputting the minimum representable value of their custom data type.
+
+ Users should use create_min_lower_func to create their lowering function.
It
+ should serve most use-cases.
+
+ Note: for special cases when it is known that the custom datatype is
representable
+ by a float, the user can create their own lowering func that returns a
FloatImm.
+ The benefits are allowing optimizations such as rewrites to work as
expected on custom
+ datatypes.
+
+ Parameters
+ ----------
+ func : function
+ Input is an integer num_bits, should return a TIR expression node that
+ represents a scalar tensor of type custom[type_name]num_bits with the
minimum
+ representable value.
+
+ type_name : str
+ The name of the custom datatype, e.g. posites2 (but not
custom[posites2]32).
+ """
+ _register_func("tvm.datatype.min." + type_name, func)
+
+
+def create_min_lower_func(extern_func_map, type_name):
+ """Returns a lowering function for getting the minimum value of a custom
datatype.
+
+ Parameters
+ ----------
+ extern_func_map : map
+ A map from bit lengths to the name of the extern "C" function to lower
to.
+
+ type_name : string
+ The name of the custom datatype, e.g. posites2 (but not
custom[posites2]32).
+ """
+
+ def lower(num_bits):
+ dtype = f"custom[{type_name}]{num_bits}"
+
+ if num_bits not in extern_func_map:
+ raise RuntimeError("missing minimum function for {dtype}")
+
+ return call_pure_extern(dtype, extern_func_map[num_bits])
+
+ return lower
+
+
+def create_lower_func(extern_func_map):
"""Returns a function which lowers an operation to a function call.
Parameters
----------
- extern_func_name : str
- The name of the extern "C" function to lower to
+ extern_func_map : map
+ If lowering a Cast, extern_func_map should be a map from tuples of
+ (src_bit_length, dest_bit_length) to the name of the extern "C"
function to lower to.
+
+ Otherwise, for unary and binary ops, it should simply be a map
+ from bit_length to the name of the extern "C" function to lower to.
"""
def lower(op):
"""
- Takes an op---either a Cast or a binary op (e.g. an Add) and returns a
+ Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and
returns a
call to the specified external function, passing the op's argument
- (Cast) or arguments (a binary op). The return type of the call depends
+ or arguments. The return type of the call depends
on the type of the op: if it is a custom type, then a uint of the same
width as the custom type is returned. Otherwise, the type is
unchanged."""
@@ -135,8 +296,74 @@ def create_lower_func(extern_func_name):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
- if isinstance(op, (_Cast, _FloatImm)):
- return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value)
- return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
+
+ key = t.bits
+ if isinstance(op, _Cast):
+ src_bits = DataType(op.value.dtype).bits
+ key = (src_bits, t.bits)
+
+ if key not in extern_func_map:
+ raise RuntimeError(f"missing key {key} in extern_func_map for
{op.astext()}")
+
+ if isinstance(op, _Cast):
+ return call_pure_extern(dtype, extern_func_map[key], op.value)
+ if isinstance(op, _FloatImm):
+ return call_pure_extern(dtype, extern_func_map[key], op.value)
+ if isinstance(op, _Call):
+ return call_pure_extern(dtype, extern_func_map[key], *op.args)
+ if isinstance(op, _BinaryOpExpr):
+ return call_pure_extern(dtype, extern_func_map[key], op.a, op.b)
+
+ raise RuntimeError(f"lowering unsupported op: {op.astext()}")
return lower
+
+
+def lower_ite(ite_op):
+ """Lowered if then else function that calls intrinsic if_then_else.
+ Unlike a function lowered by create_lower_func, this function
+ calls the tvm intrinsic if_then_else.
+
+ Parameters
+ ----------
+ ite_op : Op
+ Takes an if then else op and returns a
+ call to tir.if_then_else function, passing the op's
+ arguments. The return type of the call if a uint of the same
+ width as the custom type is returned.
+ """
+ dtype = ite_op.dtype
+ t = tvm.DataType(dtype)
+ assert get_type_registered(t.type_code)
+ dtype = "uint" + str(t.bits)
+ if t.lanes > 1:
+ dtype += "x" + str(t.lanes)
+ return call_intrin(
+ dtype,
+ "tir.if_then_else",
+ convert(ite_op.args[0]),
+ convert(ite_op.args[1]),
+ convert(ite_op.args[2]),
+ )
+
+
+def lower_call_pure_extern(op):
+ """Lowered call pure extern function that calls intrinsic call_pure_extern.
+ Unlike a function lowered by create_lower_func, this function
+ calls the tvm intrinsic call_pure_extern.
+
+ Parameters
+ ----------
+ ite_op : Op
+ Takes a call_pure_extern op and returns a
+ call to tir.call_pure_extern function, passing the op's
+ arguments. The return type of the call if a uint of the same
+ width as the custom type is returned.
+ """
+ dtype = op.dtype
+ t = tvm.DataType(dtype)
+ assert get_type_registered(t.type_code)
+ dtype = "uint" + str(t.bits)
+ if t.lanes > 1:
+ dtype += "x" + str(t.lanes)
+ return call_intrin(dtype, "tir.call_pure_extern", *op.args)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index a7aded6..c237edc 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -30,6 +30,7 @@
#include <algorithm>
+#include "../target/datatype/registry.h"
#include "const_fold.h"
#include "pattern_match.h"
@@ -460,7 +461,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode*
op) {
// x / 2.0 = x * 0.5
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
- CHECK(op->dtype.is_float());
+ CHECK(op->dtype.is_float() ||
+ datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e4851b5..d05d846 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -209,6 +209,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule
mod_mixed, const Target
}),
BindTarget(target_host),
tir::transform::LowerTVMBuiltin(),
+ tir::transform::LowerCustomDatatypes(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
tir::transform::CombineContextCall(),
@@ -225,6 +226,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule
mod_mixed, const Target
BindTarget(target),
tir::transform::LowerWarpMemory(),
tir::transform::Simplify(),
+ tir::transform::LowerCustomDatatypes(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
};
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index d6edd10..07f4226 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -149,8 +149,12 @@ inline std::string DType2String(const tvm::DataType dtype)
{
os << "int";
} else if (dtype.is_uint()) {
os << "uint";
+ } else if
((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
+ os << "custom["
+ <<
(*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator
std::string()
+ << "]";
} else {
- LOG(FATAL) << "Unknown type";
+ LOG(FATAL) << "Unknown type with code " <<
static_cast<unsigned>(dtype.code());
}
os << dtype.bits();
return os.str();
diff --git a/src/relay/transforms/pattern_util.h
b/src/relay/transforms/pattern_util.h
index 39fbec5..17e7304 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -35,6 +35,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h>
#include <limits>
@@ -51,42 +52,46 @@ namespace relay {
* \brief Dispatch DataType to the C++ data type
* during runtime.
*/
-#define TVM_DTYPE_DISPATCH(type, DType, ...) \
- if (type == DataType::Float(64)) { \
- typedef double DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::Float(32)) { \
- typedef float DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::Float(16)) { \
- typedef uint16_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::Int(64)) { \
- typedef int64_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::Int(32)) { \
- typedef int32_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::Int(16)) { \
- typedef int16_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::Int(8)) { \
- typedef int8_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::UInt(64)) { \
- typedef uint64_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::UInt(32)) { \
- typedef uint32_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::UInt(16)) { \
- typedef uint16_t DType; \
- { __VA_ARGS__ } \
- } else if (type == DataType::UInt(8)) { \
- typedef uint8_t DType; \
- { __VA_ARGS__ } \
- } else { \
- LOG(FATAL) << "unknown data type " << type; \
+#define TVM_DTYPE_DISPATCH(type, DType, ...)
\
+ if (type == DataType::Float(64)) {
\
+ typedef double DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::Float(32)) {
\
+ typedef float DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::Float(16)) {
\
+ typedef uint16_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::Int(64)) {
\
+ typedef int64_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::Int(32)) {
\
+ typedef int32_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::Int(16)) {
\
+ typedef int16_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::Int(8)) {
\
+ typedef int8_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::UInt(64)) {
\
+ typedef uint64_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::UInt(32)) {
\
+ typedef uint32_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::UInt(16)) {
\
+ typedef uint16_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if (type == DataType::UInt(8)) {
\
+ typedef uint8_t DType;
\
+ { __VA_ARGS__ }
\
+ } else if
((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \
+ static_cast<uint8_t>(type.code()))) {
\
+ typedef double DType;
\
+ { __VA_ARGS__ }
\
+ } else {
\
+ LOG(FATAL) << "unknown data type " << type;
\
}
/*!
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index 0bdae82..3daf4e4 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -120,6 +120,10 @@
#define TVM_INFO_USE_FALLBACK_STL_MAP "NOT-FOUND"
#endif
+#ifndef TVM_INFO_USE_BYOC_POSIT
+#define TVM_INFO_USE_BYOC_POSIT "NOT-FOUND"
+#endif
+
#ifndef TVM_INFO_USE_BLAS
#define TVM_INFO_USE_BLAS "NOT-FOUND"
#endif
@@ -237,6 +241,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"HIDE_PRIVATE_SYMBOLS", TVM_INFO_HIDE_PRIVATE_SYMBOLS},
{"USE_TF_TVMDSOOP", TVM_INFO_USE_TF_TVMDSOOP},
{"USE_FALLBACK_STL_MAP", TVM_INFO_USE_FALLBACK_STL_MAP},
+ {"USE_BYOC_POSIT", TVM_INFO_USE_BYOC_POSIT},
{"USE_BLAS", TVM_INFO_USE_BLAS},
{"USE_MKL", TVM_INFO_USE_MKL},
{"USE_MKLDNN", TVM_INFO_USE_MKLDNN},
diff --git a/src/target/datatype/myfloat/myfloat.cc
b/src/target/datatype/myfloat/myfloat.cc
new file mode 100644
index 0000000..c0c2fff
--- /dev/null
+++ b/src/target/datatype/myfloat/myfloat.cc
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file 3rdparty/byodt/my-custom-datatype.cc
+ * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT)
framework.
+ * This is a toy example that under the hood simulates floats.
+ *
+ * Users interested in using the BYODT framework can use this file as a
template.
+ *
+ * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
+ */
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cmath>
+#include <cstdint>
+#include <limits>
+
+// Custom datatypes are stored as bits in a uint of the appropriate bit length.
+// Thus, when TVM calls these C functions,
+// the arguments of are uints that need to reinterpreted as your custom
datatype.
+//
+// When returning, your custom datatype needs to be re-wrapped into a uint,
+// which can be thought of as just a wrapper for the raw bits that represent
your custom datatype.
+template <class T>
+TVM_DLL T Uint32ToCustom32(uint32_t in) {
+ // This is a helper function to interpret the uint as your custom dataype.
+ // The following line should be replaced with the appropriate function
+ // that interprets the bits in `in` and returns your custom datatype
+ T* custom = reinterpret_cast<T*>(&in);
+ return *custom;
+}
+
+template <class T>
+TVM_DLL uint32_t Custom32ToUint32(T in) {
+ // This is a helper function to wrap your custom datatype in a uint.
+ // the following line should be replaced with the appropriate function
+ // that converts your custom datatype into a uint
+ uint32_t* bits = reinterpret_cast<uint32_t*>(&in);
+ return *bits;
+}
+
+extern "C" {
+TVM_DLL uint32_t MinCustom32() {
+ // return minimum representable value
+ float min = std::numeric_limits<float>::lowest();
+ return Custom32ToUint32<float>(min);
+}
+
+TVM_DLL float Custom32ToFloat(uint32_t in) {
+ // cast from custom datatype to float
+ float custom_datatype = Uint32ToCustom32<float>(in);
+ // our custom datatype is float, so the following redundant cast to float
+ // is to remind users to cast their own custom datatype to float
+ return static_cast<float>(custom_datatype);
+}
+
+TVM_DLL uint32_t FloatToCustom32(float in) {
+ // cast from float to custom datatype
+ return Custom32ToUint32<float>(in);
+}
+
+TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) {
+ // add operation
+ float acustom = Uint32ToCustom32<float>(a);
+ float bcustom = Uint32ToCustom32<float>(b);
+ return Custom32ToUint32<float>(acustom + bcustom);
+}
+
+TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) {
+ // subtract
+ float acustom = Uint32ToCustom32<float>(a);
+ float bcustom = Uint32ToCustom32<float>(b);
+ return Custom32ToUint32<float>(acustom - bcustom);
+}
+
+TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) {
+ // multiply
+ float acustom = Uint32ToCustom32<float>(a);
+ float bcustom = Uint32ToCustom32<float>(b);
+ return Custom32ToUint32<float>(acustom * bcustom);
+}
+
+TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) {
+ // divide
+ float acustom = Uint32ToCustom32<float>(a);
+ float bcustom = Uint32ToCustom32<float>(b);
+ return Custom32ToUint32<float>(acustom / bcustom);
+}
+
+TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) {
+ // max
+ float acustom = Uint32ToCustom32<float>(a);
+ float bcustom = Uint32ToCustom32<float>(b);
+ return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom);
+}
+
+TVM_DLL uint32_t Custom32Sqrt(uint32_t a) {
+ // sqrt
+ float acustom = Uint32ToCustom32<float>(a);
+ return Custom32ToUint32<float>(sqrt(acustom));
+}
+
+TVM_DLL uint32_t Custom32Exp(uint32_t a) {
+ // exponential
+ float acustom = Uint32ToCustom32<float>(a);
+ return Custom32ToUint32<float>(exp(acustom));
+}
+
+TVM_DLL uint32_t Custom32Log(uint32_t a) {
+ // log
+ float acustom = Uint32ToCustom32<float>(a);
+ return Custom32ToUint32<float>(log(acustom));
+}
+
+TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) {
+ // sigmoid
+ float acustom = Uint32ToCustom32<float>(a);
+ float one = 1.0f;
+ return Custom32ToUint32<float>(one / (one + exp(-acustom)));
+}
+
+TVM_DLL uint32_t Custom32Tanh(uint32_t a) {
+ // tanh
+ float acustom = Uint32ToCustom32<float>(a);
+ return Custom32ToUint32<float>(tanh(acustom));
+}
+}
diff --git a/src/target/datatype/posit/posit-wrapper.cc
b/src/target/datatype/posit/posit-wrapper.cc
new file mode 100644
index 0000000..96dbfe1
--- /dev/null
+++ b/src/target/datatype/posit/posit-wrapper.cc
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file 3rdparty/posit/posit-wrapper.cc
+ * \brief Wrapper over the Universal library for Bring Your Own Datatypes tests
+ * Use the SET_POSIT flag to include this file in the build.
+ *
+ * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
+ */
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cstdint>
+
+#include "universal/posit/posit.hpp"
+// must go after posit.hpp
+#include "universal/posit/math/exponent.hpp"
+#include "universal/posit/math/hyperbolic.hpp"
+#include "universal/posit/math/logarithm.hpp"
+#include "universal/posit/math/sqrt.hpp"
+#include "universal/posit/numeric_limits.hpp"
+
+TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
+ sw::unum::bitblock<8> bb;
+ bb = static_cast<uint64_t>(in);
+ return sw::unum::posit<8, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
+ return static_cast<uint8_t>(in.get().to_ullong());
+}
+
+TVM_DLL uint8_t MinPosit8es2() {
+ auto min = std::numeric_limits<sw::unum::posit<8, 2>>::lowest();
+ return Posit8es2toUint8(min);
+}
+
+TVM_DLL float Posit8es2ToFloat(uint8_t in) { return
Uint8ToPosit8es2(in).operator float(); }
+
+TVM_DLL uint8_t FloatToPosit8es2(float in) {
+ auto posit = sw::unum::posit<8, 2>(in);
+ return Posit8es2toUint8(posit);
+}
+
+TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) {
+ auto a_p = Uint8ToPosit8es2(a);
+ auto b_p = Uint8ToPosit8es2(b);
+ return Posit8es2toUint8(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Exp(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Log(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) {
+ auto posit_one = sw::unum::posit<8, 2>(1);
+ return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) +
posit_one));
+}
+
+TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
+ sw::unum::bitblock<16> bb;
+ bb = static_cast<uint64_t>(in);
+ return sw::unum::posit<16, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
+ return static_cast<uint16_t>(in.get().to_ullong());
+}
+
+TVM_DLL uint8_t MinPosit16es2() {
+ auto min = std::numeric_limits<sw::unum::posit<16, 2>>::lowest();
+ return Posit16es2toUint16(min);
+}
+
+TVM_DLL float Posit16es2ToFloat(uint16_t in) { return
Uint16ToPosit16es2(in).operator float(); }
+
+TVM_DLL uint16_t FloatToPosit16es2(float in) {
+ auto posit = sw::unum::posit<16, 2>(in);
+ return Posit16es2toUint16(posit);
+}
+
+TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) {
+ auto a_p = Uint16ToPosit16es2(a);
+ auto b_p = Uint16ToPosit16es2(b);
+ return Posit16es2toUint16(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Exp(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Log(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) {
+ auto posit_one = sw::unum::posit<16, 2>(1);
+ return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a))
+ posit_one));
+}
+
+TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
+ sw::unum::bitblock<32> bb;
+ bb = static_cast<uint64_t>(in);
+ return sw::unum::posit<32, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint32_t Posit32es2ToUint32(sw::unum::posit<32, 2> in) {
+ return static_cast<uint32_t>(in.get().to_ullong());
+}
+
+TVM_DLL uint8_t MinPosit32es2() {
+ auto min = std::numeric_limits<sw::unum::posit<32, 2>>::lowest();
+ return Posit32es2ToUint32(min);
+}
+
+TVM_DLL float Posit32es2ToFloat(uint32_t in) { return
Uint32ToPosit32es2(in).operator float(); }
+
+TVM_DLL uint32_t FloatToPosit32es2(float in) {
+ auto posit = sw::unum::posit<32, 2>(in);
+ return Posit32es2ToUint32(posit);
+}
+
+TVM_DLL uint32_t Posit32es2Add(uint32_t a, uint32_t b) {
+ return Posit32es2ToUint32(Uint32ToPosit32es2(a) + Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Sub(uint32_t a, uint32_t b) {
+ return Posit32es2ToUint32(Uint32ToPosit32es2(a) - Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Mul(uint32_t a, uint32_t b) {
+ return Posit32es2ToUint32(Uint32ToPosit32es2(a) * Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Div(uint32_t a, uint32_t b) {
+ return Posit32es2ToUint32(Uint32ToPosit32es2(a) / Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Max(uint32_t a, uint32_t b) {
+ auto a_p = Uint32ToPosit32es2(a);
+ auto b_p = Uint32ToPosit32es2(b);
+ return Posit32es2ToUint32(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint32_t Posit32es2Sqrt(uint32_t a) {
+ return Posit32es2ToUint32(sw::unum::sqrt(Uint32ToPosit32es2(a)));
+}
+
+TVM_DLL uint32_t Posit32es2Exp(uint32_t a) {
+ return Posit32es2ToUint32(sw::unum::exp(Uint32ToPosit32es2(a)));
+}
+
+TVM_DLL uint32_t Posit32es2Log(uint32_t a) {
+ return Posit32es2ToUint32(sw::unum::log(Uint32ToPosit32es2(a)));
+}
+
+TVM_DLL uint32_t Posit32es2Sigmoid(uint32_t a) {
+ auto posit_one = sw::unum::posit<32, 2>(1);
+ return Posit32es2ToUint32(posit_one / (posit_one +
sw::unum::exp(-Uint32ToPosit32es2(a))));
+}
+
+TVM_DLL uint32_t Posit32es2Tanh(uint32_t a) {
+ return Posit32es2ToUint32(sw::unum::tanh(Uint32ToPosit32es2(a)));
+}
+}
diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc
index 5ed3ce4..c84f917 100644
--- a/src/target/datatype/registry.cc
+++ b/src/target/datatype/registry.cc
@@ -91,6 +91,13 @@ const runtime::PackedFunc* GetCastLowerFunc(const
std::string& target, uint8_t t
return runtime::Registry::Get(ss.str());
}
+const runtime::PackedFunc* GetMinFunc(uint8_t type_code) {
+ std::ostringstream ss;
+ ss << "tvm.datatype.min.";
+ ss << datatype::Registry::Global()->GetTypeName(type_code);
+ return runtime::Registry::Get(ss.str());
+}
+
const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target,
uint8_t type_code) {
std::ostringstream ss;
ss << "tvm.datatype.lower.";
@@ -100,6 +107,18 @@ const runtime::PackedFunc* GetFloatImmLowerFunc(const
std::string& target, uint8
return runtime::Registry::Get(ss.str());
}
+const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const
std::string& name,
+ uint8_t type_code) {
+ std::ostringstream ss;
+ ss << "tvm.datatype.lower.";
+ ss << target;
+ ss << ".Call.intrin.";
+ ss << name;
+ ss << ".";
+ ss << datatype::Registry::Global()->GetTypeName(type_code);
+ return runtime::Registry::Get(ss.str());
+}
+
uint64_t ConvertConstScalar(uint8_t type_code, double value) {
std::ostringstream ss;
ss << "tvm.datatype.convertconstscalar.float.";
diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h
index 5df8ef8..01e7b82 100644
--- a/src/target/datatype/registry.h
+++ b/src/target/datatype/registry.h
@@ -43,6 +43,8 @@ namespace datatype {
* For Casts: tvm.datatype.lower.<target>.Cast.<type>.<src_type>
* Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from
* float to myfloat.
+ * For intrinsic Calls: tvm.datatype.lower.<target>.Call.intrin.<name>.<type>
+ * Example: tvm.datatype.lower.llvm.Call.intrin.sqrt.myfloat
* For other ops: tvm.datatype.lower.<target>.<op>.<type>
* Examples: tvm.datatype.lower.llvm.Add.myfloat
* tvm.datatype.lower.llvm.FloatImm.posit
@@ -60,7 +62,7 @@ class Registry {
* manually allocated by the user, and the user must ensure that no two
custom types share the
* same code. Generally, this should be straightforward, as the user will be
manually registering
* all of their custom types.
- * \param type_name The name of the type, e.g. "bfloat"
+ * \param type_name The name of the type, e.g. "posites2"
* \param type_code The type code, which should be greater than
TVMArgTypeCode::kTVMExtEnd
*/
void Register(const std::string& type_name, uint8_t type_code);
@@ -112,6 +114,13 @@ class Registry {
uint64_t ConvertConstScalar(uint8_t type_code, double value);
/*!
+ * \brief Get a function returning the minimum value for a datatype.
+ * \param type_code The datatype
+ * \return Function which takes the width of the datatype and returns the min
value
+ */
+const runtime::PackedFunc* GetMinFunc(uint8_t type_code);
+
+/*!
* \brief Get lowering function for Cast ops
* \param target The target we are lowering to, e.g. "llvm"
* \param type_code The datatype being cast to
@@ -130,6 +139,16 @@ const runtime::PackedFunc* GetCastLowerFunc(const
std::string& target, uint8_t t
const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target,
uint8_t type_code);
/*!
+ * \brief Get lowering function for intrinsic Calls/pure intrinsic Calls
+ * \param target The target we are lowering to, e.g. "llvm"
+ * \param type_code The datatype of the Call
+ * \param name The intrinsic name
+ * \return Lowering function for intrinsic Calls for the provided target and
type
+ */
+const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const
std::string& name,
+ uint8_t type_code);
+
+/*!
* \brief Get lowering function for other ops
* \param target The target we are lowering to, e.g. "llvm"
* \param type_code The datatype of the op
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 6dc485f..6d94a08 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -32,6 +32,7 @@
#include <cmath>
// Centralized header for constant folders.
#include "../../arith/const_fold.h"
+#include "../../target/datatype/registry.h"
namespace tvm {
@@ -114,10 +115,14 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) {
// NOLINT(*)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
- if (!lhs.dtype().is_float() && rhs.dtype().is_float()) {
+ if (!lhs.dtype().is_float() &&
+ (rhs.dtype().is_float() ||
+ datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
// int->float
lhs = cast(rhs.dtype(), lhs);
- } else if (lhs.dtype().is_float() && !rhs.dtype().is_float()) {
+ } else if ((lhs.dtype().is_float() ||
+
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
+ !rhs.dtype().is_float()) {
// int->float
rhs = cast(lhs.dtype(), rhs);
} else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
@@ -174,7 +179,13 @@ PrimExpr max_value(const DataType& dtype) {
PrimExpr min_value(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
- if (dtype.is_int()) {
+ if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
+ auto f = datatype::GetMinFunc(dtype.code());
+ CHECK(f) << "No minimum function registered for custom dtype " <<
(unsigned int)dtype.code();
+ // TODO(@hypercubestart) Document this change (and others associated with
the overflowing
+ // floatimm min bug)
+ return (*f)(dtype.bits());
+ } else if (dtype.is_int()) {
if (dtype.bits() == 64) {
return IntImm(dtype, std::numeric_limits<int64_t>::lowest());
} else if (dtype.bits() < 64) {
diff --git a/src/tir/transforms/lower_custom_datatypes.cc
b/src/tir/transforms/lower_custom_datatypes.cc
index ae9584f..a0faa17 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -23,6 +23,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
@@ -50,7 +51,6 @@ class CustomDatatypesLowerer : public StmtExprMutator {
bool toBeLowered =
datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<CastNode>();
if (toBeLowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code,
src_type_code);
CHECK(lower) << "Cast lowering function for target " << target_ << "
destination type "
@@ -97,6 +97,22 @@ class CustomDatatypesLowerer : public StmtExprMutator {
return expr;
}
+ inline PrimExpr VisitExpr_(const CallNode* call) final {
+ bool toBeLowered =
datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
+ PrimExpr expr = StmtExprMutator::VisitExpr_(call);
+ call = expr.as<CallNode>();
+ if (toBeLowered) {
+ auto op = call->op.as<OpNode>();
+ CHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented";
+ auto lower = datatype::GetIntrinLowerFunc(target_, op->name,
call->dtype.code());
+ CHECK(lower) << "Intrinsic lowering function for target " << target_ <<
", intrinsic name "
+ << op->name << ", type " <<
static_cast<unsigned>(call->dtype.code())
+ << " not found";
+ return (*lower)(expr);
+ }
+ return expr;
+ }
+
#define DEFINE_MUTATE(OP, NodeName)
\
inline PrimExpr VisitExpr_(const NodeName* op) final {
\
auto type_code = op->dtype.code();
\
diff --git a/tests/python/unittest/test_custom_datatypes.py
b/tests/python/unittest/test_custom_datatypes.py
new file mode 100644
index 0000000..337d703
--- /dev/null
+++ b/tests/python/unittest/test_custom_datatypes.py
@@ -0,0 +1,562 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for the Bring Your Own Datatype framework.
+
+TODO(@gussmith23 @hypercubestart) link to documentation"""
+import tvm
+import tvm.topi.testing
+import numpy as np
+import pytest
+from numpy.random import MT19937, RandomState, SeedSequence
+from tvm import relay
+from tvm.relay.testing.layers import batch_norm_infer
+from tvm.target.datatype import (
+ register,
+ register_min_func,
+ register_op,
+ create_lower_func,
+ lower_ite,
+ lower_call_pure_extern,
+ create_min_lower_func,
+)
+from tvm.tir.op import call_pure_extern
+
+# note: we can't use relay.testing models because params are randomly
initialized,
+# which lead the output to have the same values
+# get mobilenet model from Gluon CV
+# because:
https://discuss.tvm.apache.org/t/mobilenet-intermediate-values-are-0/7812
+def get_mobilenet():
+ dshape = (1, 3, 224, 224)
+ from mxnet.gluon.model_zoo.vision import get_model
+
+ block = get_model("mobilenet0.25", pretrained=True)
+ shape_dict = {"data": dshape}
+ return relay.frontend.from_mxnet(block, shape_dict)
+
+
+# use real image instead of random data for end-to-end model training
+# or else output would all be around the same value
+def get_cat_image(dimensions):
+ from tvm.contrib.download import download_testdata
+ from PIL import Image
+
+ url =
"https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png"
+ dst = "cat.png"
+ real_dst = download_testdata(url, dst, module="data")
+ img = Image.open(real_dst).resize(dimensions)
+ # CoreML's standard model image format is BGR
+ img_bgr = np.array(img)[:, :, ::-1]
+ img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
+ return np.asarray(img, dtype="float32")
+
+
+# we use a random seed to generate input_data
+# to guarantee stable tests
+rs = RandomState(MT19937(SeedSequence(123456789)))
+
+
+def convert_ndarray(dst_dtype, array):
+ """Converts NDArray(s) into the specified datatype"""
+ x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
+ cast = relay.Function([x], x.astype(dst_dtype))
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ return relay.create_executor("graph").evaluate(cast)(array)
+
+
+def change_dtype(src, dst, module, params):
+ """Convert constants and functions in module from src type to dst type.
+ Returns changed module and converted params of type dst_type.
+ """
+ module = relay.frontend.ChangeDatatype(src, dst)(module)
+ module = relay.transform.InferType()(module)
+ params = {k: convert_ndarray(dst, v) for k, v in params.items()}
+ return module, params
+
+
+def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={},
target="llvm"):
+ module = relay.transform.SimplifyInference()(module)
+ ex = relay.create_executor("graph", mod=module)
+
+ correct = ex.evaluate()(*input, **params)
+ module, converted_params = change_dtype(src_dtype, dst_dtype, module,
params)
+ ex = relay.create_executor("graph", mod=module, target=target)
+ # converts all inputs to dst_dtype
+ x_converted = [convert_ndarray(dst_dtype, arr) for arr in input]
+
+ # Vectorization is not implemented with custom datatypes
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ maybe_correct = ex.evaluate()(*x_converted, **converted_params)
+ # currently this only works for comparing single output
+ maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct)
+ np.testing.assert_allclose(
+ maybe_correct_converted.asnumpy(), correct.asnumpy(), rtol=rtol,
atol=atol
+ )
+
+
+def setup_myfloat():
+ """Set up tests for myfloat (a custom datatype that under the hood is
float)
+
+ Currently, this registers some custom datatypes using the Bring Your
+ Own Datatypes framework.
+ """
+
+ # To use datatype operations in an external library, you should first load
+ # the library containing the datatype implementation:
+ # CDLL("libposit.so", RTLD_GLOBAL)
+ # In this case, the datatype library we are using is built right into TVM,
+ # so we do not need to explicitly load any library.
+
+ # You can pick a code for your datatype arbitrarily, as long as it is
+ # greater than 128 and has not already been chosen.
+ register("myfloat", 131)
+
+ register_op(
+ create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm",
"float", "myfloat"
+ )
+ register_op(
+ create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm",
"myfloat", "float"
+ )
+ register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm",
"myfloat")
+ register_op(
+ create_lower_func(
+ {
+ 32: "Custom32Sub",
+ }
+ ),
+ "Sub",
+ "llvm",
+ "myfloat",
+ )
+ register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm",
"myfloat")
+ register_op(
+ create_lower_func(
+ {
+ 32: "FloatToCustom32",
+ }
+ ),
+ "FloatImm",
+ "llvm",
+ "myfloat",
+ )
+ register_op(
+ create_lower_func(
+ {
+ 32: "Custom32Div",
+ }
+ ),
+ "Div",
+ "llvm",
+ "myfloat",
+ )
+ register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm",
"myfloat")
+ register_op(
+ create_lower_func({32: "Custom32Sqrt"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.sqrt",
+ )
+ register_op(
+ create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat",
intrinsic_name="tir.exp"
+ )
+ register_op(
+ create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat",
intrinsic_name="tir.log"
+ )
+ register_op(
+ create_lower_func({32: "Custom32Sigmoid"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.sigmoid",
+ )
+ register_op(
+ create_lower_func({32: "Custom32Tanh"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.tanh",
+ )
+ register_op(lower_ite, "Call", "llvm", "myfloat",
intrinsic_name="tir.if_then_else")
+ register_op(
+ lower_call_pure_extern, "Call", "llvm", "myfloat",
intrinsic_name="tir.call_pure_extern"
+ )
+
+ register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"),
"myfloat")
+
+
+def setup_posites2():
+ """Set up tests for posites2
+ Currently, this registers some custom datatypes using the Bring Your
+ Own Datatypes framework.
+ """
+
+ # To use datatype operations in an external library, you should first load
+ # the library containing the datatype implementation:
+ # CDLL("libposit.so", RTLD_GLOBAL)
+ # In this case, the datatype library we are using is built right into TVM,
+ # so we do not need to explicitly load any library.
+
+ # You can pick a code for your datatype arbitrarily, as long as it is
+ # greater than 128 and has not already been chosen.
+
+ register("posites2", 132)
+
+ register_op(
+ create_lower_func(
+ {
+ (32, 32): "FloatToPosit32es2",
+ (32, 16): "FloatToPosit16es2",
+ (32, 8): "FloatToPosit8es2",
+ }
+ ),
+ "Cast",
+ "llvm",
+ "float",
+ "posites2",
+ )
+ register_op(
+ create_lower_func(
+ {
+ (32, 32): "Posit32es2ToFloat",
+ (16, 32): "Posit16es2ToFloat",
+ (8, 32): "Posit8es2ToFloat",
+ }
+ ),
+ "Cast",
+ "llvm",
+ "posites2",
+ "float",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Add", 16: "Posit16es2Add", 8:
"Posit8es2Add"}),
+ "Add",
+ "llvm",
+ "posites2",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Sub", 16: "Posit16es2Sub", 8:
"Posit8es2Sub"}),
+ "Sub",
+ "llvm",
+ "posites2",
+ )
+ register_op(
+ create_lower_func(
+ {32: "FloatToPosit32es2", 16: "FloatToPosit16es2", 8:
"FloatToPosit8es2"}
+ ),
+ "FloatImm",
+ "llvm",
+ "posites2",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Mul", 16: "Posit16es2Mul", 8:
"Posit8es2Mul"}),
+ "Mul",
+ "llvm",
+ "posites2",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Div", 16: "Posit16es2Div", 8:
"Posit8es2Div"}),
+ "Div",
+ "llvm",
+ "posites2",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Max", 16: "Posit16es2Max", 8:
"Posit8es2Max"}),
+ "Max",
+ "llvm",
+ "posites2",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Sqrt", 16: "Posit16es2Sqrt", 8:
"Posit8es2Sqrt"}),
+ "Call",
+ "llvm",
+ "posites2",
+ intrinsic_name="tir.sqrt",
+ )
+ register_op(lower_ite, "Call", "llvm", "posites2",
intrinsic_name="tir.if_then_else")
+ register_op(
+ lower_call_pure_extern, "Call", "llvm", "posites2",
intrinsic_name="tir.call_pure_extern"
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Exp", 16: "Posit16es2Exp", 8:
"Posit8es2Exp"}),
+ "Call",
+ "llvm",
+ "posites2",
+ intrinsic_name="tir.exp",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Log", 16: "Posit16es2Log", 8:
"Posit8es2Log"}),
+ "Call",
+ "llvm",
+ "posites2",
+ intrinsic_name="tir.log",
+ )
+ register_op(
+ create_lower_func(
+ {32: "Posit32es2Sigmoid", 16: "Posit16es2Sigmoid", 8:
"Posit8es2Sigmoid"}
+ ),
+ "Call",
+ "llvm",
+ "posites2",
+ intrinsic_name="tir.sigmoid",
+ )
+ register_op(
+ create_lower_func({32: "Posit32es2Tanh", 16: "Posit16es2Tanh", 8:
"Posit8es2Tanh"}),
+ "Call",
+ "llvm",
+ "posites2",
+ intrinsic_name="tir.tanh",
+ )
+
+ register_min_func(
+ create_min_lower_func(
+ {32: "MinPosit32es2", 16: "MinPosit16es2", 8: "MinPosit8es2"},
"posites2"
+ ),
+ "posites2",
+ )
+
+
+def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7):
+ """Run the same op, but with two different datatypes"""
+ # used for unary ops, first shape in binary ops
+ shape1 = (5, 10, 5)
+ # second shape for binary ops
+ shape2 = (5,)
+
+ def check_unary_op(op, src_dtype, dst_dtype, shape):
+ t1 = relay.TensorType(shape, src_dtype)
+ x = relay.var("x", t1)
+ z = op(x)
+ x_data = rs.rand(*shape).astype(t1.dtype)
+
+ module = tvm.IRModule.from_expr(relay.Function([x], z))
+
+ compare(module, (x_data,), src_dtype, dst_dtype, rtol, atol)
+
+ # test unary ops
+ for op in [
+ relay.nn.softmax,
+ tvm.relay.log,
+ tvm.relay.exp,
+ tvm.relay.sqrt,
+ tvm.relay.rsqrt,
+ tvm.relay.sigmoid,
+ tvm.relay.tanh,
+ relay.nn.relu,
+ relay.nn.batch_flatten,
+ ]:
+ check_unary_op(op, src_dtype, dst_dtype, shape1)
+
+ # test unary ops over 4d data
+ for op in [relay.nn.max_pool2d, relay.nn.avg_pool2d,
relay.nn.global_avg_pool2d]:
+ shape_2d = (3, 32, 32, 32)
+ check_unary_op(op, src_dtype, dst_dtype, shape_2d)
+
+ def check_binary_op(opfunc, src_dtype, dst_dtype):
+ t1 = relay.TensorType(shape1, src_dtype)
+ t2 = relay.TensorType(shape2, src_dtype)
+ x = relay.var("x", t1)
+ y = relay.var("y", t2)
+ z = opfunc(x, y)
+ x_data = rs.rand(*shape1).astype(t1.dtype)
+ y_data = rs.rand(*shape2).astype(t2.dtype)
+ module = tvm.IRModule.from_expr(relay.Function([x, y], z))
+
+ compare(module, (x_data, y_data), src_dtype, dst_dtype, rtol, atol)
+
+ for op in [
+ relay.add,
+ relay.subtract,
+ relay.divide,
+ relay.multiply,
+ ]:
+ check_binary_op(op, src_dtype, dst_dtype)
+
+ # we would like to test tvm_if_then_else
+ # but Relay.IfNode is not lowered to this intrinsic,
+ # so to keep our tests consistent with relay, we decide to not unit test
+ # Note: tvm_if_then_else is tested as part of the mobile_net model
+
+
+def run_model(get_workload, input, src_dtype, dst_dtype, rtol=1e-4, atol=1e-4):
+ module, params = get_workload()
+
+ # we don't generate random data here
+ # because then the output data would all be around the same value
+ compare(module, input, src_dtype, dst_dtype, rtol, atol, params)
+
+
+def run_conv2d(src_dtype, dst_dtype, rtol=1e-7, atol=1e-4):
+ def run_test_conv2d(
+ src_dtype,
+ dst_dtype,
+ scale,
+ dshape,
+ kshape,
+ padding=(1, 1),
+ groups=1,
+ dilation=(1, 1),
+ **attrs,
+ ):
+ x = relay.var("x", shape=dshape, dtype=src_dtype)
+ w = relay.var("w", shape=kshape, dtype=src_dtype)
+ y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation,
groups=groups, **attrs)
+ module = tvm.IRModule.from_expr(relay.Function([x, w], y))
+ data = rs.uniform(-scale, scale, size=dshape).astype(src_dtype)
+ kernel = rs.uniform(-scale, scale, size=kshape).astype(src_dtype)
+
+ compare(module, (data, kernel), src_dtype, dst_dtype, rtol, atol)
+
+ # depthwise conv2d
+ dshape = (1, 32, 18, 18)
+ kshape = (32, 1, 3, 3)
+ run_test_conv2d(
+ src_dtype,
+ dst_dtype,
+ 1,
+ dshape,
+ kshape,
+ padding=(1, 1),
+ channels=32,
+ groups=32,
+ kernel_size=(3, 3),
+ )
+
+ # CUDA is disabled for 'direct' schedule:
+ # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553
+ # group conv2d
+ dshape = (1, 32, 18, 18)
+ kshape = (32, 4, 3, 3)
+ run_test_conv2d(
+ src_dtype,
+ dst_dtype,
+ 1,
+ dshape,
+ kshape,
+ padding=(1, 1),
+ channels=32,
+ groups=8,
+ kernel_size=(3, 3),
+ )
+ # also group conv2d
+ dshape = (1, 32, 18, 18)
+ kshape = (64, 1, 3, 3)
+ run_test_conv2d(
+ src_dtype,
+ dst_dtype,
+ 1,
+ dshape,
+ kshape,
+ padding=(1, 1),
+ channels=64,
+ groups=32,
+ kernel_size=(3, 3),
+ )
+
+ # normal conv2d
+ dshape = (1, 3, 224, 224)
+ kshape = (10, 3, 3, 3)
+ run_test_conv2d(
+ src_dtype, dst_dtype, 1, dshape, kshape, padding=(1, 1), channels=10,
kernel_size=(3, 3)
+ )
+
+ # dilated conv2d
+ dshape = (1, 3, 18, 18)
+ kshape = (10, 3, 3, 3)
+ run_test_conv2d(
+ src_dtype,
+ dst_dtype,
+ 1,
+ dshape,
+ kshape,
+ padding=(1, 1),
+ channels=10,
+ kernel_size=(3, 3),
+ dilation=(3, 3),
+ )
+
+
+def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6):
+ shape = (3, 32, 32)
+ t = relay.TensorType(shape, src_dtype)
+ x = relay.var("x", t)
+ bn = batch_norm_infer(data=x, epsilon=2e-5, scale=False, name="bn_x")
+ f = relay.Function(relay.analysis.free_vars(bn), bn)
+
+ x_data = rs.rand(*shape).astype(t.dtype)
+ module = tvm.IRModule.from_expr(f)
+
+ zero_data = np.zeros((32), "float32")
+ compare(
+ module,
+ (x_data, zero_data, zero_data, zero_data, zero_data),
+ src_dtype,
+ dst_dtype,
+ rtol,
+ atol,
+ )
+
+
+def test_myfloat():
+ setup_myfloat()
+ run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
+ run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
+ run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
+
+ # mxnet python package not available
+ # run_model(get_mobilenet, (get_cat_image((224, 224)), ),
+ # 'float32',
+ # 'custom[myfloat]32')
+
+
+def _has_posit():
+ return tvm.support.libinfo()["USE_BYOC_POSIT"] == "ON"
+
+
[email protected](not _has_posit(), reason="compiled with USE_BYOC_POSIT
flag OFF")
+def test_posites2():
+ setup_posites2()
+ run_ops("float32", "custom[posites2]8", rtol=1, atol=1)
+ run_ops("float32", "custom[posites2]16", rtol=0.01, atol=1)
+ run_ops("float32", "custom[posites2]32", rtol=1e-6, atol=1e-6)
+
+ run_conv2d("float32", "custom[posites2]8", rtol=1, atol=1)
+ run_conv2d("float32", "custom[posites2]16", rtol=0.01, atol=1)
+ run_conv2d("float32", "custom[posites2]32")
+
+ run_batchnorm("float32", "custom[posites2]8", rtol=1, atol=1)
+ run_batchnorm("float32", "custom[posites2]16", rtol=0.01, atol=1)
+ run_batchnorm("float32", "custom[posites2]32", rtol=1e-4, atol=1e-4)
+ # Expected posit8 might be faster, but it's not.
+ # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32',
'custom[posit8]8')
+ # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32',
'custom[posit32]32')
+ # run_model(get_inception, (get_cat_image((229, 229)), ), 'float32',
'custom[posit32]32')
+ # run_model(get_resnet, (get_cat_image((224, 224)), ), 'float32',
'custom[posit32]32')
+
+ # can't run cifar-10 sizes because dimensions
+ # don't match pretrained weights
+
+ # runs on the order of minutes...
+ # run_model(get_inception, (get_cat_image((229, 229)), ),
+ # 'float32',
+ # 'custom[posites2]32')
+ # run_model(get_resnet, (get_cat_image((224, 224)), ),
+ # 'float32',
+ # 'custom[posites2]32')
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/python/unittest/test_target_custom_datatypes.py
b/tests/python/unittest/test_target_custom_datatypes.py
deleted file mode 100644
index 9b2b85d..0000000
--- a/tests/python/unittest/test_target_custom_datatypes.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import tvm
-from tvm import te
-from ctypes import *
-from tvm import topi
-import numpy as np
-
-tgt = "llvm"
-
-
-def setup_module():
- # You must first load the library containing the datatype implementation.
- # In this case, we have built the test functions used below right into TVM.
- # CDLL("libmybfloat16.so", RTLD_GLOBAL)
-
- tvm.target.datatype.register("bfloat", 129)
-
- tvm.target.datatype.register_op(
- tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"),
- "Cast",
- "llvm",
- "bfloat",
- "float",
- )
- tvm.target.datatype.register_op(
- tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"),
- "Cast",
- "llvm",
- "float",
- "bfloat",
- )
- tvm.target.datatype.register_op(
- tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add",
"llvm", "bfloat"
- )
- tvm.target.datatype.register_op(
- tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"),
- "FloatImm",
- "llvm",
- "bfloat",
- )
-
-
-def lower_datatypes_and_build(schedule, args):
- """Create schedule and lower, manually lowering datatypes.
-
- Once datatype lowering is integrated directly into TVM's lower/build
- process, we won't need to do this manually.
- TODO(gus) integrate datatype lowering into build process; change this
test"""
- mod = tvm.lower(schedule, args)
- target = tvm.target.Target(tgt)
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
- mod = tvm.tir.transform.LowerCustomDatatypes()(mod)
- return tvm.build(mod, target=tgt)
-
-
-def test_bfloat_add_and_cast_1():
- X = te.placeholder((3,), name="X")
- Y = te.placeholder((3,), name="Y")
- Z = topi.cast(
- topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y,
dtype="custom[bfloat]16"),
- dtype="float",
- )
-
- s = te.create_schedule([Z.op])
- built_cast = lower_datatypes_and_build(s, [X, Y, Z])
-
- ctx = tvm.context(tgt, 0)
-
- # Used float32 calculator at http://www.weitz.de/ieee/. Generated float32s
- # with at most 7-bit mantissas which, when added, produce a result with at
- # most 7-bit mantissas. This is to ensure there are no errors due to
- # float32->bfloat16 conversions.
- x = tvm.nd.array(np.array([4.4103796e-32, 14942208.0,
1.78125]).astype("float32"), ctx=ctx)
- y = tvm.nd.array(np.array([-3.330669e-14, 19660800.0,
2.25]).astype("float32"), ctx=ctx)
- z_expected = np.array([-3.330669e-14, 34603008.0,
4.03125]).astype("float32")
- z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
-
- built_cast(x, y, z)
-
- assert np.array_equal(z_expected, z.asnumpy())
-
-
-def test_bfloat_add_and_cast_2():
- X = te.placeholder((3,), name="X")
- Y = te.placeholder((3,), name="Y")
- Z = topi.cast(
- topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y,
dtype="custom[bfloat]16"),
- dtype="float",
- )
-
- s = te.create_schedule([Z.op])
- built_cast = lower_datatypes_and_build(s, [X, Y, Z])
-
- ctx = tvm.context(tgt, 0)
-
- # Used float32 calculator at http://www.weitz.de/ieee/. Generated
- # unconstrained float32s for the operands and copied them in to x and y.
- # Then, to simulate float32->bfloat16 conversion implemented by the
mybfloat
- # library, I cut off all but 7 bits of the mantissa. I then added the
- # numbers. To simulate bfloat16 add implemented in mybfloat, I cut off all
- # but 7 bits of the result's mantissa. I then copied that value into
- # z_expected.
- x = tvm.nd.array(np.array([1.2348297, -1.0298302e25,
1.2034023e-30]).astype("float32"), ctx=ctx)
- y = tvm.nd.array(np.array([-2.4992788, -9.888288e19,
9.342338e-29]).astype("float32"), ctx=ctx)
- z_expected = np.array([-1.25, -1.027587e25,
9.426888e-29]).astype("float32")
- z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
-
- built_cast(x, y, z)
-
- assert np.array_equal(z_expected, z.asnumpy())
-
-
-def test_bfloat_add_and_cast_FloatImm():
- X = te.placeholder((3,), name="X")
- Z = topi.cast(
- topi.add(topi.cast(X, dtype="custom[bfloat]16"),
tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
- dtype="float",
- )
-
- s = te.create_schedule([Z.op])
- built_cast = lower_datatypes_and_build(s, [X, Z])
-
- ctx = tvm.context(tgt, 0)
-
- x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx)
- z_expected = np.array([1.5, 2.5, 3.0]).astype("float32")
- z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
-
- built_cast(x, z)
-
- assert np.array_equal(z_expected, z.asnumpy())
-
-
-if __name__ == "__main__":
- setup_module()
- test_bfloat_add_and_cast_1()
- test_bfloat_add_and_cast_2()
- test_bfloat_add_and_cast_FloatImm()
diff --git a/tutorials/dev/bring_your_own_datatypes.py
b/tutorials/dev/bring_your_own_datatypes.py
new file mode 100644
index 0000000..cbb1b99
--- /dev/null
+++ b/tutorials/dev/bring_your_own_datatypes.py
@@ -0,0 +1,408 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Bring Your Own Datatypes to TVM
+===============================
+**Authors**: `Gus Smith <https://github.com/gussmith23>`_, `Andrew Liu
<https://github.com/hypercubestart>`_
+
+In this tutorial, we will show you how to utilize the Bring Your Own Datatypes
framework to use your own custom datatypes in TVM.
+Note that the Bring Your Own Datatypes framework currently only handles
**software emulated versions of datatypes**.
+The framework does not support compiling for custom accelerator datatypes
out-of-the-box.
+
+Datatype Libraries
+------------------
+
+The Bring Your Own Datatypes allows users to register their own datatype
implementations alongside TVM's native datatypes (such as ``float``).
+In the wild, these datatype implementations often appear as libraries.
+For example:
+
+- `libposit <https://github.com/cjdelisle/libposit>`_, a posit library
+- `Stillwater Universal <https://github.com/stillwater-sc/universal>`_, a
library with posits, fixed-point numbers, and other types
+- `SoftFloat <https://github.com/ucb-bar/berkeley-softfloat-3>`_, Berkeley's
software implementation of IEEE 754 floating-point
+
+The Bring Your Own Datatypes enables users to plug these datatype
implementations into TVM!
+
+In this section, we will use an example library we have already implemented,
located at ``3rdparty/byodt/myfloat.cc``.
+This datatype, which we dubbed "myfloat", is really just a IEE-754 float
under-the-hood, but it serves a useful example
+to show that any datatype can be used in the BYODT framework.
+
+Setup
+-----
+
+Since we do not use any 3rdparty library, there is no setup needed.
+
+If you would like to try this with your own datatype library, first bring the
library's functions into the process space with ``CDLL``:
+
+.. code-block :: python
+
+ ctypes.CDLL('my-datatype-lib.so', ctypes.RTLD_GLOBAL)
+"""
+
+######################
+# A Simple TVM Program
+# --------------------
+#
+# We'll begin by writing a simple program in TVM; afterwards, we will re-write
it to use custom datatypes.
+import tvm
+from tvm import relay
+
+# Our basic program: Z = X + Y
+x = relay.var("x", shape=(3,), dtype="float32")
+y = relay.var("y", shape=(3,), dtype="float32")
+z = x + y
+program = relay.Function([x, y], z)
+module = tvm.IRModule.from_expr(program)
+
+######################################################################
+# Now, we create random inputs to feed into this program using numpy:
+
+import numpy as np
+
+np.random.seed(23) # for reproducibility
+
+x_input = np.random.rand(3).astype("float32")
+y_input = np.random.rand(3).astype("float32")
+print("x: {}".format(x_input))
+print("y: {}".format(y_input))
+
+######################################################################
+# Finally, we're ready to run the program:
+
+ex = relay.create_executor(mod=module)
+
+z_output = ex.evaluate()(x_input, y_input)
+print("z: {}".format(z_output))
+
+######################################################################
+# Adding Custom Datatypes
+# -----------------------
+# Now, we will do the same, but we will use a custom datatype for our
intermediate computation.
+#
+# We use the same input variables ``x`` and ``y`` as above, but before adding
``x + y``, we first cast both ``x`` and ``y`` to a custom datatype via the
``relay.cast(...)`` call.
+#
+# Note how we specify the custom datatype: we indicate it using the special
``custom[...]`` syntax.
+# Additionally, note the "32" after the datatype: this is the bitwidth of the
custom datatype. This tells TVM that each instance of ``myfloat`` is 32 bits
wide.
+
+try:
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ x_myfloat = relay.cast(x, dtype="custom[myfloat]32")
+ y_myfloat = relay.cast(y, dtype="custom[myfloat]32")
+ z_myfloat = x_myfloat + y_myfloat
+ z = relay.cast(z_myfloat, dtype="float32")
+except tvm.TVMError as e:
+ # Print last line of error
+ print(str(e).split("\n")[-1])
+
+######################################################################
+# Trying to generate this program throws an error from TVM.
+# TVM does not know how to handle any custom datatype out of the box!
+# We first have to register the custom type with TVM, giving it a name and a
type code:
+
+tvm.target.datatype.register("myfloat", 150)
+
+######################################################################
+# Note that the type code, 150, is currently chosen manually by the user.
+# See ``TVMTypeCode::kCustomBegin`` in `include/tvm/runtime/c_runtime_api.h
<https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/data_type.h>`_.
+# Now we can generate our program again:
+
+x_myfloat = relay.cast(x, dtype="custom[myfloat]32")
+y_myfloat = relay.cast(y, dtype="custom[myfloat]32")
+z_myfloat = x_myfloat + y_myfloat
+z = relay.cast(z_myfloat, dtype="float32")
+program = relay.Function([x, y], z)
+module = tvm.IRModule.from_expr(program)
+
+######################################################################
+# Now we have a Relay program that uses myfloat!
+print(program)
+
+######################################################################
+# Now that we can express our program without errors, let's try running it!
+try:
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ ex = relay.create_executor("graph", mod=module)
+ z_output_myfloat = ex.evaluate()(x_input, y_input)
+ print("z: {}".format(y_myfloat))
+except tvm.TVMError as e:
+ # Print last line of error
+ print(str(e).split("\n")[-1])
+
+######################################################################
+# Now, trying to compile this program throws an error.
+# Let's dissect this error.
+#
+# The error is occurring during the process of lowering the custom datatype
code to code that TVM can compile and run.
+# TVM is telling us that it cannot find a *lowering function* for the ``Cast``
operation, when casting from source type 2 (``float``, in TVM), to destination
type 150 (our custom datatype).
+# When lowering custom datatypes, if TVM encounters an operation over a custom
datatype, it looks for a user-registered *lowering function*, which tells it
how to lower the operation to an operation over datatypes it understands.
+# We have not told TVM how to lower ``Cast`` operations for our custom
datatypes; thus, the source of this error.
+#
+# To fix this error, we simply need to specify a lowering function:
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func(
+ {
+ (32, 32): "FloatToCustom32", # cast from float32 to myfloat32
+ }
+ ),
+ "Cast",
+ "llvm",
+ "float",
+ "myfloat",
+)
+
+######################################################################
+# The ``register_op(...)`` call takes a lowering function, and a number of
parameters which specify exactly the operation which should be lowered with the
provided lowering function.
+# In this case, the arguments we pass specify that this lowering function is
for lowering a ``Cast`` from ``float`` to ``myfloat`` for target ``"llvm"``.
+#
+# The lowering function passed into this call is very general: it should take
an operation of the specified type (in this case, `Cast`) and return another
operation which only uses datatypes which TVM understands.
+#
+# In the general case, we expect users to implement operations over their
custom datatypes using calls to an external library.
+# In our example, our ``myfloat`` library implements a ``Cast`` from ``float``
to 32-bit ``myfloat`` in the function ``FloatToCustom32``.
+# To provide for the general case, we have made a helper function,
``create_lower_func(...)``,
+# which does just this: given a dictionary, it replaces the given operation
with a ``Call`` to the appropriate function name provided based on the op and
the bit widths.
+# It additionally removes usages of the custom datatype by storing the custom
datatype in an opaque ``uint`` of the appropriate width; in our case, a
``uint32_t``.
+# For more information, see `the source code
<https://github.com/apache/incubator-tvm/blob/master/python/tvm/target/datatype.py>`_.
+
+# We can now re-try running the program:
+try:
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ ex = relay.create_executor("graph", mod=module)
+ z_output_myfloat = ex.evaluate()(x_input, y_input)
+ print("z: {}".format(z_output_myfloat))
+except tvm.TVMError as e:
+ # Print last line of error
+ print(str(e).split("\n")[-1])
+
+######################################################################
+# This new error tells us that the ``Add`` lowering function is not found,
which is good news, as it's no longer complaining about the ``Cast``!
+# We know what to do from here: we just need to register the lowering
functions for the other operations in our program.
+#
+# Note that for ``Add``, ``create_lower_func`` takes in a dict where the key
is an integer.
+# For ``Cast`` operations, we require a 2-tuple to specify the
``src_bit_length`` and the ``dest_bit_length``,
+# while for all other operations, the bit length is the same between the
operands so we only require one integer to specify ``bit_length``.
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Add"}),
+ "Add",
+ "llvm",
+ "myfloat",
+)
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({(32, 32): "Custom32ToFloat"}),
+ "Cast",
+ "llvm",
+ "myfloat",
+ "float",
+)
+
+# Now, we can run our program without errors.
+with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ compiled = ex.evaluate(program)
+ z_output_myfloat = compiled(x_input, y_input)
+print("z: {}".format(z_output_myfloat))
+
+print("x:\t\t{}".format(x_input))
+print("y:\t\t{}".format(y_input))
+print("z (float32):\t{}".format(z_output))
+print("z (myfloat32):\t{}".format(z_output_myfloat))
+
+# Perhaps as expected, the ``myfloat32`` results and ``float32`` are exactly
the same!
+
+######################################################################
+# Running Models With Custom Datatypes
+# ------------------------------------
+#
+# We will first choose the model which we would like to run with myfloat.
+# In this case we use `Mobilenet <https://arxiv.org/abs/1704.04861>`_.
+# We choose Mobilenet due to its small size.
+# In this alpha state of the Bring Your Own Datatypes framework, we have not
implemented any software optimizations for running software emulations of
custom datatypes; the result is poor performance due to many calls into our
datatype emulation library.
+#
+# First let us define two helper functions to get the mobilenet model and a
cat image.
+
+
+def get_mobilenet():
+ dshape = (1, 3, 224, 224)
+ from mxnet.gluon.model_zoo.vision import get_model
+
+ block = get_model("mobilenet0.25", pretrained=True)
+ shape_dict = {"data": dshape}
+ return relay.frontend.from_mxnet(block, shape_dict)
+
+
+def get_cat_image():
+ from tvm.contrib.download import download_testdata
+ from PIL import Image
+
+ url =
"https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png"
+ dst = "cat.png"
+ real_dst = download_testdata(url, dst, module="data")
+ img = Image.open(real_dst).resize((224, 224))
+ # CoreML's standard model image format is BGR
+ img_bgr = np.array(img)[:, :, ::-1]
+ img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
+ return np.asarray(img, dtype="float32")
+
+
+module, params = get_mobilenet()
+
+######################################################################
+# It's easy to execute MobileNet with native TVM:
+
+ex = tvm.relay.create_executor("graph", mod=module)
+input = get_cat_image()
+result = ex.evaluate()(input, **params).asnumpy()
+# print first 10 elements
+print(result.flatten()[:10])
+
+######################################################################
+# Now, we would like to change the model to use myfloat internally. To do so,
we need to convert the network. To do this, we first define a function which
will help us convert tensors:
+
+
+def convert_ndarray(dst_dtype, array):
+ """Converts an NDArray into the specified datatype"""
+ x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
+ cast = relay.Function([x], x.astype(dst_dtype))
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ return relay.create_executor("graph").evaluate(cast)(array)
+
+
+######################################################################
+# Now, to actually convert the entire network, we have written `a pass in
Relay
<https://github.com/gussmith23/tvm/blob/ea174c01c54a2529e19ca71e125f5884e728da6e/python/tvm/relay/frontend/change_datatype.py#L21>`_
which simply converts all nodes within the model to use the new datatype.
+
+from tvm.relay.frontend.change_datatype import ChangeDatatype
+
+src_dtype = "float32"
+dst_dtype = "custom[myfloat]32"
+
+# Currently, custom datatypes only work if you run simplify_inference
beforehand
+module = tvm.relay.transform.SimplifyInference()(module)
+
+# Run type inference before changing datatype
+module = tvm.relay.transform.InferType()(module)
+
+# Change datatype from float to myfloat and re-infer types
+cdtype = ChangeDatatype(src_dtype, dst_dtype)
+expr = cdtype.visit(module["main"])
+module = tvm.relay.transform.InferType()(module)
+
+# We also convert the parameters:
+params = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}
+
+# We also need to convert our input:
+input = convert_ndarray(dst_dtype, input)
+
+# Finally, we can try to run the converted model:
+try:
+ # Vectorization is not implemented with custom datatypes.
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ result_myfloat = ex.evaluate(expr)(input, **params)
+except tvm.TVMError as e:
+ print(str(e).split("\n")[-1])
+
+######################################################################
+# When we attempt to run the model, we get a familiar error telling us that
more funcions need to be registerd for myfloat.
+#
+# Because this is a neural network, many more operations are required.
+# Here, we register all the needed functions:
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "FloatToCustom32"}),
+ "FloatImm",
+ "llvm",
+ "myfloat",
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.lower_ite, "Call", "llvm", "myfloat",
intrinsic_name="tir.if_then_else"
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.lower_call_pure_extern,
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.call_pure_extern",
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Mul"}),
+ "Mul",
+ "llvm",
+ "myfloat",
+)
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Div"}),
+ "Div",
+ "llvm",
+ "myfloat",
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Sqrt"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.sqrt",
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Sub"}),
+ "Sub",
+ "llvm",
+ "myfloat",
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Exp"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.exp",
+)
+
+tvm.target.datatype.register_op(
+ tvm.target.datatype.create_lower_func({32: "Custom32Max"}),
+ "Max",
+ "llvm",
+ "myfloat",
+)
+
+tvm.target.datatype.register_min_func(
+ tvm.target.datatype.create_min_lower_func({32: "MinCustom32"}, "myfloat"),
+ "myfloat",
+)
+
+######################################################################
+# Note we are making use of two new functions: ``register_min_func`` and
``create_min_lower_func``.
+#
+# ``register_min_func`` takes in an integer ``num_bits`` for the bit length,
and should return an operation
+# representing the minimum finite representable value for the custom data type
with the specified bit length.
+#
+# Similar to ``register_op`` and ``create_lower_func``, the
``create_min_lower_func`` handles the general case
+# where the minimum representable custom datatype value is implemented using
calls to an external library.
+#
+# Now we can finally run the model:
+
+# Vectorization is not implemented with custom datatypes.
+with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ result_myfloat = ex.evaluate(expr)(input, **params)
+ result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
+ # print first 10 elements
+ print(result_myfloat.flatten()[:10])
+
+# Again, note that the output using 32-bit myfloat exactly the same as 32-bit
floats,
+# because myfloat is exactly a float!
+np.testing.assert_array_equal(result, result_myfloat)