This is an automated email from the ASF dual-hosted git repository.

ziheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 5aafff9  Bring Your Own Datatypes (#5812)
5aafff9 is described below

commit 5aafff913b963879f6ea6f24e01533793ea1a68a
Author: Gus Smith <[email protected]>
AuthorDate: Sat Sep 26 15:49:20 2020 -0700

    Bring Your Own Datatypes (#5812)
    
    * Add ChangeDatatype pass and unittest
    
    * [WIP] Jared's work on Fri
    
    This was work that Jared did on my computer, trying to get Inception v3 
running.
    
    * Fix simplify inference to work over different data types.
    
    * Formatting
    
    * Copy setup code from other test file
    
    * Logging in Relay
    
    * Remove duplicate TVM_DLL
    
    * Add Sub, Mul, Div, Max to bfloat lib
    
    * Fix previous broken rebased commit
    
    * Remove line
    
    * Add LowerCustomDatatypes to build passes
    
    * Upcast ints to custom datatypes too, as well as to floats
    
    * Add and use convert_ndarray
    
    * Lower Call
    
    * Relay: create constant scalars of custom dtypes
    
    We use the same method we use in TVM: store the value in a double.
    
    * Custom datatype formatting in Relay
    
    * Update unittests
    
    * Add simpler example that's not working yet
    
    * Add Python unittests to Makefile
    
    * Fix bug
    
    * Fix function name in GetPackedFunc call
    
    * convert_ndarray makes its own executor
    
    * Add simple test case
    
    * Move setup() calls
    
    * Use convert_ndarray
    
    * Change import to make it more specific
    
    * Fix another Registry::Get call
    
    * Allow users to register minimum functions for custom datatypes
    
    This commit allows users to register global functions named
    `tvm.datatype.min.<type name>` which take the number of bits in the custom 
type
    and return the corresponding minimum value (as a double).
    
    A similar commit will need to be created for max, whenever that ends up 
being
    needed!
    
    * Remove check for float
    
    * Add test
    
    * Fix inception test
    
    * Add MobileNet
    
    * Lower custom datatypes before intrinsics
    
    * Add exp and sqrt bfloat functions
    
    * [buggy commit] Lower intrinsics like sqrt, exp
    
    This commit has bugs in it, I'm fairly certain.
    
    * Formatting
    
    * Fix bug
    
    * Add lowering for new ops in test
    
    * Add int to bfloat
    
    * Remove print
    
    * Add all tests
    
    * Correct image size
    
    * Add TODO
    
    * Add "notbfloat" type
    
    This type is for testing purposes. It just stores a float in a uint32. It 
was
    used to confirm the fact that my bfloat "implementation" is very numerically
    unstable and was causing issues when running the model.
    
    * Convert arguments
    
    Not sure how necessary this actually is.
    
    * Rewrite custom datatype constants in Relay
    
    * Add test_ops
    
    * Print constants in Relay
    
    * Use topi.testing
    
    * Test conv2d
    
    * Add test_model
    
    * Comment out model tests
    
    * Register notbfloat
    
    This could be unregistered at some point later
    
    * Add commented code
    
    Remove later
    
    * Add posit tests
    
    * test_ops_same_function
    
    * [temporary] move incomplete commit to macbook
    
    * Add more to tests
    
    * Formatting
    
    * Uncomment add
    
    * Remove bad tests
    
    * Change comments
    
    * Change function name and docstring
    
    * Change main function
    
    * Restructure tests
    
    * Fix visibility of posit functions
    
    * YAPF
    
    * Switching keywords around to resolve build errors on some systems
    
    * Improve test by running smaller mobilenet
    
    * Add test_cast
    
    * Change datatype name; add simple test
    
    * Rename to posit32
    
    * Merge 3 posit types into one file
    
    * Add a nop type
    
    * Remove bfloat
    
    * Refactor test comments
    
    * Refactor conv2d test
    
    * Add optional tolerance arguments
    
    * Add posit8 and posit16
    
    * Add comment about posit8
    
    * Whoops -- actually add noptype to CMakeLists
    
    * Add rtol, atol to run_workload
    
    * Add noptype to tests
    
    * Run noptype over other models, too
    
    * Pass correct arguments to calls
    
    * Fix line length errors
    
    * Raise tolerances (again) to avoid flaky test
    
    * fix style
    
    * add test for tanh, log, sigmoid
    
    * Remove references to bfloat, notbfloat
    
    * Change comments
    
    * Remove old test file
    
    * fix min func
    
    * refactoring unit test file
    
    * use posits es2
    
    * cleanup
    
    * comment
    
    * coment if_then_else
    
    * support different bit widths
    
    * use random seed to create stable tests
    
    * update documentation
    
    * removed nop-type and code consistency
    
    * add batchnorm test
    
    * rebase and update
    
    * fix tests and format
    
    * pylint
    
    * change order of include
    
    * include order
    
    * fix style
    
    * remove posit c linkage
    
    * update universal
    
    * fix style
    
    * fix test
    
    * fix overflow error with minfunc and posits
    
    * style
    
    * use change_dtype to convert params
    
    * update universal
    
    * fix fatal error
    
    * fix constant repr
    
    * minor update to posites2
    
    * update universal
    
    * fix rst
    
    * fix invalid import and sqrt
    
    * update universal
    
    * comments
    
    * comments and expand testing
    
    * increase atol/rtol for custom[posites2]32
    
    * Re-add newline
    
    * Remove comment
    
    * Remove opt level and comment
    
    * Change docstring
    
    * Add TODO
    
    * Add file header and newline
    
    * Update docstring
    
    * Update file docstring
    
    * Update docstrings
    
    * Delete todos
    
    * create_min_lower_func
    
    * add better debugging message
    
    * docs
    
    * add BYODT tutorial
    
    * add todo
    
    * Reformat some of tutorial to RST, plus code fixes
    
    * tutorial notebook runs now
    
    * fix hyperlink
    
    * rebase
    
    * add to tutorial
    
    * fix mobilenet model
    
    * add skip tag
    
    * black lint
    
    * add compiler flag and add dummy float
    
    * myfloat and posites2 test
    
    * remove universal
    
    * lint
    
    * lint
    
    * add setup
    
    * build with USE_POSIT for CI/CD
    
    * fix posit cmake
    
    * add cd /
    
    * undo docker changes
    
    * change tutorial to use myfloat
    
    * move files
    
    * lint
    
    * fix
    
    * remove filter
    
    * fix lint
    
    * fix suggestions
    
    Co-authored-by: Jared Roesch <[email protected]>
    Co-authored-by: Andrew Liu <[email protected]>
---
 3rdparty/bfloat16/bfloat16.cc                      |  84 ---
 CMakeLists.txt                                     |   5 +-
 LICENSE                                            |   1 -
 cmake/config.cmake                                 |   3 +
 cmake/modules/LibInfo.cmake                        |   1 +
 .../modules/contrib/Posit.cmake                    |  28 +-
 licenses/LICENSE.bfloat16.txt                      |   9 -
 python/tvm/driver/build_module.py                  |   2 +
 python/tvm/relay/backend/_backend.py               |   3 +
 python/tvm/relay/frontend/__init__.py              |   1 +
 python/tvm/relay/frontend/change_datatype.py       | 107 ++++
 python/tvm/target/datatype.py                      | 303 +++++++++--
 src/arith/rewrite_simplify.cc                      |   4 +-
 src/driver/driver_api.cc                           |   2 +
 src/relay/backend/utils.h                          |   6 +-
 src/relay/transforms/pattern_util.h                |  77 +--
 src/support/libinfo.cc                             |   5 +
 src/target/datatype/myfloat/myfloat.cc             | 144 ++++++
 src/target/datatype/posit/posit-wrapper.cc         | 238 +++++++++
 src/target/datatype/registry.cc                    |  19 +
 src/target/datatype/registry.h                     |  21 +-
 src/tir/op/op.cc                                   |  17 +-
 src/tir/transforms/lower_custom_datatypes.cc       |  18 +-
 tests/python/unittest/test_custom_datatypes.py     | 562 +++++++++++++++++++++
 .../unittest/test_target_custom_datatypes.py       | 154 ------
 tutorials/dev/bring_your_own_datatypes.py          | 408 +++++++++++++++
 26 files changed, 1872 insertions(+), 350 deletions(-)

diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc
deleted file mode 100644
index 674feb4..0000000
--- a/3rdparty/bfloat16/bfloat16.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
-    Copyright (c) 2019 by Contributors
-   \file tvm/src/codegen/custom_datatypes/mybfloat16.cc
-   \brief Small bfloat16 library for use in unittests
-
-  Code originally from TensorFlow; taken and simplified. Original license:
-
-  Licensed under the Apache License, Version 2.0 (the "License");
-  you may not use this file except in compliance with the License.
-  You may obtain a copy of the License at
-  http://www.apache.org/licenses/LICENSE-2.0
-  Unless required by applicable law or agreed to in writing, software
-  distributed under the License is distributed on an "AS IS" BASIS,
-  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-  See the License for the specific language governing permissions and
-  limitations under the License.
-  
==============================================================================*/
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include <cstddef>
-#include <cstdint>
-
-void FloatToBFloat16(const float* src, uint16_t* dst, size_t size) {
-  const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
-  uint16_t* q = reinterpret_cast<uint16_t*>(dst);
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-  for (; size != 0; p += 2, q++, size--) {
-    *q = p[0];
-  }
-#else
-  for (; size != 0; p += 2, q++, size--) {
-    *q = p[1];
-  }
-#endif
-}
-
-void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) {
-  const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
-  uint16_t* q = reinterpret_cast<uint16_t*>(dst);
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-  for (; size != 0; p++, q += 2, size--) {
-    q[0] = *p;
-    q[1] = 0;
-  }
-#else
-  for (; size != 0; p++, q += 2, size--) {
-    q[0] = 0;
-    q[1] = *p;
-  }
-#endif
-}
-
-void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, size_t 
size) {
-  float a_f, b_f;
-  BFloat16ToFloat(a, &a_f, 1);
-  BFloat16ToFloat(b, &b_f, 1);
-  float out_f = a_f + b_f;
-  FloatToBFloat16(&out_f, dst, 1);
-}
-
-extern "C" {
-TVM_DLL uint16_t FloatToBFloat16_wrapper(float in);
-TVM_DLL float BFloat16ToFloat_wrapper(uint16_t in);
-TVM_DLL uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b);
-
-uint16_t FloatToBFloat16_wrapper(float in) {
-  uint16_t out;
-  FloatToBFloat16(&in, &out, 1);
-  return out;
-}
-
-float BFloat16ToFloat_wrapper(uint16_t in) {
-  float out;
-  BFloat16ToFloat(&in, &out, 1);
-  return out;
-}
-
-uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b) {
-  uint16_t out;
-  BFloat16Add(&a, &b, &out, 1);
-  return out;
-}
-}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a29364b..f3182a4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -56,6 +56,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" 
"3rdparty/compiler-rt")
 tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
 
 # Contrib library options
+tvm_option(USE_BYOC_POSIT "Build with BYOC software emulated posit custom 
datatype" OFF)
 tvm_option(USE_BLAS "The blas library to be linked" none)
 tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
 tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
@@ -257,6 +258,7 @@ endif(USE_VM_PROFILER)
 
 file(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
 list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
+list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
 
 file(GLOB RUNTIME_SRCS
   src/runtime/*.cc
@@ -279,8 +281,6 @@ if (INDEX_DEFAULT_I64)
   add_definitions(-DTVM_INDEX_DEFAULT_I64=1)
 endif()
 
-list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc)
-
 if(USE_RPC)
   message(STATUS "Build with RPC support...")
   file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
@@ -334,6 +334,7 @@ include(cmake/modules/contrib/BLAS.cmake)
 include(cmake/modules/contrib/CODEGENC.cmake)
 include(cmake/modules/contrib/DNNL.cmake)
 include(cmake/modules/contrib/Random.cmake)
+include(cmake/modules/contrib/Posit.cmake)
 include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
 include(cmake/modules/contrib/Sort.cmake)
 include(cmake/modules/contrib/NNPack.cmake)
diff --git a/LICENSE b/LICENSE
index 1314c16..52b2219 100644
--- a/LICENSE
+++ b/LICENSE
@@ -209,7 +209,6 @@ for text of these licenses.
 Apache Software Foundation License 2.0
 --------------------------------------
 
-3rdparty/bfloat16/bfloat16.cc
 3rdparty/dlpack
 3rdparty/dmlc-core
 
diff --git a/cmake/config.cmake b/cmake/config.cmake
index 6ed660c..7e8df55 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -120,6 +120,9 @@ set(USE_LLVM OFF)
 #---------------------------------------------
 # Contrib libraries
 #---------------------------------------------
+# Whether to build with BYOC software emulated posit custom datatype
+set(USE_BYOC_POSIT OFF)
+
 # Whether use BLAS, choices: openblas, atlas, apple
 set(USE_BLAS none)
 
diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake
index e7685f3..53d88ff 100644
--- a/cmake/modules/LibInfo.cmake
+++ b/cmake/modules/LibInfo.cmake
@@ -53,6 +53,7 @@ function(add_lib_info src_file)
     TVM_INFO_HIDE_PRIVATE_SYMBOLS="${HIDE_PRIVATE_SYMBOLS}"
     TVM_INFO_USE_TF_TVMDSOOP="${USE_TF_TVMDSOOP}"
     TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}"
+    TVM_INFO_USE_BYOC_POSIT="${USE_BYOC_POSIT}"
     TVM_INFO_USE_BLAS="${USE_BLAS}"
     TVM_INFO_USE_MKL="${USE_MKL}"
     TVM_INFO_USE_MKLDNN="${USE_MKLDNN}"
diff --git a/python/tvm/relay/frontend/__init__.py 
b/cmake/modules/contrib/Posit.cmake
similarity index 59%
copy from python/tvm/relay/frontend/__init__.py
copy to cmake/modules/contrib/Posit.cmake
index 7154f5a..cd2f2f6 100644
--- a/python/tvm/relay/frontend/__init__.py
+++ b/cmake/modules/contrib/Posit.cmake
@@ -14,23 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""
-Frontends for constructing Relay programs.
 
-Contains the model importers currently defined
-for Relay.
-"""
-
-from __future__ import absolute_import
-
-from .mxnet import from_mxnet
-from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var
-from .keras import from_keras
-from .onnx import from_onnx
-from .tflite import from_tflite
-from .coreml import from_coreml
-from .caffe2 import from_caffe2
-from .tensorflow import from_tensorflow
-from .darknet import from_darknet
-from .pytorch import from_pytorch
-from .caffe import from_caffe
+if(USE_BYOC_POSIT)
+  message(STATUS "Build with contrib.posit")
+  if (NOT UNIVERSAL_PATH)
+    message(FATAL_ERROR "Fail to get Universal path")
+  endif(NOT UNIVERSAL_PATH)
+  
+  include_directories(${UNIVERSAL_PATH}/include)
+  list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc")
+endif(USE_BYOC_POSIT)
diff --git a/licenses/LICENSE.bfloat16.txt b/licenses/LICENSE.bfloat16.txt
deleted file mode 100644
index ce537b4..0000000
--- a/licenses/LICENSE.bfloat16.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
\ No newline at end of file
diff --git a/python/tvm/driver/build_module.py 
b/python/tvm/driver/build_module.py
index 1c11a6c..2b8346d 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -264,6 +264,7 @@ def _build_for_device(input_mod, target, target_host):
             tvm.tir.transform.LowerWarpMemory(),
             tvm.tir.transform.Simplify(),
             tvm.tir.transform.LowerDeviceStorageAccessInfo(),
+            tvm.tir.transform.LowerCustomDatatypes(),
             tvm.tir.transform.LowerIntrin(),
         ]
     )
@@ -279,6 +280,7 @@ def _build_for_device(input_mod, target, target_host):
             tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
             tvm.tir.transform.LowerTVMBuiltin(),
             tvm.tir.transform.LowerDeviceStorageAccessInfo(),
+            tvm.tir.transform.LowerCustomDatatypes(),
             tvm.tir.transform.LowerIntrin(),
             tvm.tir.transform.CombineContextCall(),
         ]
diff --git a/python/tvm/relay/backend/_backend.py 
b/python/tvm/relay/backend/_backend.py
index 641ff04..65b0c0b 100644
--- a/python/tvm/relay/backend/_backend.py
+++ b/python/tvm/relay/backend/_backend.py
@@ -90,6 +90,9 @@ def _tensor_value_repr(tvalue):
 
 @tvm._ffi.register_func("relay._constant_repr")
 def _tensor_constant_repr(tvalue):
+    dtype = tvm.runtime.DataType(tvalue.data.dtype)
+    if tvm.target.datatype.get_type_registered(dtype.type_code):
+        return "custom tensor of type " + dtype.type_code
     return str(tvalue.data.asnumpy())
 
 
diff --git a/python/tvm/relay/frontend/__init__.py 
b/python/tvm/relay/frontend/__init__.py
index 7154f5a..7e16499 100644
--- a/python/tvm/relay/frontend/__init__.py
+++ b/python/tvm/relay/frontend/__init__.py
@@ -34,3 +34,4 @@ from .tensorflow import from_tensorflow
 from .darknet import from_darknet
 from .pytorch import from_pytorch
 from .caffe import from_caffe
+from .change_datatype import ChangeDatatype
diff --git a/python/tvm/relay/frontend/change_datatype.py 
b/python/tvm/relay/frontend/change_datatype.py
new file mode 100644
index 0000000..dc80b3e
--- /dev/null
+++ b/python/tvm/relay/frontend/change_datatype.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Change Datatype Pass"""
+from ..function import Function
+from ..expr_functor import ExprMutator
+from ..transform.transform import function_pass
+from ..expr import var, bind
+
+
+@function_pass(opt_level=0)
+class ChangeDatatype(ExprMutator):
+    """Mutator for changing the datatype of Relay programs.
+
+    This pass should be useful for users of the Bring Your Own Datatypes
+    framework.
+    TODO(@gussmith23 @hypercubestart) Add link to documentation when it exists
+
+    Example:
+
+    .. code-block:: python
+
+        from tvm.relay.testing.inception_v3 import get_workload
+        mod, params = get_workload()
+
+        def change_dtype(mod, params, src, dst):
+            mod = ChangeDatatype(src, dst)(mod)
+            params = dict((p, tvm.nd.array(params[p].asnumpy().astype(dst))) 
for p in params)
+            return mod, params
+
+        mod, params = change_dtype(mod, params, "float32", 
"custom[posites2]32")
+
+    Parameters
+    ----------
+    src : String
+        The source datatype name, e.g. "float" or "posites2" (but not "float32"
+        or "custom[posites2]32").
+    dst : String
+        The destination datatype name, in the same format.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        Module where all nodes of dtype `src` have been changed to have dtype
+        `dst`.
+    """
+
+    def __init__(self, src, dst):
+        self.src = src
+        self.dst = dst
+        super().__init__()
+
+    def transform_function(self, func, mod, ctx):
+        return self.visit(func)
+
+    def visit_constant(self, const):
+        if const.data.dtype == self.src:
+            return const.astype(self.dst)
+        return const
+
+    def visit_function(self, fn):
+        new_params = []
+        binds = {}
+
+        for param in fn.params:
+            # Get the parameter's type annotation.
+            var_type = param.type_annotation
+
+            # See if we want to replace dtype.
+            if var_type.dtype == self.src:
+                dtype = self.dst
+            else:
+                dtype = var_type.dtype
+
+            # Generate new variable.
+            new_param = var(param.name_hint, shape=var_type.shape, dtype=dtype)
+
+            new_params.append(new_param)
+            binds[param] = new_param
+
+        new_body = self.visit(fn.body)
+        # Rewrite the body to use new parameters.
+        new_body = bind(new_body, binds)
+
+        # Construct the updated function and return.
+        return Function(
+            new_params,
+            new_body,
+            # You could change the return type, if you use None it will 
re-infer.
+            None,
+            type_params=fn.type_params,
+            attrs=fn.attrs,
+        )
diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py
index 5d3ca5f..cdd8149 100644
--- a/python/tvm/target/datatype.py
+++ b/python/tvm/target/datatype.py
@@ -14,73 +14,154 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Custom datatype functionality"""
-import tvm._ffi
+"""Bring Your Own Datatypes custom datatype framework
 
-import tvm.runtime._ffi_api
-from tvm.runtime import DataType
-import tvm.tir
-from tvm.tir.expr import Cast as _Cast, FloatImm as _FloatImm
+TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist"""
+import tvm
+from tvm.runtime import convert, DataType
+from tvm.tir.expr import (
+    Call as _Call,
+    Cast as _Cast,
+    FloatImm as _FloatImm,
+    BinaryOpExpr as _BinaryOpExpr,
+)
+from tvm.tir.op import call_pure_extern
+from tvm._ffi import register_func as _register_func
+from tvm.tir import call_intrin
 
 
 def register(type_name, type_code):
     """Register a custom datatype with the given type name and type code
-    Currently, the type code is manually allocated by the user, and the
-    user must ensure that no two custom types share the same code.
-    Generally, this should be straightforward, as the user will be
-    manually registering all of their custom types.
+
+    Currently, the type code is manually allocated by the user, and the user
+    must ensure that no two custom types share the same code. Generally, this
+    should be straightforward, as the user will be manually registering all of
+    their custom types.
+
+    Example:
+
+    .. code-block:: python
+
+        # Register a dtype named 'posites2' under type code 130.
+        tvm.datatype.register('posites2', 130)
+
 
     Parameters
     ----------
     type_name : str
-        The name of the custom datatype
+        The name of the custom datatype.
 
     type_code : int
-        The type's code, which should be >= kCustomBegin
+        The type's code, which should be >= kCustomBegin. See
+        include/tvm/runtime/data_type.h.
     """
     tvm.runtime._ffi_api._datatype_register(type_name, type_code)
 
 
 def get_type_name(type_code):
-    """Get the type name from the type code
+    """Get the type name of a custom datatype from the type code.
+
+    Note that this only works for custom datatypes registered with
+    tvm.datatype.register(). It does not work for TVM-native types.
+
+    Example:
+
+    .. code-block:: python
+
+        tvm.datatype.register('posites2', 130)
+        assert tvm.datatype.get_type_name(130) == 'posites2'
 
     Parameters
     ----------
     type_code : int
-        The type code
+        The type code of the custom datatype.
+
+    Returns
+    -------
+    type_name : String
+        The name of the custom datatype.
+
     """
     return tvm.runtime._ffi_api._datatype_get_type_name(type_code)
 
 
 def get_type_code(type_name):
-    """Get the type code from the type name
+    """Get the type code of a custom datatype from its type name
+
+    Note that this only works for custom datatypes registered with
+    tvm.datatype.register(). It does not work for TVM-native types.
+
+    Example:
+
+    .. code-block:: python
+
+        tvm.datatype.register('posites2', 130)
+        assert tvm.datatype.get_type_code('posites2') == 130
 
     Parameters
     ----------
     type_name : str
         The type name
+
+    Returns
+    -------
+    type_code : int
+        The type code of the custom datatype.
     """
     return tvm.runtime._ffi_api._datatype_get_type_code(type_name)
 
 
 def get_type_registered(type_code):
-    """Get a boolean representing whether the type is registered
+    """Returns true if a custom datatype is registered under the given type 
code
+
+    Example:
+
+    .. code-block:: python
+
+        tvm.datatype.register('posites2', 130)
+        assert tvm.datatype.get_type_registered(130)
 
     Parameters
     ----------
     type_code: int
         The type code
+
+    Returns
+    -------
+    type_registered : bool
+        True if a custom datatype is registered under this type code, and false
+        otherwise.
     """
     return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
 
 
-def register_op(lower_func, op_name, target, type_name, src_type_name=None):
-    """Register an external function which computes the given op.
+def register_op(
+    lower_func, op_name, target, src_type_name, dest_type_name=None, 
intrinsic_name=None
+):
+    """Register a lowering function for a specific operator of a custom 
datatype
+
+    At build time, Relay must lower operators over custom datatypes into
+    operators it understands how to compile. For each custom datatype operator
+    which Relay finds while lowering custom datatypes, Relay expects to find a
+    user-defined lowering function. Users register their user-defined lowering
+    functions using this function.
 
-    Currently, this will only work with Casts and binary expressions
-    whose arguments are named `a` and `b`.
-    TODO(gus) figure out what other special cases must be handled by
-        looking through expr.py.
+    Users should use create_lower_func to create their lowering function. It
+    should serve most use-cases.
+
+    Currently, this will work with Casts, intrinsics (e.g. sqrt, sigmoid), and
+    binary expressions (e.g. Add, Sub, Mul, Div).
+
+    See the LowerCustomDatatypes pass to see how registered functions are used.
+
+    Lowering Functions
+    ------------------
+    TODO(@gussmith23) Get the terminology right here.
+    Lowering functions take in a Relay node, and should return a semantically
+    equivalent Relay node which Relay can build. This means that the returned
+    node should not contain any custom datatypes. Users should likely not need
+    to define lowering functions by hand -- see the helper function
+    create_lower_func.
 
     Parameters
     ----------
@@ -89,43 +170,123 @@ def register_op(lower_func, op_name, target, type_name, 
src_type_name=None):
 
     op_name : str
         The name of the operation which the function computes, given by its
-        class name (e.g. Add, LE, Cast).
+        class name (e.g. Add, LE, Cast, Call).
 
     target : str
         The name of codegen target.
 
-    type_name : str
-        The name of the custom datatype, e.g. posit (but not custom[posit]8).
-
     src_type_name : str
-        If op_name is "Cast", then this should be set to the source datatype of
+        The name of the custom datatype, e.g. posites2 (but not 
custom[posites2]32).
+        If op_name is not "Cast", then target type is guaranteed to be the 
same as src_type_name.
+
+    dest_type_name : str
+        If op_name is "Cast", then this is required and should be set to the 
dest datatype of
         the argument to the Cast. If op_name is not "Cast", this is unused.
+
+    intrinsic_name : str
+        If op_name is "Call" and intrinsic_name is not None, then we assume the
+        op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's
+        name.
     """
 
     if op_name == "Cast":
-        assert src_type_name is not None
+        assert dest_type_name is not None
+        lower_func_name = (
+            "tvm.datatype.lower."
+            + target
+            + "."
+            + op_name
+            + "."
+            + dest_type_name
+            + "."
+            + src_type_name
+        )
+    elif op_name == "Call" and intrinsic_name is not None:
         lower_func_name = (
-            "tvm.datatype.lower." + target + "." + op_name + "." + type_name + 
"." + src_type_name
+            "tvm.datatype.lower."
+            + target
+            + "."
+            + op_name
+            + ".intrin."
+            + intrinsic_name
+            + "."
+            + src_type_name
         )
     else:
-        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." 
+ type_name
+        lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." 
+ src_type_name
     tvm._ffi.register_func(lower_func_name, lower_func)
 
 
-def create_lower_func(extern_func_name):
+def register_min_func(func, type_name):
+    """Register the function that returns the minimum representable value of 
type_name.
+
+    Operators such as max pooling and argmax require the minimum
+    finite value representable by the datatype the op operating on.
+    Users can use this function to register a function that returns a TIR 
expression node
+    outputting the minimum representable value of their custom data type.
+
+    Users should use create_min_lower_func to create their lowering function. 
It
+    should serve most use-cases.
+
+    Note: for special cases when it is known that the custom datatype is 
representable
+    by a float, the user can create their own lowering func that returns a 
FloatImm.
+    The benefits are allowing optimizations such as rewrites to work as 
expected on custom
+    datatypes.
+
+    Parameters
+    ----------
+    func : function
+        Input is an integer num_bits, should return a TIR expression node that
+        represents a scalar tensor of type custom[type_name]num_bits with the 
minimum
+        representable value.
+
+    type_name : str
+        The name of the custom datatype, e.g. posites2 (but not 
custom[posites2]32).
+    """
+    _register_func("tvm.datatype.min." + type_name, func)
+
+
+def create_min_lower_func(extern_func_map, type_name):
+    """Returns a lowering function for getting the minimum value of a custom 
datatype.
+
+    Parameters
+    ----------
+    extern_func_map : map
+        A map from bit lengths to the name of the extern "C" function to lower 
to.
+
+    type_name : string
+        The name of the custom datatype, e.g. posites2 (but not 
custom[posites2]32).
+    """
+
+    def lower(num_bits):
+        dtype = f"custom[{type_name}]{num_bits}"
+
+        if num_bits not in extern_func_map:
+            raise RuntimeError("missing minimum function for {dtype}")
+
+        return call_pure_extern(dtype, extern_func_map[num_bits])
+
+    return lower
+
+
+def create_lower_func(extern_func_map):
     """Returns a function which lowers an operation to a function call.
 
     Parameters
     ----------
-    extern_func_name : str
-        The name of the extern "C" function to lower to
+    extern_func_map : map
+        If lowering a Cast, extern_func_map should be a map from tuples of
+        (src_bit_length, dest_bit_length) to the name of the extern "C" 
function to lower to.
+
+        Otherwise, for unary and binary ops, it should simply be a map
+        from bit_length to the name of the extern "C" function to lower to.
     """
 
     def lower(op):
         """
-        Takes an op---either a Cast or a binary op (e.g. an Add) and returns a
+        Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and 
returns a
         call to the specified external function, passing the op's argument
-        (Cast) or arguments (a binary op). The return type of the call depends
+        or arguments. The return type of the call depends
         on the type of the op: if it is a custom type, then a uint of the same
         width as the custom type is returned. Otherwise, the type is
         unchanged."""
@@ -135,8 +296,74 @@ def create_lower_func(extern_func_name):
             dtype = "uint" + str(t.bits)
             if t.lanes > 1:
                 dtype += "x" + str(t.lanes)
-        if isinstance(op, (_Cast, _FloatImm)):
-            return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value)
-        return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
+
+        key = t.bits
+        if isinstance(op, _Cast):
+            src_bits = DataType(op.value.dtype).bits
+            key = (src_bits, t.bits)
+
+        if key not in extern_func_map:
+            raise RuntimeError(f"missing key {key} in extern_func_map for 
{op.astext()}")
+
+        if isinstance(op, _Cast):
+            return call_pure_extern(dtype, extern_func_map[key], op.value)
+        if isinstance(op, _FloatImm):
+            return call_pure_extern(dtype, extern_func_map[key], op.value)
+        if isinstance(op, _Call):
+            return call_pure_extern(dtype, extern_func_map[key], *op.args)
+        if isinstance(op, _BinaryOpExpr):
+            return call_pure_extern(dtype, extern_func_map[key], op.a, op.b)
+
+        raise RuntimeError(f"lowering unsupported op: {op.astext()}")
 
     return lower
+
+
+def lower_ite(ite_op):
+    """Lowered if then else function that calls intrinsic if_then_else.
+    Unlike a function lowered by create_lower_func, this function
+    calls the tvm intrinsic if_then_else.
+
+    Parameters
+    ----------
+    ite_op : Op
+        Takes an if then else op and returns a
+        call to tir.if_then_else function, passing the op's
+        arguments. The return type of the call if a uint of the same
+        width as the custom type is returned.
+    """
+    dtype = ite_op.dtype
+    t = tvm.DataType(dtype)
+    assert get_type_registered(t.type_code)
+    dtype = "uint" + str(t.bits)
+    if t.lanes > 1:
+        dtype += "x" + str(t.lanes)
+    return call_intrin(
+        dtype,
+        "tir.if_then_else",
+        convert(ite_op.args[0]),
+        convert(ite_op.args[1]),
+        convert(ite_op.args[2]),
+    )
+
+
+def lower_call_pure_extern(op):
+    """Lowered call pure extern function that calls intrinsic call_pure_extern.
+    Unlike a function lowered by create_lower_func, this function
+    calls the tvm intrinsic call_pure_extern.
+
+    Parameters
+    ----------
+    ite_op : Op
+        Takes a call_pure_extern op and returns a
+        call to tir.call_pure_extern function, passing the op's
+        arguments. The return type of the call if a uint of the same
+        width as the custom type is returned.
+    """
+    dtype = op.dtype
+    t = tvm.DataType(dtype)
+    assert get_type_registered(t.type_code)
+    dtype = "uint" + str(t.bits)
+    if t.lanes > 1:
+        dtype += "x" + str(t.lanes)
+    return call_intrin(dtype, "tir.call_pure_extern", *op.args)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index a7aded6..c237edc 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -30,6 +30,7 @@
 
 #include <algorithm>
 
+#include "../target/datatype/registry.h"
 #include "const_fold.h"
 #include "pattern_match.h"
 
@@ -460,7 +461,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* 
op) {
 
   // x / 2.0 = x * 0.5
   if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
-    CHECK(op->dtype.is_float());
+    CHECK(op->dtype.is_float() ||
+          datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
     return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
   }
 
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e4851b5..d05d846 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -209,6 +209,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule 
mod_mixed, const Target
       }),
       BindTarget(target_host),
       tir::transform::LowerTVMBuiltin(),
+      tir::transform::LowerCustomDatatypes(),
       tir::transform::LowerIntrin(),
       tir::transform::LowerDeviceStorageAccessInfo(),
       tir::transform::CombineContextCall(),
@@ -225,6 +226,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule 
mod_mixed, const Target
       BindTarget(target),
       tir::transform::LowerWarpMemory(),
       tir::transform::Simplify(),
+      tir::transform::LowerCustomDatatypes(),
       tir::transform::LowerIntrin(),
       tir::transform::LowerDeviceStorageAccessInfo(),
   };
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index d6edd10..07f4226 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -149,8 +149,12 @@ inline std::string DType2String(const tvm::DataType dtype) 
{
     os << "int";
   } else if (dtype.is_uint()) {
     os << "uint";
+  } else if 
((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
+    os << "custom["
+       << 
(*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator 
std::string()
+       << "]";
   } else {
-    LOG(FATAL) << "Unknown type";
+    LOG(FATAL) << "Unknown type with code " << 
static_cast<unsigned>(dtype.code());
   }
   os << dtype.bits();
   return os.str();
diff --git a/src/relay/transforms/pattern_util.h 
b/src/relay/transforms/pattern_util.h
index 39fbec5..17e7304 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -35,6 +35,7 @@
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/registry.h>
 #include <tvm/tir/data_layout.h>
 
 #include <limits>
@@ -51,42 +52,46 @@ namespace relay {
  * \brief Dispatch DataType to the C++ data type
  *  during runtime.
  */
-#define TVM_DTYPE_DISPATCH(type, DType, ...)    \
-  if (type == DataType::Float(64)) {            \
-    typedef double DType;                       \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::Float(32)) {     \
-    typedef float DType;                        \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::Float(16)) {     \
-    typedef uint16_t DType;                     \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::Int(64)) {       \
-    typedef int64_t DType;                      \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::Int(32)) {       \
-    typedef int32_t DType;                      \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::Int(16)) {       \
-    typedef int16_t DType;                      \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::Int(8)) {        \
-    typedef int8_t DType;                       \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::UInt(64)) {      \
-    typedef uint64_t DType;                     \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::UInt(32)) {      \
-    typedef uint32_t DType;                     \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::UInt(16)) {      \
-    typedef uint16_t DType;                     \
-    { __VA_ARGS__ }                             \
-  } else if (type == DataType::UInt(8)) {       \
-    typedef uint8_t DType;                      \
-    { __VA_ARGS__ }                             \
-  } else {                                      \
-    LOG(FATAL) << "unknown data type " << type; \
+#define TVM_DTYPE_DISPATCH(type, DType, ...)                                   
       \
+  if (type == DataType::Float(64)) {                                           
       \
+    typedef double DType;                                                      
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Float(32)) {                                    
       \
+    typedef float DType;                                                       
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Float(16)) {                                    
       \
+    typedef uint16_t DType;                                                    
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Int(64)) {                                      
       \
+    typedef int64_t DType;                                                     
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Int(32)) {                                      
       \
+    typedef int32_t DType;                                                     
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Int(16)) {                                      
       \
+    typedef int16_t DType;                                                     
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Int(8)) {                                       
       \
+    typedef int8_t DType;                                                      
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::UInt(64)) {                                     
       \
+    typedef uint64_t DType;                                                    
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::UInt(32)) {                                     
       \
+    typedef uint32_t DType;                                                    
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::UInt(16)) {                                     
       \
+    typedef uint16_t DType;                                                    
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::UInt(8)) {                                      
       \
+    typedef uint8_t DType;                                                     
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else if 
((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \
+                 static_cast<uint8_t>(type.code()))) {                         
       \
+    typedef double DType;                                                      
       \
+    { __VA_ARGS__ }                                                            
       \
+  } else {                                                                     
       \
+    LOG(FATAL) << "unknown data type " << type;                                
       \
   }
 
 /*!
diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc
index 0bdae82..3daf4e4 100644
--- a/src/support/libinfo.cc
+++ b/src/support/libinfo.cc
@@ -120,6 +120,10 @@
 #define TVM_INFO_USE_FALLBACK_STL_MAP "NOT-FOUND"
 #endif
 
+#ifndef TVM_INFO_USE_BYOC_POSIT
+#define TVM_INFO_USE_BYOC_POSIT "NOT-FOUND"
+#endif
+
 #ifndef TVM_INFO_USE_BLAS
 #define TVM_INFO_USE_BLAS "NOT-FOUND"
 #endif
@@ -237,6 +241,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
       {"HIDE_PRIVATE_SYMBOLS", TVM_INFO_HIDE_PRIVATE_SYMBOLS},
       {"USE_TF_TVMDSOOP", TVM_INFO_USE_TF_TVMDSOOP},
       {"USE_FALLBACK_STL_MAP", TVM_INFO_USE_FALLBACK_STL_MAP},
+      {"USE_BYOC_POSIT", TVM_INFO_USE_BYOC_POSIT},
       {"USE_BLAS", TVM_INFO_USE_BLAS},
       {"USE_MKL", TVM_INFO_USE_MKL},
       {"USE_MKLDNN", TVM_INFO_USE_MKLDNN},
diff --git a/src/target/datatype/myfloat/myfloat.cc 
b/src/target/datatype/myfloat/myfloat.cc
new file mode 100644
index 0000000..c0c2fff
--- /dev/null
+++ b/src/target/datatype/myfloat/myfloat.cc
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file 3rdparty/byodt/my-custom-datatype.cc
+ * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT) 
framework.
+ * This is a toy example that under the hood simulates floats.
+ *
+ * Users interested in using the BYODT framework can use this file as a 
template.
+ *
+ * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
+ */
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cmath>
+#include <cstdint>
+#include <limits>
+
+// Custom datatypes are stored as bits in a uint of the appropriate bit length.
+// Thus, when TVM calls these C functions,
+// the arguments of are uints that need to reinterpreted as your custom 
datatype.
+//
+// When returning, your custom datatype needs to be re-wrapped into a uint,
+// which can be thought of as just a wrapper for the raw bits that represent 
your custom datatype.
+template <class T>
+TVM_DLL T Uint32ToCustom32(uint32_t in) {
+  // This is a helper function to interpret the uint as your custom dataype.
+  // The following line should be replaced with the appropriate function
+  // that interprets the bits in `in` and returns your custom datatype
+  T* custom = reinterpret_cast<T*>(&in);
+  return *custom;
+}
+
+template <class T>
+TVM_DLL uint32_t Custom32ToUint32(T in) {
+  // This is a helper function to wrap your custom datatype in a uint.
+  // the following line should be replaced with the appropriate function
+  // that converts your custom datatype into a uint
+  uint32_t* bits = reinterpret_cast<uint32_t*>(&in);
+  return *bits;
+}
+
+extern "C" {
+TVM_DLL uint32_t MinCustom32() {
+  // return minimum representable value
+  float min = std::numeric_limits<float>::lowest();
+  return Custom32ToUint32<float>(min);
+}
+
+TVM_DLL float Custom32ToFloat(uint32_t in) {
+  // cast from custom datatype to float
+  float custom_datatype = Uint32ToCustom32<float>(in);
+  // our custom datatype is float, so the following redundant cast to float
+  // is to remind users to cast their own custom datatype to float
+  return static_cast<float>(custom_datatype);
+}
+
+TVM_DLL uint32_t FloatToCustom32(float in) {
+  // cast from float to custom datatype
+  return Custom32ToUint32<float>(in);
+}
+
+TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) {
+  // add operation
+  float acustom = Uint32ToCustom32<float>(a);
+  float bcustom = Uint32ToCustom32<float>(b);
+  return Custom32ToUint32<float>(acustom + bcustom);
+}
+
+TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) {
+  // subtract
+  float acustom = Uint32ToCustom32<float>(a);
+  float bcustom = Uint32ToCustom32<float>(b);
+  return Custom32ToUint32<float>(acustom - bcustom);
+}
+
+TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) {
+  // multiply
+  float acustom = Uint32ToCustom32<float>(a);
+  float bcustom = Uint32ToCustom32<float>(b);
+  return Custom32ToUint32<float>(acustom * bcustom);
+}
+
+TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) {
+  // divide
+  float acustom = Uint32ToCustom32<float>(a);
+  float bcustom = Uint32ToCustom32<float>(b);
+  return Custom32ToUint32<float>(acustom / bcustom);
+}
+
+TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) {
+  // max
+  float acustom = Uint32ToCustom32<float>(a);
+  float bcustom = Uint32ToCustom32<float>(b);
+  return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom);
+}
+
+TVM_DLL uint32_t Custom32Sqrt(uint32_t a) {
+  // sqrt
+  float acustom = Uint32ToCustom32<float>(a);
+  return Custom32ToUint32<float>(sqrt(acustom));
+}
+
+TVM_DLL uint32_t Custom32Exp(uint32_t a) {
+  // exponential
+  float acustom = Uint32ToCustom32<float>(a);
+  return Custom32ToUint32<float>(exp(acustom));
+}
+
+TVM_DLL uint32_t Custom32Log(uint32_t a) {
+  // log
+  float acustom = Uint32ToCustom32<float>(a);
+  return Custom32ToUint32<float>(log(acustom));
+}
+
+TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) {
+  // sigmoid
+  float acustom = Uint32ToCustom32<float>(a);
+  float one = 1.0f;
+  return Custom32ToUint32<float>(one / (one + exp(-acustom)));
+}
+
+TVM_DLL uint32_t Custom32Tanh(uint32_t a) {
+  // tanh
+  float acustom = Uint32ToCustom32<float>(a);
+  return Custom32ToUint32<float>(tanh(acustom));
+}
+}
diff --git a/src/target/datatype/posit/posit-wrapper.cc 
b/src/target/datatype/posit/posit-wrapper.cc
new file mode 100644
index 0000000..96dbfe1
--- /dev/null
+++ b/src/target/datatype/posit/posit-wrapper.cc
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file 3rdparty/posit/posit-wrapper.cc
+ * \brief Wrapper over the Universal library for Bring Your Own Datatypes tests
+ * Use the SET_POSIT flag to include this file in the build.
+ *
+ * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
+ */
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cstdint>
+
+#include "universal/posit/posit.hpp"
+// must go after posit.hpp
+#include "universal/posit/math/exponent.hpp"
+#include "universal/posit/math/hyperbolic.hpp"
+#include "universal/posit/math/logarithm.hpp"
+#include "universal/posit/math/sqrt.hpp"
+#include "universal/posit/numeric_limits.hpp"
+
+TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
+  sw::unum::bitblock<8> bb;
+  bb = static_cast<uint64_t>(in);
+  return sw::unum::posit<8, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
+  return static_cast<uint8_t>(in.get().to_ullong());
+}
+
+TVM_DLL uint8_t MinPosit8es2() {
+  auto min = std::numeric_limits<sw::unum::posit<8, 2>>::lowest();
+  return Posit8es2toUint8(min);
+}
+
+TVM_DLL float Posit8es2ToFloat(uint8_t in) { return 
Uint8ToPosit8es2(in).operator float(); }
+
+TVM_DLL uint8_t FloatToPosit8es2(float in) {
+  auto posit = sw::unum::posit<8, 2>(in);
+  return Posit8es2toUint8(posit);
+}
+
+TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) {
+  auto a_p = Uint8ToPosit8es2(a);
+  auto b_p = Uint8ToPosit8es2(b);
+  return Posit8es2toUint8(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Exp(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Log(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) {
+  auto posit_one = sw::unum::posit<8, 2>(1);
+  return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) + 
posit_one));
+}
+
+TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
+  sw::unum::bitblock<16> bb;
+  bb = static_cast<uint64_t>(in);
+  return sw::unum::posit<16, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
+  return static_cast<uint16_t>(in.get().to_ullong());
+}
+
+TVM_DLL uint8_t MinPosit16es2() {
+  auto min = std::numeric_limits<sw::unum::posit<16, 2>>::lowest();
+  return Posit16es2toUint16(min);
+}
+
+TVM_DLL float Posit16es2ToFloat(uint16_t in) { return 
Uint16ToPosit16es2(in).operator float(); }
+
+TVM_DLL uint16_t FloatToPosit16es2(float in) {
+  auto posit = sw::unum::posit<16, 2>(in);
+  return Posit16es2toUint16(posit);
+}
+
+TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) {
+  auto a_p = Uint16ToPosit16es2(a);
+  auto b_p = Uint16ToPosit16es2(b);
+  return Posit16es2toUint16(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Exp(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Log(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) {
+  auto posit_one = sw::unum::posit<16, 2>(1);
+  return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a)) 
+ posit_one));
+}
+
+TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
+  sw::unum::bitblock<32> bb;
+  bb = static_cast<uint64_t>(in);
+  return sw::unum::posit<32, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint32_t Posit32es2ToUint32(sw::unum::posit<32, 2> in) {
+  return static_cast<uint32_t>(in.get().to_ullong());
+}
+
+TVM_DLL uint8_t MinPosit32es2() {
+  auto min = std::numeric_limits<sw::unum::posit<32, 2>>::lowest();
+  return Posit32es2ToUint32(min);
+}
+
+TVM_DLL float Posit32es2ToFloat(uint32_t in) { return 
Uint32ToPosit32es2(in).operator float(); }
+
+TVM_DLL uint32_t FloatToPosit32es2(float in) {
+  auto posit = sw::unum::posit<32, 2>(in);
+  return Posit32es2ToUint32(posit);
+}
+
+TVM_DLL uint32_t Posit32es2Add(uint32_t a, uint32_t b) {
+  return Posit32es2ToUint32(Uint32ToPosit32es2(a) + Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Sub(uint32_t a, uint32_t b) {
+  return Posit32es2ToUint32(Uint32ToPosit32es2(a) - Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Mul(uint32_t a, uint32_t b) {
+  return Posit32es2ToUint32(Uint32ToPosit32es2(a) * Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Div(uint32_t a, uint32_t b) {
+  return Posit32es2ToUint32(Uint32ToPosit32es2(a) / Uint32ToPosit32es2(b));
+}
+
+TVM_DLL uint32_t Posit32es2Max(uint32_t a, uint32_t b) {
+  auto a_p = Uint32ToPosit32es2(a);
+  auto b_p = Uint32ToPosit32es2(b);
+  return Posit32es2ToUint32(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint32_t Posit32es2Sqrt(uint32_t a) {
+  return Posit32es2ToUint32(sw::unum::sqrt(Uint32ToPosit32es2(a)));
+}
+
+TVM_DLL uint32_t Posit32es2Exp(uint32_t a) {
+  return Posit32es2ToUint32(sw::unum::exp(Uint32ToPosit32es2(a)));
+}
+
+TVM_DLL uint32_t Posit32es2Log(uint32_t a) {
+  return Posit32es2ToUint32(sw::unum::log(Uint32ToPosit32es2(a)));
+}
+
+TVM_DLL uint32_t Posit32es2Sigmoid(uint32_t a) {
+  auto posit_one = sw::unum::posit<32, 2>(1);
+  return Posit32es2ToUint32(posit_one / (posit_one + 
sw::unum::exp(-Uint32ToPosit32es2(a))));
+}
+
+TVM_DLL uint32_t Posit32es2Tanh(uint32_t a) {
+  return Posit32es2ToUint32(sw::unum::tanh(Uint32ToPosit32es2(a)));
+}
+}
diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc
index 5ed3ce4..c84f917 100644
--- a/src/target/datatype/registry.cc
+++ b/src/target/datatype/registry.cc
@@ -91,6 +91,13 @@ const runtime::PackedFunc* GetCastLowerFunc(const 
std::string& target, uint8_t t
   return runtime::Registry::Get(ss.str());
 }
 
+const runtime::PackedFunc* GetMinFunc(uint8_t type_code) {
+  std::ostringstream ss;
+  ss << "tvm.datatype.min.";
+  ss << datatype::Registry::Global()->GetTypeName(type_code);
+  return runtime::Registry::Get(ss.str());
+}
+
 const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, 
uint8_t type_code) {
   std::ostringstream ss;
   ss << "tvm.datatype.lower.";
@@ -100,6 +107,18 @@ const runtime::PackedFunc* GetFloatImmLowerFunc(const 
std::string& target, uint8
   return runtime::Registry::Get(ss.str());
 }
 
+const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const 
std::string& name,
+                                              uint8_t type_code) {
+  std::ostringstream ss;
+  ss << "tvm.datatype.lower.";
+  ss << target;
+  ss << ".Call.intrin.";
+  ss << name;
+  ss << ".";
+  ss << datatype::Registry::Global()->GetTypeName(type_code);
+  return runtime::Registry::Get(ss.str());
+}
+
 uint64_t ConvertConstScalar(uint8_t type_code, double value) {
   std::ostringstream ss;
   ss << "tvm.datatype.convertconstscalar.float.";
diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h
index 5df8ef8..01e7b82 100644
--- a/src/target/datatype/registry.h
+++ b/src/target/datatype/registry.h
@@ -43,6 +43,8 @@ namespace datatype {
  *      For Casts: tvm.datatype.lower.<target>.Cast.<type>.<src_type>
  *        Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from
  *                 float to myfloat.
+ * For intrinsic Calls: tvm.datatype.lower.<target>.Call.intrin.<name>.<type>
+ *             Example: tvm.datatype.lower.llvm.Call.intrin.sqrt.myfloat
  *  For other ops: tvm.datatype.lower.<target>.<op>.<type>
  *       Examples: tvm.datatype.lower.llvm.Add.myfloat
  *                 tvm.datatype.lower.llvm.FloatImm.posit
@@ -60,7 +62,7 @@ class Registry {
    * manually allocated by the user, and the user must ensure that no two 
custom types share the
    * same code. Generally, this should be straightforward, as the user will be 
manually registering
    * all of their custom types.
-   * \param type_name The name of the type, e.g. "bfloat"
+   * \param type_name The name of the type, e.g. "posites2"
    * \param type_code The type code, which should be greater than 
TVMArgTypeCode::kTVMExtEnd
    */
   void Register(const std::string& type_name, uint8_t type_code);
@@ -112,6 +114,13 @@ class Registry {
 uint64_t ConvertConstScalar(uint8_t type_code, double value);
 
 /*!
+ * \brief Get a function returning the minimum value for a datatype.
+ * \param type_code The datatype
+ * \return Function which takes the width of the datatype and returns the min 
value
+ */
+const runtime::PackedFunc* GetMinFunc(uint8_t type_code);
+
+/*!
  * \brief Get lowering function for Cast ops
  * \param target The target we are lowering to, e.g. "llvm"
  * \param type_code The datatype being cast to
@@ -130,6 +139,16 @@ const runtime::PackedFunc* GetCastLowerFunc(const 
std::string& target, uint8_t t
 const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, 
uint8_t type_code);
 
 /*!
+ * \brief Get lowering function for intrinsic Calls/pure intrinsic Calls
+ * \param target The target we are lowering to, e.g. "llvm"
+ * \param type_code The datatype of the Call
+ * \param name The intrinsic name
+ * \return Lowering function for intrinsic Calls for the provided target and 
type
+ */
+const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const 
std::string& name,
+                                              uint8_t type_code);
+
+/*!
  * \brief Get lowering function for other ops
  * \param target The target we are lowering to, e.g. "llvm"
  * \param type_code The datatype of the op
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 6dc485f..6d94a08 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -32,6 +32,7 @@
 #include <cmath>
 // Centralized header for constant folders.
 #include "../../arith/const_fold.h"
+#include "../../target/datatype/registry.h"
 
 namespace tvm {
 
@@ -114,10 +115,14 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) {  
// NOLINT(*)
   // require the types to be relatively consistent
   // This will the reduce amount code generated by operators
   // and also help user to find potential type conversion problems.
-  if (!lhs.dtype().is_float() && rhs.dtype().is_float()) {
+  if (!lhs.dtype().is_float() &&
+      (rhs.dtype().is_float() ||
+       datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
     // int->float
     lhs = cast(rhs.dtype(), lhs);
-  } else if (lhs.dtype().is_float() && !rhs.dtype().is_float()) {
+  } else if ((lhs.dtype().is_float() ||
+              
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
+             !rhs.dtype().is_float()) {
     // int->float
     rhs = cast(lhs.dtype(), rhs);
   } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
@@ -174,7 +179,13 @@ PrimExpr max_value(const DataType& dtype) {
 PrimExpr min_value(const DataType& dtype) {
   using namespace tir;
   CHECK_EQ(dtype.lanes(), 1);
-  if (dtype.is_int()) {
+  if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
+    auto f = datatype::GetMinFunc(dtype.code());
+    CHECK(f) << "No minimum function registered for custom dtype " << 
(unsigned int)dtype.code();
+    // TODO(@hypercubestart) Document this change (and others associated with 
the overflowing
+    // floatimm min bug)
+    return (*f)(dtype.bits());
+  } else if (dtype.is_int()) {
     if (dtype.bits() == 64) {
       return IntImm(dtype, std::numeric_limits<int64_t>::lowest());
     } else if (dtype.bits() < 64) {
diff --git a/src/tir/transforms/lower_custom_datatypes.cc 
b/src/tir/transforms/lower_custom_datatypes.cc
index ae9584f..a0faa17 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -23,6 +23,7 @@
 
 #include <tvm/runtime/registry.h>
 #include <tvm/target/target.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
@@ -50,7 +51,6 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     bool toBeLowered = 
datatype::Registry::Global()->GetTypeRegistered(type_code) ||
                        
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<CastNode>();
     if (toBeLowered) {
       auto lower = datatype::GetCastLowerFunc(target_, type_code, 
src_type_code);
       CHECK(lower) << "Cast lowering function for target " << target_ << " 
destination type "
@@ -97,6 +97,22 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;
   }
 
+  inline PrimExpr VisitExpr_(const CallNode* call) final {
+    bool toBeLowered = 
datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
+    PrimExpr expr = StmtExprMutator::VisitExpr_(call);
+    call = expr.as<CallNode>();
+    if (toBeLowered) {
+      auto op = call->op.as<OpNode>();
+      CHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented";
+      auto lower = datatype::GetIntrinLowerFunc(target_, op->name, 
call->dtype.code());
+      CHECK(lower) << "Intrinsic lowering function for target " << target_ << 
", intrinsic name "
+                   << op->name << ", type " << 
static_cast<unsigned>(call->dtype.code())
+                   << " not found";
+      return (*lower)(expr);
+    }
+    return expr;
+  }
+
 #define DEFINE_MUTATE(OP, NodeName)                                            
    \
   inline PrimExpr VisitExpr_(const NodeName* op) final {                       
    \
     auto type_code = op->dtype.code();                                         
    \
diff --git a/tests/python/unittest/test_custom_datatypes.py 
b/tests/python/unittest/test_custom_datatypes.py
new file mode 100644
index 0000000..337d703
--- /dev/null
+++ b/tests/python/unittest/test_custom_datatypes.py
@@ -0,0 +1,562 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for the Bring Your Own Datatype framework.
+
+TODO(@gussmith23 @hypercubestart) link to documentation"""
+import tvm
+import tvm.topi.testing
+import numpy as np
+import pytest
+from numpy.random import MT19937, RandomState, SeedSequence
+from tvm import relay
+from tvm.relay.testing.layers import batch_norm_infer
+from tvm.target.datatype import (
+    register,
+    register_min_func,
+    register_op,
+    create_lower_func,
+    lower_ite,
+    lower_call_pure_extern,
+    create_min_lower_func,
+)
+from tvm.tir.op import call_pure_extern
+
+# note: we can't use relay.testing models because params are randomly 
initialized,
+# which lead the output to have the same values
+# get mobilenet model from Gluon CV
+# because: 
https://discuss.tvm.apache.org/t/mobilenet-intermediate-values-are-0/7812
+def get_mobilenet():
+    dshape = (1, 3, 224, 224)
+    from mxnet.gluon.model_zoo.vision import get_model
+
+    block = get_model("mobilenet0.25", pretrained=True)
+    shape_dict = {"data": dshape}
+    return relay.frontend.from_mxnet(block, shape_dict)
+
+
+# use real image instead of random data for end-to-end model training
+# or else output would all be around the same value
+def get_cat_image(dimensions):
+    from tvm.contrib.download import download_testdata
+    from PIL import Image
+
+    url = 
"https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png";
+    dst = "cat.png"
+    real_dst = download_testdata(url, dst, module="data")
+    img = Image.open(real_dst).resize(dimensions)
+    # CoreML's standard model image format is BGR
+    img_bgr = np.array(img)[:, :, ::-1]
+    img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
+    return np.asarray(img, dtype="float32")
+
+
+# we use a random seed to generate input_data
+# to guarantee stable tests
+rs = RandomState(MT19937(SeedSequence(123456789)))
+
+
+def convert_ndarray(dst_dtype, array):
+    """Converts NDArray(s) into the specified datatype"""
+    x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
+    cast = relay.Function([x], x.astype(dst_dtype))
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        return relay.create_executor("graph").evaluate(cast)(array)
+
+
+def change_dtype(src, dst, module, params):
+    """Convert constants and functions in module from src type to dst type.
+    Returns changed module and converted params of type dst_type.
+    """
+    module = relay.frontend.ChangeDatatype(src, dst)(module)
+    module = relay.transform.InferType()(module)
+    params = {k: convert_ndarray(dst, v) for k, v in params.items()}
+    return module, params
+
+
+def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={}, 
target="llvm"):
+    module = relay.transform.SimplifyInference()(module)
+    ex = relay.create_executor("graph", mod=module)
+
+    correct = ex.evaluate()(*input, **params)
+    module, converted_params = change_dtype(src_dtype, dst_dtype, module, 
params)
+    ex = relay.create_executor("graph", mod=module, target=target)
+    # converts all inputs to dst_dtype
+    x_converted = [convert_ndarray(dst_dtype, arr) for arr in input]
+
+    # Vectorization is not implemented with custom datatypes
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        maybe_correct = ex.evaluate()(*x_converted, **converted_params)
+        # currently this only works for comparing single output
+        maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct)
+    np.testing.assert_allclose(
+        maybe_correct_converted.asnumpy(), correct.asnumpy(), rtol=rtol, 
atol=atol
+    )
+
+
+def setup_myfloat():
+    """Set up tests for myfloat (a custom datatype that under the hood is 
float)
+
+    Currently, this registers some custom datatypes using the Bring Your
+    Own Datatypes framework.
+    """
+
+    # To use datatype operations in an external library, you should first load
+    # the library containing the datatype implementation:
+    # CDLL("libposit.so", RTLD_GLOBAL)
+    # In this case, the datatype library we are using is built right into TVM,
+    # so we do not need to explicitly load any library.
+
+    # You can pick a code for your datatype arbitrarily, as long as it is
+    # greater than 128 and has not already been chosen.
+    register("myfloat", 131)
+
+    register_op(
+        create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", 
"float", "myfloat"
+    )
+    register_op(
+        create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", 
"myfloat", "float"
+    )
+    register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", 
"myfloat")
+    register_op(
+        create_lower_func(
+            {
+                32: "Custom32Sub",
+            }
+        ),
+        "Sub",
+        "llvm",
+        "myfloat",
+    )
+    register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", 
"myfloat")
+    register_op(
+        create_lower_func(
+            {
+                32: "FloatToCustom32",
+            }
+        ),
+        "FloatImm",
+        "llvm",
+        "myfloat",
+    )
+    register_op(
+        create_lower_func(
+            {
+                32: "Custom32Div",
+            }
+        ),
+        "Div",
+        "llvm",
+        "myfloat",
+    )
+    register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", 
"myfloat")
+    register_op(
+        create_lower_func({32: "Custom32Sqrt"}),
+        "Call",
+        "llvm",
+        "myfloat",
+        intrinsic_name="tir.sqrt",
+    )
+    register_op(
+        create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", 
intrinsic_name="tir.exp"
+    )
+    register_op(
+        create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", 
intrinsic_name="tir.log"
+    )
+    register_op(
+        create_lower_func({32: "Custom32Sigmoid"}),
+        "Call",
+        "llvm",
+        "myfloat",
+        intrinsic_name="tir.sigmoid",
+    )
+    register_op(
+        create_lower_func({32: "Custom32Tanh"}),
+        "Call",
+        "llvm",
+        "myfloat",
+        intrinsic_name="tir.tanh",
+    )
+    register_op(lower_ite, "Call", "llvm", "myfloat", 
intrinsic_name="tir.if_then_else")
+    register_op(
+        lower_call_pure_extern, "Call", "llvm", "myfloat", 
intrinsic_name="tir.call_pure_extern"
+    )
+
+    register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), 
"myfloat")
+
+
+def setup_posites2():
+    """Set up tests for posites2
+    Currently, this registers some custom datatypes using the Bring Your
+    Own Datatypes framework.
+    """
+
+    # To use datatype operations in an external library, you should first load
+    # the library containing the datatype implementation:
+    # CDLL("libposit.so", RTLD_GLOBAL)
+    # In this case, the datatype library we are using is built right into TVM,
+    # so we do not need to explicitly load any library.
+
+    # You can pick a code for your datatype arbitrarily, as long as it is
+    # greater than 128 and has not already been chosen.
+
+    register("posites2", 132)
+
+    register_op(
+        create_lower_func(
+            {
+                (32, 32): "FloatToPosit32es2",
+                (32, 16): "FloatToPosit16es2",
+                (32, 8): "FloatToPosit8es2",
+            }
+        ),
+        "Cast",
+        "llvm",
+        "float",
+        "posites2",
+    )
+    register_op(
+        create_lower_func(
+            {
+                (32, 32): "Posit32es2ToFloat",
+                (16, 32): "Posit16es2ToFloat",
+                (8, 32): "Posit8es2ToFloat",
+            }
+        ),
+        "Cast",
+        "llvm",
+        "posites2",
+        "float",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Add", 16: "Posit16es2Add", 8: 
"Posit8es2Add"}),
+        "Add",
+        "llvm",
+        "posites2",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Sub", 16: "Posit16es2Sub", 8: 
"Posit8es2Sub"}),
+        "Sub",
+        "llvm",
+        "posites2",
+    )
+    register_op(
+        create_lower_func(
+            {32: "FloatToPosit32es2", 16: "FloatToPosit16es2", 8: 
"FloatToPosit8es2"}
+        ),
+        "FloatImm",
+        "llvm",
+        "posites2",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Mul", 16: "Posit16es2Mul", 8: 
"Posit8es2Mul"}),
+        "Mul",
+        "llvm",
+        "posites2",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Div", 16: "Posit16es2Div", 8: 
"Posit8es2Div"}),
+        "Div",
+        "llvm",
+        "posites2",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Max", 16: "Posit16es2Max", 8: 
"Posit8es2Max"}),
+        "Max",
+        "llvm",
+        "posites2",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Sqrt", 16: "Posit16es2Sqrt", 8: 
"Posit8es2Sqrt"}),
+        "Call",
+        "llvm",
+        "posites2",
+        intrinsic_name="tir.sqrt",
+    )
+    register_op(lower_ite, "Call", "llvm", "posites2", 
intrinsic_name="tir.if_then_else")
+    register_op(
+        lower_call_pure_extern, "Call", "llvm", "posites2", 
intrinsic_name="tir.call_pure_extern"
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Exp", 16: "Posit16es2Exp", 8: 
"Posit8es2Exp"}),
+        "Call",
+        "llvm",
+        "posites2",
+        intrinsic_name="tir.exp",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Log", 16: "Posit16es2Log", 8: 
"Posit8es2Log"}),
+        "Call",
+        "llvm",
+        "posites2",
+        intrinsic_name="tir.log",
+    )
+    register_op(
+        create_lower_func(
+            {32: "Posit32es2Sigmoid", 16: "Posit16es2Sigmoid", 8: 
"Posit8es2Sigmoid"}
+        ),
+        "Call",
+        "llvm",
+        "posites2",
+        intrinsic_name="tir.sigmoid",
+    )
+    register_op(
+        create_lower_func({32: "Posit32es2Tanh", 16: "Posit16es2Tanh", 8: 
"Posit8es2Tanh"}),
+        "Call",
+        "llvm",
+        "posites2",
+        intrinsic_name="tir.tanh",
+    )
+
+    register_min_func(
+        create_min_lower_func(
+            {32: "MinPosit32es2", 16: "MinPosit16es2", 8: "MinPosit8es2"}, 
"posites2"
+        ),
+        "posites2",
+    )
+
+
+def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7):
+    """Run the same op, but with two different datatypes"""
+    # used for unary ops, first shape in binary ops
+    shape1 = (5, 10, 5)
+    # second shape for binary ops
+    shape2 = (5,)
+
+    def check_unary_op(op, src_dtype, dst_dtype, shape):
+        t1 = relay.TensorType(shape, src_dtype)
+        x = relay.var("x", t1)
+        z = op(x)
+        x_data = rs.rand(*shape).astype(t1.dtype)
+
+        module = tvm.IRModule.from_expr(relay.Function([x], z))
+
+        compare(module, (x_data,), src_dtype, dst_dtype, rtol, atol)
+
+    # test unary ops
+    for op in [
+        relay.nn.softmax,
+        tvm.relay.log,
+        tvm.relay.exp,
+        tvm.relay.sqrt,
+        tvm.relay.rsqrt,
+        tvm.relay.sigmoid,
+        tvm.relay.tanh,
+        relay.nn.relu,
+        relay.nn.batch_flatten,
+    ]:
+        check_unary_op(op, src_dtype, dst_dtype, shape1)
+
+    # test unary ops over 4d data
+    for op in [relay.nn.max_pool2d, relay.nn.avg_pool2d, 
relay.nn.global_avg_pool2d]:
+        shape_2d = (3, 32, 32, 32)
+        check_unary_op(op, src_dtype, dst_dtype, shape_2d)
+
+    def check_binary_op(opfunc, src_dtype, dst_dtype):
+        t1 = relay.TensorType(shape1, src_dtype)
+        t2 = relay.TensorType(shape2, src_dtype)
+        x = relay.var("x", t1)
+        y = relay.var("y", t2)
+        z = opfunc(x, y)
+        x_data = rs.rand(*shape1).astype(t1.dtype)
+        y_data = rs.rand(*shape2).astype(t2.dtype)
+        module = tvm.IRModule.from_expr(relay.Function([x, y], z))
+
+        compare(module, (x_data, y_data), src_dtype, dst_dtype, rtol, atol)
+
+    for op in [
+        relay.add,
+        relay.subtract,
+        relay.divide,
+        relay.multiply,
+    ]:
+        check_binary_op(op, src_dtype, dst_dtype)
+
+    # we would like to test tvm_if_then_else
+    # but Relay.IfNode is not lowered to this intrinsic,
+    # so to keep our tests consistent with relay, we decide to not unit test
+    # Note: tvm_if_then_else is tested as part of the mobile_net model
+
+
+def run_model(get_workload, input, src_dtype, dst_dtype, rtol=1e-4, atol=1e-4):
+    module, params = get_workload()
+
+    # we don't generate random data here
+    # because then the output data would all be around the same value
+    compare(module, input, src_dtype, dst_dtype, rtol, atol, params)
+
+
+def run_conv2d(src_dtype, dst_dtype, rtol=1e-7, atol=1e-4):
+    def run_test_conv2d(
+        src_dtype,
+        dst_dtype,
+        scale,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        groups=1,
+        dilation=(1, 1),
+        **attrs,
+    ):
+        x = relay.var("x", shape=dshape, dtype=src_dtype)
+        w = relay.var("w", shape=kshape, dtype=src_dtype)
+        y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, 
groups=groups, **attrs)
+        module = tvm.IRModule.from_expr(relay.Function([x, w], y))
+        data = rs.uniform(-scale, scale, size=dshape).astype(src_dtype)
+        kernel = rs.uniform(-scale, scale, size=kshape).astype(src_dtype)
+
+        compare(module, (data, kernel), src_dtype, dst_dtype, rtol, atol)
+
+    # depthwise conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (32, 1, 3, 3)
+    run_test_conv2d(
+        src_dtype,
+        dst_dtype,
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=32,
+        groups=32,
+        kernel_size=(3, 3),
+    )
+
+    # CUDA is disabled for 'direct' schedule:
+    # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553
+    # group conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (32, 4, 3, 3)
+    run_test_conv2d(
+        src_dtype,
+        dst_dtype,
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=32,
+        groups=8,
+        kernel_size=(3, 3),
+    )
+    # also group conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (64, 1, 3, 3)
+    run_test_conv2d(
+        src_dtype,
+        dst_dtype,
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=64,
+        groups=32,
+        kernel_size=(3, 3),
+    )
+
+    # normal conv2d
+    dshape = (1, 3, 224, 224)
+    kshape = (10, 3, 3, 3)
+    run_test_conv2d(
+        src_dtype, dst_dtype, 1, dshape, kshape, padding=(1, 1), channels=10, 
kernel_size=(3, 3)
+    )
+
+    # dilated conv2d
+    dshape = (1, 3, 18, 18)
+    kshape = (10, 3, 3, 3)
+    run_test_conv2d(
+        src_dtype,
+        dst_dtype,
+        1,
+        dshape,
+        kshape,
+        padding=(1, 1),
+        channels=10,
+        kernel_size=(3, 3),
+        dilation=(3, 3),
+    )
+
+
+def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6):
+    shape = (3, 32, 32)
+    t = relay.TensorType(shape, src_dtype)
+    x = relay.var("x", t)
+    bn = batch_norm_infer(data=x, epsilon=2e-5, scale=False, name="bn_x")
+    f = relay.Function(relay.analysis.free_vars(bn), bn)
+
+    x_data = rs.rand(*shape).astype(t.dtype)
+    module = tvm.IRModule.from_expr(f)
+
+    zero_data = np.zeros((32), "float32")
+    compare(
+        module,
+        (x_data, zero_data, zero_data, zero_data, zero_data),
+        src_dtype,
+        dst_dtype,
+        rtol,
+        atol,
+    )
+
+
+def test_myfloat():
+    setup_myfloat()
+    run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
+    run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
+    run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
+
+    # mxnet python package not available
+    # run_model(get_mobilenet, (get_cat_image((224, 224)), ),
+    #           'float32',
+    #           'custom[myfloat]32')
+
+
+def _has_posit():
+    return tvm.support.libinfo()["USE_BYOC_POSIT"] == "ON"
+
+
[email protected](not _has_posit(), reason="compiled with USE_BYOC_POSIT 
flag OFF")
+def test_posites2():
+    setup_posites2()
+    run_ops("float32", "custom[posites2]8", rtol=1, atol=1)
+    run_ops("float32", "custom[posites2]16", rtol=0.01, atol=1)
+    run_ops("float32", "custom[posites2]32", rtol=1e-6, atol=1e-6)
+
+    run_conv2d("float32", "custom[posites2]8", rtol=1, atol=1)
+    run_conv2d("float32", "custom[posites2]16", rtol=0.01, atol=1)
+    run_conv2d("float32", "custom[posites2]32")
+
+    run_batchnorm("float32", "custom[posites2]8", rtol=1, atol=1)
+    run_batchnorm("float32", "custom[posites2]16", rtol=0.01, atol=1)
+    run_batchnorm("float32", "custom[posites2]32", rtol=1e-4, atol=1e-4)
+    # Expected posit8 might be faster, but it's not.
+    # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32', 
'custom[posit8]8')
+    # run_model(get_mobilenet, (get_cat_image((224, 224)), ), 'float32', 
'custom[posit32]32')
+    # run_model(get_inception, (get_cat_image((229, 229)), ), 'float32', 
'custom[posit32]32')
+    # run_model(get_resnet, (get_cat_image((224, 224)), ), 'float32', 
'custom[posit32]32')
+
+    # can't run cifar-10 sizes because dimensions
+    # don't match pretrained weights
+
+    # runs on the order of minutes...
+    # run_model(get_inception, (get_cat_image((229, 229)), ),
+    #           'float32',
+    #           'custom[posites2]32')
+    # run_model(get_resnet, (get_cat_image((224, 224)), ),
+    #           'float32',
+    #           'custom[posites2]32')
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/python/unittest/test_target_custom_datatypes.py 
b/tests/python/unittest/test_target_custom_datatypes.py
deleted file mode 100644
index 9b2b85d..0000000
--- a/tests/python/unittest/test_target_custom_datatypes.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import tvm
-from tvm import te
-from ctypes import *
-from tvm import topi
-import numpy as np
-
-tgt = "llvm"
-
-
-def setup_module():
-    # You must first load the library containing the datatype implementation.
-    # In this case, we have built the test functions used below right into TVM.
-    # CDLL("libmybfloat16.so", RTLD_GLOBAL)
-
-    tvm.target.datatype.register("bfloat", 129)
-
-    tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"),
-        "Cast",
-        "llvm",
-        "bfloat",
-        "float",
-    )
-    tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("BFloat16ToFloat_wrapper"),
-        "Cast",
-        "llvm",
-        "float",
-        "bfloat",
-    )
-    tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", 
"llvm", "bfloat"
-    )
-    tvm.target.datatype.register_op(
-        tvm.target.datatype.create_lower_func("FloatToBFloat16_wrapper"),
-        "FloatImm",
-        "llvm",
-        "bfloat",
-    )
-
-
-def lower_datatypes_and_build(schedule, args):
-    """Create schedule and lower, manually lowering datatypes.
-
-    Once datatype lowering is integrated directly into TVM's lower/build
-    process, we won't need to do this manually.
-    TODO(gus) integrate datatype lowering into build process; change this 
test"""
-    mod = tvm.lower(schedule, args)
-    target = tvm.target.Target(tgt)
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
-    mod = tvm.tir.transform.LowerCustomDatatypes()(mod)
-    return tvm.build(mod, target=tgt)
-
-
-def test_bfloat_add_and_cast_1():
-    X = te.placeholder((3,), name="X")
-    Y = te.placeholder((3,), name="Y")
-    Z = topi.cast(
-        topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y, 
dtype="custom[bfloat]16"),
-        dtype="float",
-    )
-
-    s = te.create_schedule([Z.op])
-    built_cast = lower_datatypes_and_build(s, [X, Y, Z])
-
-    ctx = tvm.context(tgt, 0)
-
-    # Used float32 calculator at http://www.weitz.de/ieee/. Generated float32s
-    # with at most 7-bit mantissas which, when added, produce a result with at
-    # most 7-bit mantissas. This is to ensure there are no errors due to
-    # float32->bfloat16 conversions.
-    x = tvm.nd.array(np.array([4.4103796e-32, 14942208.0, 
1.78125]).astype("float32"), ctx=ctx)
-    y = tvm.nd.array(np.array([-3.330669e-14, 19660800.0, 
2.25]).astype("float32"), ctx=ctx)
-    z_expected = np.array([-3.330669e-14, 34603008.0, 
4.03125]).astype("float32")
-    z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
-
-    built_cast(x, y, z)
-
-    assert np.array_equal(z_expected, z.asnumpy())
-
-
-def test_bfloat_add_and_cast_2():
-    X = te.placeholder((3,), name="X")
-    Y = te.placeholder((3,), name="Y")
-    Z = topi.cast(
-        topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y, 
dtype="custom[bfloat]16"),
-        dtype="float",
-    )
-
-    s = te.create_schedule([Z.op])
-    built_cast = lower_datatypes_and_build(s, [X, Y, Z])
-
-    ctx = tvm.context(tgt, 0)
-
-    # Used float32 calculator at http://www.weitz.de/ieee/. Generated
-    # unconstrained float32s for the operands and copied them in to x and y.
-    # Then, to simulate float32->bfloat16 conversion implemented by the 
mybfloat
-    # library, I cut off all but 7 bits of the mantissa. I then added the
-    # numbers. To simulate bfloat16 add implemented in mybfloat, I cut off all
-    # but 7 bits of the result's mantissa. I then copied that value into
-    # z_expected.
-    x = tvm.nd.array(np.array([1.2348297, -1.0298302e25, 
1.2034023e-30]).astype("float32"), ctx=ctx)
-    y = tvm.nd.array(np.array([-2.4992788, -9.888288e19, 
9.342338e-29]).astype("float32"), ctx=ctx)
-    z_expected = np.array([-1.25, -1.027587e25, 
9.426888e-29]).astype("float32")
-    z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
-
-    built_cast(x, y, z)
-
-    assert np.array_equal(z_expected, z.asnumpy())
-
-
-def test_bfloat_add_and_cast_FloatImm():
-    X = te.placeholder((3,), name="X")
-    Z = topi.cast(
-        topi.add(topi.cast(X, dtype="custom[bfloat]16"), 
tvm.tir.FloatImm("custom[bfloat]16", 1.5)),
-        dtype="float",
-    )
-
-    s = te.create_schedule([Z.op])
-    built_cast = lower_datatypes_and_build(s, [X, Z])
-
-    ctx = tvm.context(tgt, 0)
-
-    x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx)
-    z_expected = np.array([1.5, 2.5, 3.0]).astype("float32")
-    z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)
-
-    built_cast(x, z)
-
-    assert np.array_equal(z_expected, z.asnumpy())
-
-
-if __name__ == "__main__":
-    setup_module()
-    test_bfloat_add_and_cast_1()
-    test_bfloat_add_and_cast_2()
-    test_bfloat_add_and_cast_FloatImm()
diff --git a/tutorials/dev/bring_your_own_datatypes.py 
b/tutorials/dev/bring_your_own_datatypes.py
new file mode 100644
index 0000000..cbb1b99
--- /dev/null
+++ b/tutorials/dev/bring_your_own_datatypes.py
@@ -0,0 +1,408 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Bring Your Own Datatypes to TVM
+===============================
+**Authors**: `Gus Smith <https://github.com/gussmith23>`_, `Andrew Liu 
<https://github.com/hypercubestart>`_
+
+In this tutorial, we will show you how to utilize the Bring Your Own Datatypes 
framework to use your own custom datatypes in TVM.
+Note that the Bring Your Own Datatypes framework currently only handles 
**software emulated versions of datatypes**.
+The framework does not support compiling for custom accelerator datatypes 
out-of-the-box.
+
+Datatype Libraries
+------------------
+
+The Bring Your Own Datatypes allows users to register their own datatype 
implementations alongside TVM's native datatypes (such as ``float``).
+In the wild, these datatype implementations often appear as libraries.
+For example:
+
+- `libposit <https://github.com/cjdelisle/libposit>`_, a posit library
+- `Stillwater Universal <https://github.com/stillwater-sc/universal>`_, a 
library with posits, fixed-point numbers, and other types
+- `SoftFloat <https://github.com/ucb-bar/berkeley-softfloat-3>`_, Berkeley's 
software implementation of IEEE 754 floating-point
+
+The Bring Your Own Datatypes enables users to plug these datatype 
implementations into TVM!
+
+In this section, we will use an example library we have already implemented, 
located at ``3rdparty/byodt/myfloat.cc``.
+This datatype, which we dubbed "myfloat", is really just a IEE-754 float 
under-the-hood, but it serves a useful example
+to show that any datatype can be used in the BYODT framework.
+
+Setup
+-----
+
+Since we do not use any 3rdparty library, there is no setup needed.
+
+If you would like to try this with your own datatype library, first bring the 
library's functions into the process space with ``CDLL``:
+
+.. code-block :: python
+
+    ctypes.CDLL('my-datatype-lib.so', ctypes.RTLD_GLOBAL)
+"""
+
+######################
+# A Simple TVM Program
+# --------------------
+#
+# We'll begin by writing a simple program in TVM; afterwards, we will re-write 
it to use custom datatypes.
+import tvm
+from tvm import relay
+
+# Our basic program: Z = X + Y
+x = relay.var("x", shape=(3,), dtype="float32")
+y = relay.var("y", shape=(3,), dtype="float32")
+z = x + y
+program = relay.Function([x, y], z)
+module = tvm.IRModule.from_expr(program)
+
+######################################################################
+# Now, we create random inputs to feed into this program using numpy:
+
+import numpy as np
+
+np.random.seed(23)  # for reproducibility
+
+x_input = np.random.rand(3).astype("float32")
+y_input = np.random.rand(3).astype("float32")
+print("x: {}".format(x_input))
+print("y: {}".format(y_input))
+
+######################################################################
+# Finally, we're ready to run the program:
+
+ex = relay.create_executor(mod=module)
+
+z_output = ex.evaluate()(x_input, y_input)
+print("z: {}".format(z_output))
+
+######################################################################
+# Adding Custom Datatypes
+# -----------------------
+# Now, we will do the same, but we will use a custom datatype for our 
intermediate computation.
+#
+# We use the same input variables ``x`` and ``y`` as above, but before adding 
``x + y``, we first cast both ``x`` and ``y`` to a custom datatype via the 
``relay.cast(...)`` call.
+#
+# Note how we specify the custom datatype: we indicate it using the special 
``custom[...]`` syntax.
+# Additionally, note the "32" after the datatype: this is the bitwidth of the 
custom datatype. This tells TVM that each instance of ``myfloat`` is 32 bits 
wide.
+
+try:
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        x_myfloat = relay.cast(x, dtype="custom[myfloat]32")
+        y_myfloat = relay.cast(y, dtype="custom[myfloat]32")
+        z_myfloat = x_myfloat + y_myfloat
+        z = relay.cast(z_myfloat, dtype="float32")
+except tvm.TVMError as e:
+    # Print last line of error
+    print(str(e).split("\n")[-1])
+
+######################################################################
+# Trying to generate this program throws an error from TVM.
+# TVM does not know how to handle any custom datatype out of the box!
+# We first have to register the custom type with TVM, giving it a name and a 
type code:
+
+tvm.target.datatype.register("myfloat", 150)
+
+######################################################################
+# Note that the type code, 150, is currently chosen manually by the user.
+# See ``TVMTypeCode::kCustomBegin`` in `include/tvm/runtime/c_runtime_api.h 
<https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/data_type.h>`_.
+# Now we can generate our program again:
+
+x_myfloat = relay.cast(x, dtype="custom[myfloat]32")
+y_myfloat = relay.cast(y, dtype="custom[myfloat]32")
+z_myfloat = x_myfloat + y_myfloat
+z = relay.cast(z_myfloat, dtype="float32")
+program = relay.Function([x, y], z)
+module = tvm.IRModule.from_expr(program)
+
+######################################################################
+# Now we have a Relay program that uses myfloat!
+print(program)
+
+######################################################################
+# Now that we can express our program without errors, let's try running it!
+try:
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        ex = relay.create_executor("graph", mod=module)
+        z_output_myfloat = ex.evaluate()(x_input, y_input)
+        print("z: {}".format(y_myfloat))
+except tvm.TVMError as e:
+    # Print last line of error
+    print(str(e).split("\n")[-1])
+
+######################################################################
+# Now, trying to compile this program throws an error.
+# Let's dissect this error.
+#
+# The error is occurring during the process of lowering the custom datatype 
code to code that TVM can compile and run.
+# TVM is telling us that it cannot find a *lowering function* for the ``Cast`` 
operation, when casting from source type 2 (``float``, in TVM), to destination 
type 150 (our custom datatype).
+# When lowering custom datatypes, if TVM encounters an operation over a custom 
datatype, it looks for a user-registered *lowering function*, which tells it 
how to lower the operation to an operation over datatypes it understands.
+# We have not told TVM how to lower ``Cast`` operations for our custom 
datatypes; thus, the source of this error.
+#
+# To fix this error, we simply need to specify a lowering function:
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func(
+        {
+            (32, 32): "FloatToCustom32",  # cast from float32 to myfloat32
+        }
+    ),
+    "Cast",
+    "llvm",
+    "float",
+    "myfloat",
+)
+
+######################################################################
+# The ``register_op(...)`` call takes a lowering function, and a number of 
parameters which specify exactly the operation which should be lowered with the 
provided lowering function.
+# In this case, the arguments we pass specify that this lowering function is 
for lowering a ``Cast`` from ``float`` to ``myfloat`` for target ``"llvm"``.
+#
+# The lowering function passed into this call is very general: it should take 
an operation of the specified type (in this case, `Cast`) and return another 
operation which only uses datatypes which TVM understands.
+#
+# In the general case, we expect users to implement operations over their 
custom datatypes using calls to an external library.
+# In our example, our ``myfloat`` library implements a ``Cast`` from ``float`` 
to 32-bit ``myfloat`` in the function ``FloatToCustom32``.
+# To provide for the general case, we have made a helper function, 
``create_lower_func(...)``,
+# which does just this: given a dictionary, it replaces the given operation 
with a ``Call`` to the appropriate function name provided based on the op and 
the bit widths.
+# It additionally removes usages of the custom datatype by storing the custom 
datatype in an opaque ``uint`` of the appropriate width; in our case, a 
``uint32_t``.
+# For more information, see `the source code 
<https://github.com/apache/incubator-tvm/blob/master/python/tvm/target/datatype.py>`_.
+
+# We can now re-try running the program:
+try:
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        ex = relay.create_executor("graph", mod=module)
+        z_output_myfloat = ex.evaluate()(x_input, y_input)
+        print("z: {}".format(z_output_myfloat))
+except tvm.TVMError as e:
+    # Print last line of error
+    print(str(e).split("\n")[-1])
+
+######################################################################
+# This new error tells us that the ``Add`` lowering function is not found, 
which is good news, as it's no longer complaining about the ``Cast``!
+# We know what to do from here: we just need to register the lowering 
functions for the other operations in our program.
+#
+# Note that for ``Add``, ``create_lower_func`` takes in a dict where the key 
is an integer.
+# For ``Cast`` operations, we require a 2-tuple to specify the 
``src_bit_length`` and the ``dest_bit_length``,
+# while for all other operations, the bit length is the same between the 
operands so we only require one integer to specify ``bit_length``.
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Add"}),
+    "Add",
+    "llvm",
+    "myfloat",
+)
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({(32, 32): "Custom32ToFloat"}),
+    "Cast",
+    "llvm",
+    "myfloat",
+    "float",
+)
+
+# Now, we can run our program without errors.
+with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+    compiled = ex.evaluate(program)
+    z_output_myfloat = compiled(x_input, y_input)
+print("z: {}".format(z_output_myfloat))
+
+print("x:\t\t{}".format(x_input))
+print("y:\t\t{}".format(y_input))
+print("z (float32):\t{}".format(z_output))
+print("z (myfloat32):\t{}".format(z_output_myfloat))
+
+# Perhaps as expected, the ``myfloat32`` results and ``float32`` are exactly 
the same!
+
+######################################################################
+# Running Models With Custom Datatypes
+# ------------------------------------
+#
+# We will first choose the model which we would like to run with myfloat.
+# In this case we use `Mobilenet <https://arxiv.org/abs/1704.04861>`_.
+# We choose Mobilenet due to its small size.
+# In this alpha state of the Bring Your Own Datatypes framework, we have not 
implemented any software optimizations for running software emulations of 
custom datatypes; the result is poor performance due to many calls into our 
datatype emulation library.
+#
+# First let us define two helper functions to get the mobilenet model and a 
cat image.
+
+
+def get_mobilenet():
+    dshape = (1, 3, 224, 224)
+    from mxnet.gluon.model_zoo.vision import get_model
+
+    block = get_model("mobilenet0.25", pretrained=True)
+    shape_dict = {"data": dshape}
+    return relay.frontend.from_mxnet(block, shape_dict)
+
+
+def get_cat_image():
+    from tvm.contrib.download import download_testdata
+    from PIL import Image
+
+    url = 
"https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png";
+    dst = "cat.png"
+    real_dst = download_testdata(url, dst, module="data")
+    img = Image.open(real_dst).resize((224, 224))
+    # CoreML's standard model image format is BGR
+    img_bgr = np.array(img)[:, :, ::-1]
+    img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
+    return np.asarray(img, dtype="float32")
+
+
+module, params = get_mobilenet()
+
+######################################################################
+# It's easy to execute MobileNet with native TVM:
+
+ex = tvm.relay.create_executor("graph", mod=module)
+input = get_cat_image()
+result = ex.evaluate()(input, **params).asnumpy()
+# print first 10 elements
+print(result.flatten()[:10])
+
+######################################################################
+# Now, we would like to change the model to use myfloat internally. To do so, 
we need to convert the network. To do this, we first define a function which 
will help us convert tensors:
+
+
+def convert_ndarray(dst_dtype, array):
+    """Converts an NDArray into the specified datatype"""
+    x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
+    cast = relay.Function([x], x.astype(dst_dtype))
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        return relay.create_executor("graph").evaluate(cast)(array)
+
+
+######################################################################
+# Now, to actually convert the entire network, we have written `a pass in 
Relay 
<https://github.com/gussmith23/tvm/blob/ea174c01c54a2529e19ca71e125f5884e728da6e/python/tvm/relay/frontend/change_datatype.py#L21>`_
 which simply converts all nodes within the model to use the new datatype.
+
+from tvm.relay.frontend.change_datatype import ChangeDatatype
+
+src_dtype = "float32"
+dst_dtype = "custom[myfloat]32"
+
+# Currently, custom datatypes only work if you run simplify_inference 
beforehand
+module = tvm.relay.transform.SimplifyInference()(module)
+
+# Run type inference before changing datatype
+module = tvm.relay.transform.InferType()(module)
+
+# Change datatype from float to myfloat and re-infer types
+cdtype = ChangeDatatype(src_dtype, dst_dtype)
+expr = cdtype.visit(module["main"])
+module = tvm.relay.transform.InferType()(module)
+
+# We also convert the parameters:
+params = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}
+
+# We also need to convert our input:
+input = convert_ndarray(dst_dtype, input)
+
+# Finally, we can try to run the converted model:
+try:
+    # Vectorization is not implemented with custom datatypes.
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        result_myfloat = ex.evaluate(expr)(input, **params)
+except tvm.TVMError as e:
+    print(str(e).split("\n")[-1])
+
+######################################################################
+# When we attempt to run the model, we get a familiar error telling us that 
more funcions need to be registerd for myfloat.
+#
+# Because this is a neural network, many more operations are required.
+# Here, we register all the needed functions:
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "FloatToCustom32"}),
+    "FloatImm",
+    "llvm",
+    "myfloat",
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.lower_ite, "Call", "llvm", "myfloat", 
intrinsic_name="tir.if_then_else"
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.lower_call_pure_extern,
+    "Call",
+    "llvm",
+    "myfloat",
+    intrinsic_name="tir.call_pure_extern",
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Mul"}),
+    "Mul",
+    "llvm",
+    "myfloat",
+)
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Div"}),
+    "Div",
+    "llvm",
+    "myfloat",
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Sqrt"}),
+    "Call",
+    "llvm",
+    "myfloat",
+    intrinsic_name="tir.sqrt",
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Sub"}),
+    "Sub",
+    "llvm",
+    "myfloat",
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Exp"}),
+    "Call",
+    "llvm",
+    "myfloat",
+    intrinsic_name="tir.exp",
+)
+
+tvm.target.datatype.register_op(
+    tvm.target.datatype.create_lower_func({32: "Custom32Max"}),
+    "Max",
+    "llvm",
+    "myfloat",
+)
+
+tvm.target.datatype.register_min_func(
+    tvm.target.datatype.create_min_lower_func({32: "MinCustom32"}, "myfloat"),
+    "myfloat",
+)
+
+######################################################################
+# Note we are making use of two new functions: ``register_min_func`` and 
``create_min_lower_func``.
+#
+# ``register_min_func`` takes in an integer ``num_bits`` for the bit length, 
and should return an operation
+# representing the minimum finite representable value for the custom data type 
with the specified bit length.
+#
+# Similar to ``register_op`` and ``create_lower_func``, the 
``create_min_lower_func`` handles the general case
+# where the minimum representable custom datatype value is implemented using 
calls to an external library.
+#
+# Now we can finally run the model:
+
+# Vectorization is not implemented with custom datatypes.
+with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+    result_myfloat = ex.evaluate(expr)(input, **params)
+    result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
+    # print first 10 elements
+    print(result_myfloat.flatten()[:10])
+
+# Again, note that the output using 32-bit myfloat exactly the same as 32-bit 
floats,
+# because myfloat is exactly a float!
+np.testing.assert_array_equal(result, result_myfloat)

Reply via email to