This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e43555f739 [REFACTOR][DataType] Phase out target custom datatype
support (#19760)
e43555f739 is described below
commit e43555f739474e318f1cfc15bd7a1951180d304b
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 14 09:11:13 2026 -0400
[REFACTOR][DataType] Phase out target custom datatype support (#19760)
## Summary
The in-tree target custom datatype path adds maintenance surface while
current development focuses on core datatypes. This PR phases out the
built-in registry/lowering implementation and keeps the core dtype
behavior intact.
- Remove the target/datatype implementation, BYODT posit build option,
and related Python helpers
- Remove the custom datatype lowering pass from TIRX and S-TIR
finalization pipelines
- Simplify remaining TIRX dtype handling back to built-in/core datatypes
---
CMakeLists.txt | 6 -
cmake/modules/contrib/Posit.cmake | 26 --
docker/Dockerfile.ci_cpu | 4 -
docker/Dockerfile.ci_gpu | 4 -
include/tvm/tirx/op.h | 7 -
include/tvm/tirx/transform.h | 9 -
python/tvm/s_tir/pipeline.py | 2 -
python/tvm/target/__init__.py | 1 -
python/tvm/target/datatype.py | 379 ---------------------------
python/tvm/tirx/compilation_pipeline.py | 2 -
python/tvm/tirx/transform/transform.py | 13 -
src/arith/rewrite_simplify.cc | 1 -
src/target/datatype/myfloat/myfloat.cc | 144 ----------
src/target/datatype/posit/posit-wrapper.cc | 242 -----------------
src/target/datatype/registry.cc | 138 ----------
src/target/datatype/registry.h | 182 -------------
src/target/llvm/codegen_llvm.cc | 7 +-
src/tirx/op/op.cc | 25 +-
src/tirx/transform/lower_custom_datatypes.cc | 266 -------------------
19 files changed, 8 insertions(+), 1450 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ad99c4c6ac..0eb8c90184 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -87,7 +87,6 @@ tvm_option(USE_CCACHE "Use ccache if found when invoking
compiler" AUTO)
# 3rdparty libraries
tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
# Contrib library options
-tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom
datatype" OFF)
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_AMX "Enable Intel AMX" OFF)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
@@ -356,10 +355,6 @@ tvm_file_glob(GLOB CODEGEN_SRCS
list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
-tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
-list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
-list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
-
tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
@@ -464,7 +459,6 @@ include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/AMX.cmake)
include(cmake/modules/contrib/CUTLASS.cmake)
include(cmake/modules/contrib/Random.cmake)
-include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/CoreML.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
diff --git a/cmake/modules/contrib/Posit.cmake
b/cmake/modules/contrib/Posit.cmake
deleted file mode 100644
index b8d180ee44..0000000000
--- a/cmake/modules/contrib/Posit.cmake
+++ /dev/null
@@ -1,26 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-if(USE_BYODT_POSIT)
- message(STATUS "Build with contrib.posit")
- if (NOT UNIVERSAL_PATH)
- message(FATAL_ERROR "Fail to get Universal path")
- endif(NOT UNIVERSAL_PATH)
-
- include_directories(${UNIVERSAL_PATH}/include)
- list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc")
-endif(USE_BYODT_POSIT)
diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu
index 8e31b310fe..e823db54b2 100644
--- a/docker/Dockerfile.ci_cpu
+++ b/docker/Dockerfile.ci_cpu
@@ -63,10 +63,6 @@ RUN bash /install/ubuntu_install_dnnl.sh
COPY install/ubuntu_install_xgboost.sh /install/ubuntu_install_xgboost.sh
RUN bash /install/ubuntu_install_xgboost.sh
-# BYODT deps
-COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh
-RUN bash /install/ubuntu_install_universal.sh
-
# TensorFlow deps
COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh
RUN bash /install/ubuntu_install_tensorflow.sh
diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu
index df15215b94..2f9139f842 100644
--- a/docker/Dockerfile.ci_gpu
+++ b/docker/Dockerfile.ci_gpu
@@ -115,10 +115,6 @@ RUN bash /install/ubuntu_install_vulkan.sh
COPY install/ubuntu_install_xgboost.sh /install/ubuntu_install_xgboost.sh
RUN bash /install/ubuntu_install_xgboost.sh
-# BYODT deps
-COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh
-RUN bash /install/ubuntu_install_universal.sh
-
# sccache
COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh
RUN bash /install/ubuntu_install_sccache.sh
diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h
index 2027665712..7a7584aff2 100644
--- a/include/tvm/tirx/op.h
+++ b/include/tvm/tirx/op.h
@@ -998,13 +998,6 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType
value, Span span = Span())
}
if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() ||
t.is_float4())
return FloatImm(t, static_cast<double>(value), span);
- // For now, we store const scalar values of custom datatypes within doubles;
later, during the
- // datatypes lowering pass, we will lower the value to its true
representation in the format
- // specified by the datatype.
- // TODO(gus) when do we need to start worrying about doubles not being
precise enough?
- if (static_cast<uint8_t>(t.code()) >=
static_cast<uint8_t>(DataType::kCustomBegin)) {
- return FloatImm(t, static_cast<double>(value), span);
- }
TVM_FFI_THROW(InternalError) << "cannot make const for type " << t;
throw;
}
diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h
index 32a3ea8b29..e5a754f6c5 100644
--- a/include/tvm/tirx/transform.h
+++ b/include/tvm/tirx/transform.h
@@ -153,15 +153,6 @@ TVM_DLL Pass MakePackedAPI();
*/
TVM_DLL Pass RemapThreadAxis(ffi::Map<ffi::String, IterVar> axis_map);
-/*!
- * \brief Lower custom datatypes.
- *
- * See tvm::datatypes::Registry for more information on adding custom
datatypes.
- *
- * \return The pass.
- */
-TVM_DLL Pass LowerCustomDatatypes();
-
/*!
* \brief Annotate, split, and lower host/device functions.
*
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index fb8310dc26..df1cd74a21 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -125,7 +125,6 @@ def finalize_host_passes(): # pylint:
disable=unused-argument
"""The default finalization passes for TIR backend."""
host_pass_list = [
tirx.transform.LowerTVMBuiltin(),
- tirx.transform.LowerCustomDatatypes(),
tirx.transform.LowerIntrin(),
]
return tvm.ir.transform.Sequential(host_pass_list)
@@ -136,7 +135,6 @@ def finalize_device_passes(): # pylint:
disable=unused-argument
device_pass_list = [
tirx.transform.LowerWarpMemory(),
tirx.transform.StmtSimplify(),
- tirx.transform.LowerCustomDatatypes(),
tirx.transform.LowerIntrin(),
]
return tvm.ir.transform.Sequential(device_pass_list)
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 5c6733cf8c..7303a0d097 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -34,6 +34,5 @@ and :py:func:`tvm.target.register_tag` to register new tags.
from .target import Target, TargetKind
from .virtual_device import VirtualDevice
from .tag import list_tags, register_tag
-from . import datatype
from . import codegen
from . import tag_registry # registers tags on import
diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py
deleted file mode 100644
index d7a47836b2..0000000000
--- a/python/tvm/target/datatype.py
+++ /dev/null
@@ -1,379 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# ruff: noqa: F821
-"""Bring Your Own Datatypes custom datatype framework
-
-TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist"""
-
-from tvm_ffi import get_global_func
-from tvm_ffi import register_global_func as _register_global_func
-
-import tvm
-from tvm.runtime import DataType, convert
-from tvm.tirx import call_intrin
-from tvm.tirx.expr import (
- BinaryOpExpr as _BinaryOpExpr,
-)
-from tvm.tirx.expr import (
- Call as _Call,
-)
-from tvm.tirx.expr import (
- Cast as _Cast,
-)
-from tvm.tirx.expr import (
- FloatImm as _FloatImm,
-)
-from tvm.tirx.op import call_pure_extern
-
-
-def register(type_name, type_code):
- """Register a custom datatype with the given type name and type code
-
- Currently, the type code is manually allocated by the user, and the user
- must ensure that no two custom types share the same code. Generally, this
- should be straightforward, as the user will be manually registering all of
- their custom types.
-
- Example:
-
- .. code-block:: python
-
- # Register a dtype named 'posites2' under type code 130.
- tvm.target.datatype.register('posites2', 130)
-
-
- Parameters
- ----------
- type_name : str
- The name of the custom datatype.
-
- type_code : int
- The type's code, which should be >= kCustomBegin. See
- include/tvm/runtime/data_type.h.
- """
- get_global_func("dtype.register_custom_type")(type_name, type_code)
-
-
-def get_type_name(type_code):
- """Get the type name of a custom datatype from the type code.
-
- Note that this only works for custom datatypes registered with
- tvm.target.datatype.register(). It does not work for TVM-native types.
-
- Example:
-
- .. code-block:: python
-
- tvm.target.datatype.register('posites2', 130)
- assert tvm.target.datatype.get_type_name(130) == 'posites2'
-
- Parameters
- ----------
- type_code : int
- The type code of the custom datatype.
-
- Returns
- -------
- type_name : String
- The name of the custom datatype.
-
- """
- return get_global_func("dtype.get_custom_type_name")(type_code)
-
-
-def get_type_code(type_name):
- """Get the type code of a custom datatype from its type name
-
- Note that this only works for custom datatypes registered with
- tvm.target.datatype.register(). It does not work for TVM-native types.
-
- Example:
-
- .. code-block:: python
-
- tvm.target.datatype.register('posites2', 130)
- assert tvm.target.datatype.get_type_code('posites2') == 130
-
- Parameters
- ----------
- type_name : str
- The type name
-
- Returns
- -------
- type_code : int
- The type code of the custom datatype.
- """
- return get_global_func("dtype.get_custom_type_code")(type_name)
-
-
-def get_type_registered(type_code):
- """Returns true if a custom datatype is registered under the given type
code
-
- Example:
-
- .. code-block:: python
-
- tvm.target.datatype.register('posites2', 130)
- assert tvm.target.datatype.get_type_registered(130)
-
- Parameters
- ----------
- type_code: int
- The type code
-
- Returns
- -------
- type_registered : bool
- True if a custom datatype is registered under this type code, and false
- otherwise.
- """
- return tvm.runtime._ffi_api._datatype_get_type_registered(type_code)
-
-
-def register_op(
- lower_func, op_name, target, src_type_name, dest_type_name=None,
intrinsic_name=None
-):
- """Register a lowering function for a specific operator of a custom
datatype
-
- At build time, Relay must lower operators over custom datatypes into
- operators it understands how to compile. For each custom datatype operator
- which Relay finds while lowering custom datatypes, Relay expects to find a
- user-defined lowering function. Users register their user-defined lowering
- functions using this function.
-
- Users should use create_lower_func to create their lowering function. It
- should serve most use-cases.
-
- Currently, this will work with Casts, intrinsics (e.g. sqrt, sigmoid), and
- binary expressions (e.g. Add, Sub, Mul, Div).
-
- See the LowerCustomDatatypes pass to see how registered functions are used.
-
- Lowering Functions
- ------------------
- TODO(@gussmith23) Get the terminology right here.
- Lowering functions take in a Relay node, and should return a semantically
- equivalent Relay node which Relay can build. This means that the returned
- node should not contain any custom datatypes. Users should likely not need
- to define lowering functions by hand -- see the helper function
- create_lower_func.
-
- Parameters
- ----------
- lower_func : function
- The lowering function to call. See create_lower_func.
-
- op_name : str
- The name of the operation which the function computes, given by its
- class name (e.g. Add, LE, Cast, Call).
-
- target : str
- The name of codegen target.
-
- src_type_name : str
- The name of the custom datatype, e.g. posites2 (but not
custom[posites2]32).
- If op_name is not "Cast", then target type is guaranteed to be the
same as src_type_name.
-
- dest_type_name : str
- If op_name is "Cast", then this is required and should be set to the
dest datatype of
- the argument to the Cast. If op_name is not "Cast", this is unused.
-
- intrinsic_name : str
- If op_name is "Call" and intrinsic_name is not None, then we assume the
- op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's
- name.
- """
-
- if op_name == "Cast":
- assert dest_type_name is not None
- lower_func_name = (
- "tvm.datatype.lower."
- + target
- + "."
- + op_name
- + "."
- + dest_type_name
- + "."
- + src_type_name
- )
- elif op_name == "Call" and intrinsic_name is not None:
- lower_func_name = (
- "tvm.datatype.lower."
- + target
- + "."
- + op_name
- + ".intrin."
- + intrinsic_name
- + "."
- + src_type_name
- )
- else:
- lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "."
+ src_type_name
- tvm_ffi.register_global_func(lower_func_name, lower_func)
-
-
-def register_min_func(func, type_name):
- """Register the function that returns the minimum representable value of
type_name.
-
- Operators such as max pooling and argmax require the minimum
- finite value representable by the datatype the op operating on.
- Users can use this function to register a function that returns a TIR
expression node
- outputting the minimum representable value of their custom data type.
-
- Users should use create_min_lower_func to create their lowering function.
It
- should serve most use-cases.
-
- Note: for special cases when it is known that the custom datatype is
representable
- by a float, the user can create their own lowering func that returns a
FloatImm.
- The benefits are allowing optimizations such as rewrites to work as
expected on custom
- datatypes.
-
- Parameters
- ----------
- func : function
- Input is an integer num_bits, should return a TIR expression node that
- represents a scalar tensor of type custom[type_name]num_bits with the
minimum
- representable value.
-
- type_name : str
- The name of the custom datatype, e.g. posites2 (but not
custom[posites2]32).
- """
- _register_global_func("tvm.datatype.min." + type_name, func)
-
-
-def create_min_lower_func(extern_func_map, type_name):
- """Returns a lowering function for getting the minimum value of a custom
datatype.
-
- Parameters
- ----------
- extern_func_map : map
- A map from bit lengths to the name of the extern "C" function to lower
to.
-
- type_name : string
- The name of the custom datatype, e.g. posites2 (but not
custom[posites2]32).
- """
-
- def lower(num_bits):
- dtype = f"custom[{type_name}]{num_bits}"
-
- if num_bits not in extern_func_map:
- raise RuntimeError("missing minimum function for {dtype}")
-
- return call_pure_extern(dtype, extern_func_map[num_bits])
-
- return lower
-
-
-def create_lower_func(extern_func_map):
- """Returns a function which lowers an operation to a function call.
-
- Parameters
- ----------
- extern_func_map : map
- If lowering a Cast, extern_func_map should be a map from tuples of
- (src_bit_length, dest_bit_length) to the name of the extern "C"
function to lower to.
-
- Otherwise, for unary and binary ops, it should simply be a map
- from bit_length to the name of the extern "C" function to lower to.
- """
-
- def lower(op):
- """
- Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and
returns a
- call to the specified external function, passing the op's argument
- or arguments. The return type of the call depends
- on the type of the op: if it is a custom type, then a uint of the same
- width as the custom type is returned. Otherwise, the type is
- unchanged."""
- dtype = op.dtype
- t = DataType(dtype)
- if get_type_registered(t.type_code):
- dtype = "uint" + str(t.bits)
- if t.lanes > 1:
- dtype += "x" + str(t.lanes)
-
- key = t.bits
- if isinstance(op, _Cast):
- src_bits = DataType(op.value.dtype).bits
- key = (src_bits, t.bits)
-
- if key not in extern_func_map:
- raise RuntimeError(f"missing key {key} in extern_func_map for
{op}")
-
- if isinstance(op, _Cast):
- return call_pure_extern(dtype, extern_func_map[key], op.value)
- if isinstance(op, _FloatImm):
- return call_pure_extern(dtype, extern_func_map[key], op.value)
- if isinstance(op, _Call):
- return call_pure_extern(dtype, extern_func_map[key], *op.args)
- if isinstance(op, _BinaryOpExpr):
- return call_pure_extern(dtype, extern_func_map[key], op.a, op.b)
-
- raise RuntimeError(f"lowering unsupported op: {op}")
-
- return lower
-
-
-def lower_ite(ite_op):
- """Lowered if then else function that calls intrinsic if_then_else.
- Unlike a function lowered by create_lower_func, this function
- calls the tvm intrinsic if_then_else.
-
- Parameters
- ----------
- ite_op : Op
- Takes an if then else op and returns a
- call to tirx.if_then_else function, passing the op's
- arguments. The return type of the call if a uint of the same
- width as the custom type is returned.
- """
- dtype = ite_op.dtype
- t = tvm.DataType(dtype)
- assert get_type_registered(t.type_code)
- dtype = "uint" + str(t.bits)
- if t.lanes > 1:
- dtype += "x" + str(t.lanes)
- return call_intrin(
- dtype,
- "tirx.if_then_else",
- convert(ite_op.args[0]),
- convert(ite_op.args[1]),
- convert(ite_op.args[2]),
- )
-
-
-def lower_call_pure_extern(op):
- """Lowered call pure extern function that calls intrinsic call_pure_extern.
- Unlike a function lowered by create_lower_func, this function
- calls the tvm intrinsic call_pure_extern.
-
- Parameters
- ----------
- ite_op : Op
- Takes a call_pure_extern op and returns a
- call to tirx.call_pure_extern function, passing the op's
- arguments. The return type of the call if a uint of the same
- width as the custom type is returned.
- """
- dtype = op.dtype
- t = tvm.DataType(dtype)
- assert get_type_registered(t.type_code)
- dtype = "uint" + str(t.bits)
- if t.lanes > 1:
- dtype += "x" + str(t.lanes)
- return call_intrin(dtype, "tirx.call_pure_extern", *op.args)
diff --git a/python/tvm/tirx/compilation_pipeline.py
b/python/tvm/tirx/compilation_pipeline.py
index d2847332b4..23dee416bb 100644
--- a/python/tvm/tirx/compilation_pipeline.py
+++ b/python/tvm/tirx/compilation_pipeline.py
@@ -103,7 +103,6 @@ def finalize_host_passes(): # pylint:
disable=unused-argument
"""The default finalization passes for TIR backend."""
host_pass_list = [
tirx.transform.LowerTVMBuiltin(),
- tirx.transform.LowerCustomDatatypes(),
tirx.transform.LowerIntrin(),
]
return tvm.ir.transform.Sequential(host_pass_list)
@@ -114,7 +113,6 @@ def finalize_device_passes(): # pylint:
disable=unused-argument
device_pass_list = [
tirx.transform.LowerWarpMemory(),
tirx.transform.StmtSimplify(),
- tirx.transform.LowerCustomDatatypes(),
tirx.transform.LowerIntrin(),
]
return tvm.ir.transform.Sequential(device_pass_list)
diff --git a/python/tvm/tirx/transform/transform.py
b/python/tvm/tirx/transform/transform.py
index 72a5b96202..ae6b942b66 100644
--- a/python/tvm/tirx/transform/transform.py
+++ b/python/tvm/tirx/transform/transform.py
@@ -245,19 +245,6 @@ def ConvertSSA():
return _ffi_api.ConvertSSA() # type: ignore
-def LowerCustomDatatypes():
- """Lower custom datatypes.
-
- See tvm::datatypes::Registry for more information on adding custom
datatypes.
-
- Returns
- -------
- fpass : tvm.transform.Pass
- The result pass
- """
- return _ffi_api.LowerCustomDatatypes() # type: ignore
-
-
def MakePackedAPI():
"""Transform the PrimFuncs in the module to a packed func API.
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index bec5091883..5a86cdd15a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -34,7 +34,6 @@
#include <tuple>
#include <utility>
-#include "../target/datatype/registry.h"
#include "../tirx/analysis/check_contains.h"
#include "conjunctive_normal_form.h"
#include "const_fold.h"
diff --git a/src/target/datatype/myfloat/myfloat.cc
b/src/target/datatype/myfloat/myfloat.cc
deleted file mode 100644
index afee8a7c4b..0000000000
--- a/src/target/datatype/myfloat/myfloat.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file 3rdparty/byodt/my-custom-datatype.cc
- * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT)
framework.
- * This is a toy example that under the hood simulates floats.
- *
- * Users interested in using the BYODT framework can use this file as a
template.
- *
- * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
- */
-#include <tvm/runtime/base.h>
-
-#include <cmath>
-#include <cstdint>
-#include <limits>
-
-// Custom datatypes are stored as bits in a uint of the appropriate bit length.
-// Thus, when TVM calls these C functions,
-// the arguments of are uints that need to reinterpreted as your custom
datatype.
-//
-// When returning, your custom datatype needs to be re-wrapped into a uint,
-// which can be thought of as just a wrapper for the raw bits that represent
your custom datatype.
-template <class T>
-TVM_DLL T Uint32ToCustom32(uint32_t in) {
- // This is a helper function to interpret the uint as your custom dataype.
- // The following line should be replaced with the appropriate function
- // that interprets the bits in `in` and returns your custom datatype
- T* custom = reinterpret_cast<T*>(&in);
- return *custom;
-}
-
-template <class T>
-TVM_DLL uint32_t Custom32ToUint32(T in) {
- // This is a helper function to wrap your custom datatype in a uint.
- // the following line should be replaced with the appropriate function
- // that converts your custom datatype into a uint
- uint32_t* bits = reinterpret_cast<uint32_t*>(&in);
- return *bits;
-}
-
-extern "C" {
-TVM_DLL uint32_t MinCustom32() {
- // return minimum representable value
- float min = std::numeric_limits<float>::lowest();
- return Custom32ToUint32<float>(min);
-}
-
-TVM_DLL float Custom32ToFloat(uint32_t in) {
- // cast from custom datatype to float
- float custom_datatype = Uint32ToCustom32<float>(in);
- // our custom datatype is float, so the following redundant cast to float
- // is to remind users to cast their own custom datatype to float
- return static_cast<float>(custom_datatype);
-}
-
-TVM_DLL uint32_t FloatToCustom32(float in) {
- // cast from float to custom datatype
- return Custom32ToUint32<float>(in);
-}
-
-TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) {
- // add operation
- float acustom = Uint32ToCustom32<float>(a);
- float bcustom = Uint32ToCustom32<float>(b);
- return Custom32ToUint32<float>(acustom + bcustom);
-}
-
-TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) {
- // subtract
- float acustom = Uint32ToCustom32<float>(a);
- float bcustom = Uint32ToCustom32<float>(b);
- return Custom32ToUint32<float>(acustom - bcustom);
-}
-
-TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) {
- // multiply
- float acustom = Uint32ToCustom32<float>(a);
- float bcustom = Uint32ToCustom32<float>(b);
- return Custom32ToUint32<float>(acustom * bcustom);
-}
-
-TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) {
- // divide
- float acustom = Uint32ToCustom32<float>(a);
- float bcustom = Uint32ToCustom32<float>(b);
- return Custom32ToUint32<float>(acustom / bcustom);
-}
-
-TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) {
- // max
- float acustom = Uint32ToCustom32<float>(a);
- float bcustom = Uint32ToCustom32<float>(b);
- return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom);
-}
-
-TVM_DLL uint32_t Custom32Sqrt(uint32_t a) {
- // sqrt
- float acustom = Uint32ToCustom32<float>(a);
- return Custom32ToUint32<float>(sqrt(acustom));
-}
-
-TVM_DLL uint32_t Custom32Exp(uint32_t a) {
- // exponential
- float acustom = Uint32ToCustom32<float>(a);
- return Custom32ToUint32<float>(exp(acustom));
-}
-
-TVM_DLL uint32_t Custom32Log(uint32_t a) {
- // log
- float acustom = Uint32ToCustom32<float>(a);
- return Custom32ToUint32<float>(log(acustom));
-}
-
-TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) {
- // sigmoid
- float acustom = Uint32ToCustom32<float>(a);
- float one = 1.0f;
- return Custom32ToUint32<float>(one / (one + exp(-acustom)));
-}
-
-TVM_DLL uint32_t Custom32Tanh(uint32_t a) {
- // tanh
- float acustom = Uint32ToCustom32<float>(a);
- return Custom32ToUint32<float>(tanh(acustom));
-}
-}
diff --git a/src/target/datatype/posit/posit-wrapper.cc
b/src/target/datatype/posit/posit-wrapper.cc
deleted file mode 100644
index e05695e603..0000000000
--- a/src/target/datatype/posit/posit-wrapper.cc
+++ /dev/null
@@ -1,242 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file 3rdparty/posit/posit-wrapper.cc
- * \brief Wrapper over the Stillwater Universal library for Bring Your Own
Datatypes tests
- *
- * To compile TVM with this file,
- * 1. clone the Stillwater Universal repo from here
`https://github.com/stillwater-sc/universal`.
- * 2. set `SET_BYODT_POSIT` ON and `UNIVERSAL_PATH` as the path to the folder
containing Stillwater
- * Universal in your CMake file
- *
- * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist?
- */
-#include <tvm/runtime/base.h>
-
-#include <cstdint>
-
-#include "universal/posit/posit.hpp"
-// must go after posit.hpp
-#include "universal/posit/math/exponent.hpp"
-#include "universal/posit/math/hyperbolic.hpp"
-#include "universal/posit/math/logarithm.hpp"
-#include "universal/posit/math/sqrt.hpp"
-#include "universal/posit/numeric_limits.hpp"
-
-TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
- sw::unum::bitblock<8> bb;
- bb = static_cast<uint64_t>(in);
- return sw::unum::posit<8, 2>().set(bb);
-}
-
-extern "C" {
-TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
- return static_cast<uint8_t>(in.get().to_ullong());
-}
-
-TVM_DLL uint8_t MinPosit8es2() {
- auto min = std::numeric_limits<sw::unum::posit<8, 2>>::lowest();
- return Posit8es2toUint8(min);
-}
-
-TVM_DLL float Posit8es2ToFloat(uint8_t in) { return
Uint8ToPosit8es2(in).operator float(); }
-
-TVM_DLL uint8_t FloatToPosit8es2(float in) {
- auto posit = sw::unum::posit<8, 2>(in);
- return Posit8es2toUint8(posit);
-}
-
-TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) {
- return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) {
- return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) {
- return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) {
- return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b));
-}
-
-TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) {
- auto a_p = Uint8ToPosit8es2(a);
- auto b_p = Uint8ToPosit8es2(b);
- return Posit8es2toUint8(a_p > b_p ? a_p : b_p);
-}
-
-TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) {
- return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a)));
-}
-
-TVM_DLL uint8_t Posit8es2Exp(uint8_t a) {
- return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a)));
-}
-
-TVM_DLL uint8_t Posit8es2Log(uint8_t a) {
- return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a)));
-}
-
-TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) {
- auto posit_one = sw::unum::posit<8, 2>(1);
- return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) +
posit_one));
-}
-
-TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) {
- return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a)));
-}
-}
-
-TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
- sw::unum::bitblock<16> bb;
- bb = static_cast<uint64_t>(in);
- return sw::unum::posit<16, 2>().set(bb);
-}
-
-extern "C" {
-TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
- return static_cast<uint16_t>(in.get().to_ullong());
-}
-
-TVM_DLL uint8_t MinPosit16es2() {
- auto min = std::numeric_limits<sw::unum::posit<16, 2>>::lowest();
- return Posit16es2toUint16(min);
-}
-
-TVM_DLL float Posit16es2ToFloat(uint16_t in) { return
Uint16ToPosit16es2(in).operator float(); }
-
-TVM_DLL uint16_t FloatToPosit16es2(float in) {
- auto posit = sw::unum::posit<16, 2>(in);
- return Posit16es2toUint16(posit);
-}
-
-TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) {
- return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) {
- return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) {
- return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) {
- return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b));
-}
-
-TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) {
- auto a_p = Uint16ToPosit16es2(a);
- auto b_p = Uint16ToPosit16es2(b);
- return Posit16es2toUint16(a_p > b_p ? a_p : b_p);
-}
-
-TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) {
- return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a)));
-}
-
-TVM_DLL uint16_t Posit16es2Exp(uint16_t a) {
- return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a)));
-}
-
-TVM_DLL uint16_t Posit16es2Log(uint16_t a) {
- return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a)));
-}
-
-TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) {
- auto posit_one = sw::unum::posit<16, 2>(1);
- return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a))
+ posit_one));
-}
-
-TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) {
- return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a)));
-}
-}
-
-TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
- sw::unum::bitblock<32> bb;
- bb = static_cast<uint64_t>(in);
- return sw::unum::posit<32, 2>().set(bb);
-}
-
-extern "C" {
-TVM_DLL uint32_t Posit32es2ToUint32(sw::unum::posit<32, 2> in) {
- return static_cast<uint32_t>(in.get().to_ullong());
-}
-
-TVM_DLL uint8_t MinPosit32es2() {
- auto min = std::numeric_limits<sw::unum::posit<32, 2>>::lowest();
- return Posit32es2ToUint32(min);
-}
-
-TVM_DLL float Posit32es2ToFloat(uint32_t in) { return
Uint32ToPosit32es2(in).operator float(); }
-
-TVM_DLL uint32_t FloatToPosit32es2(float in) {
- auto posit = sw::unum::posit<32, 2>(in);
- return Posit32es2ToUint32(posit);
-}
-
-TVM_DLL uint32_t Posit32es2Add(uint32_t a, uint32_t b) {
- return Posit32es2ToUint32(Uint32ToPosit32es2(a) + Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Sub(uint32_t a, uint32_t b) {
- return Posit32es2ToUint32(Uint32ToPosit32es2(a) - Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Mul(uint32_t a, uint32_t b) {
- return Posit32es2ToUint32(Uint32ToPosit32es2(a) * Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Div(uint32_t a, uint32_t b) {
- return Posit32es2ToUint32(Uint32ToPosit32es2(a) / Uint32ToPosit32es2(b));
-}
-
-TVM_DLL uint32_t Posit32es2Max(uint32_t a, uint32_t b) {
- auto a_p = Uint32ToPosit32es2(a);
- auto b_p = Uint32ToPosit32es2(b);
- return Posit32es2ToUint32(a_p > b_p ? a_p : b_p);
-}
-
-TVM_DLL uint32_t Posit32es2Sqrt(uint32_t a) {
- return Posit32es2ToUint32(sw::unum::sqrt(Uint32ToPosit32es2(a)));
-}
-
-TVM_DLL uint32_t Posit32es2Exp(uint32_t a) {
- return Posit32es2ToUint32(sw::unum::exp(Uint32ToPosit32es2(a)));
-}
-
-TVM_DLL uint32_t Posit32es2Log(uint32_t a) {
- return Posit32es2ToUint32(sw::unum::log(Uint32ToPosit32es2(a)));
-}
-
-TVM_DLL uint32_t Posit32es2Sigmoid(uint32_t a) {
- auto posit_one = sw::unum::posit<32, 2>(1);
- return Posit32es2ToUint32(posit_one / (posit_one +
sw::unum::exp(-Uint32ToPosit32es2(a))));
-}
-
-TVM_DLL uint32_t Posit32es2Tanh(uint32_t a) {
- return Posit32es2ToUint32(sw::unum::tanh(Uint32ToPosit32es2(a)));
-}
-}
diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc
deleted file mode 100644
index 9d6459df6c..0000000000
--- a/src/target/datatype/registry.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include "registry.h"
-
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/runtime/data_type.h>
-
-namespace tvm {
-namespace datatype {
-
-using ffi::Any;
-using ffi::PackedArgs;
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef()
- .def_packed("dtype.register_custom_type",
- [](ffi::PackedArgs args, ffi::Any* ret) {
- datatype::Registry::Global()->Register(
- args[0].cast<std::string>(),
static_cast<uint8_t>(args[1].cast<int>()));
- })
- .def_packed("dtype.get_custom_type_code",
- [](ffi::PackedArgs args, ffi::Any* ret) {
- *ret =
datatype::Registry::Global()->GetTypeCode(args[0].cast<std::string>());
- })
- .def_packed("dtype.get_custom_type_name",
- [](ffi::PackedArgs args, ffi::Any* ret) {
- *ret =
Registry::Global()->GetTypeName(args[0].cast<int>());
- })
- .def_packed("runtime._datatype_get_type_registered", [](ffi::PackedArgs
args, ffi::Any* ret) {
- *ret = Registry::Global()->GetTypeRegistered(args[0].cast<int>());
- });
-}
-
-Registry* Registry::Global() {
- static Registry inst;
- return &inst;
-}
-
-void Registry::Register(const std::string& type_name, uint8_t type_code) {
- TVM_FFI_ICHECK(type_code >= DataType::kCustomBegin)
- << "Please choose a type code >= DataType::kCustomBegin for custom
types";
- code_to_name_[type_code] = type_name;
- name_to_code_[type_name] = type_code;
-}
-
-uint8_t Registry::GetTypeCode(const std::string& type_name) {
- TVM_FFI_ICHECK(name_to_code_.find(type_name) != name_to_code_.end())
- << "Type name " << type_name << " not registered";
- return name_to_code_[type_name];
-}
-
-std::string Registry::GetTypeName(uint8_t type_code) {
- TVM_FFI_ICHECK(code_to_name_.find(type_code) != code_to_name_.end())
- << "Type code " << static_cast<unsigned>(type_code) << " not registered";
- return code_to_name_[type_code];
-}
-
-std::optional<tvm::ffi::Function> GetCastLowerFunc(const std::string& target,
uint8_t type_code,
- uint8_t src_type_code) {
- std::ostringstream ss;
- ss << "tvm.datatype.lower.";
- ss << target << ".";
- ss << "Cast"
- << ".";
-
- if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
- ss << datatype::Registry::Global()->GetTypeName(type_code);
- } else {
- ss <<
ffi::details::DLDataTypeCodeAsCStr(static_cast<DLDataTypeCode>(type_code));
- }
-
- ss << ".";
-
- if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) {
- ss << datatype::Registry::Global()->GetTypeName(src_type_code);
- } else {
- ss <<
ffi::details::DLDataTypeCodeAsCStr(static_cast<DLDataTypeCode>(src_type_code));
- }
- return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-std::optional<tvm::ffi::Function> GetMinFunc(uint8_t type_code) {
- std::ostringstream ss;
- ss << "tvm.datatype.min.";
- ss << datatype::Registry::Global()->GetTypeName(type_code);
- return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-std::optional<tvm::ffi::Function> GetFloatImmLowerFunc(const std::string&
target,
- uint8_t type_code) {
- std::ostringstream ss;
- ss << "tvm.datatype.lower.";
- ss << target;
- ss << ".FloatImm.";
- ss << datatype::Registry::Global()->GetTypeName(type_code);
- return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-std::optional<tvm::ffi::Function> GetIntrinLowerFunc(const std::string& target,
- const std::string& name,
uint8_t type_code) {
- std::ostringstream ss;
- ss << "tvm.datatype.lower.";
- ss << target;
- ss << ".Call.intrin.";
- ss << name;
- ss << ".";
- ss << datatype::Registry::Global()->GetTypeName(type_code);
- return tvm::ffi::Function::GetGlobal(ss.str());
-}
-
-uint64_t ConvertConstScalar(uint8_t type_code, double value) {
- std::ostringstream ss;
- ss << "tvm.datatype.convertconstscalar.float.";
- ss << datatype::Registry::Global()->GetTypeName(type_code);
- auto make_const_scalar_func = tvm::ffi::Function::GetGlobal(ss.str());
- return (*make_const_scalar_func)(value).cast<uint64_t>();
-}
-
-} // namespace datatype
-} // namespace tvm
diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h
deleted file mode 100644
index 363494e0fd..0000000000
--- a/src/target/datatype/registry.h
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_TARGET_DATATYPE_REGISTRY_H_
-#define TVM_TARGET_DATATYPE_REGISTRY_H_
-
-#include <tvm/ffi/function.h>
-
-#include <string>
-#include <unordered_map>
-
-namespace tvm {
-namespace datatype {
-
-/*!
- * \brief Registry for custom datatypes.
- *
- * Adding custom datatypes currently requires two steps:
- * 1. Register the datatype with the registry via a call to
- * datatype::Registry::Register. This can also be done in Python
- * directly---see the TVM globals registered in the corresponding .cc file.
- * Currently, user should manually choose a type name and a type code,
- * ensuring that neither conflict with existing types.
- * 2. Register the lowering functions needed to
- * lower the custom datatype. In general, these will look like:
- * For Casts: tvm.datatype.lower.<target>.Cast.<type>.<src_type>
- * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from
- * float to myfloat.
- * For intrinsic Calls: tvm.datatype.lower.<target>.Call.intrin.<name>.<type>
- * Example: tvm.datatype.lower.llvm.Call.intrin.sqrt.myfloat
- * For other ops: tvm.datatype.lower.<target>.<op>.<type>
- * Examples: tvm.datatype.lower.llvm.Add.myfloat
- * tvm.datatype.lower.llvm.FloatImm.posit
- */
-class Registry {
- public:
- /*!
- * \brief Get the global custom datatype registry singleton
- */
- static Registry* Global();
-
- /*!
- * \brief Register custom datatype
- * Register a custom datatype with the given type name and type code.
Currently, the type code is
- * manually allocated by the user, and the user must ensure that no two
custom types share the
- * same code. Generally, this should be straightforward, as the user will be
manually registering
- * all of their custom types.
- * \param type_name The name of the type, e.g. "posites2"
- * \param type_code The type code, which should be greater than
TVMArgTypeCode::kTVMExtEnd
- */
- void Register(const std::string& type_name, uint8_t type_code);
-
- /*!
- * \brief Get type code from type name
- * \param type_name The type name
- * \return The type code
- */
- uint8_t GetTypeCode(const std::string& type_name);
-
- /*!
- * \brief Get type name from type code
- * \param type_code The type code
- * \return The type name
- */
- std::string GetTypeName(uint8_t type_code);
-
- /*!
- * \brief Get bool representing whether type is registered, given the type
code
- * \param type_code The type code
- * \return bool representing whether the type is registered
- */
- inline bool GetTypeRegistered(uint8_t type_code) {
- return code_to_name_.find(type_code) != code_to_name_.end();
- }
-
- /*!
- * \brief Get bool representing whether type is registered, given the type
name
- * \param type_name The type name
- * \return bool representing whether the type is registered
- */
- inline bool GetTypeRegistered(std::string type_name) {
- return name_to_code_.find(type_name) != name_to_code_.end();
- }
-
- private:
- // TODO(gus) is there a typedef for the code?
- std::unordered_map<uint8_t, std::string> code_to_name_;
- std::unordered_map<std::string, uint8_t> name_to_code_;
-};
-
-/*!
- * \brief Convert scalar value to a custom datatype format
- * \param type_code The custom datatype to convert to, specified by type code
- * \param value The floating point value to convert
- * \return The value, encoded in the bits of a uint64_t
- */
-uint64_t ConvertConstScalar(uint8_t type_code, double value);
-
-/*!
- * \brief Get a function returning the minimum value for a datatype.
- * \param type_code The datatype
- * \return Function which takes the width of the datatype and returns the min
value
- */
-std::optional<tvm::ffi::Function> GetMinFunc(uint8_t type_code);
-
-/*!
- * \brief Get lowering function for Cast ops
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype being cast to
- * \param src_type_code The datatype being cast from
- * \return Lowering function for Cast ops for the provided target, type, and
source type
- */
-std::optional<tvm::ffi::Function> GetCastLowerFunc(const std::string& target,
uint8_t type_code,
- uint8_t src_type_code);
-
-/*!
- * \brief Get lowering function for FloatImms
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype of the FloatImm
- * \return Lowering function for FloatImms for the provided target and type
- */
-std::optional<tvm::ffi::Function> GetFloatImmLowerFunc(const std::string&
target,
- uint8_t type_code);
-
-/*!
- * \brief Get lowering function for intrinsic Calls/pure intrinsic Calls
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype of the Call
- * \param name The intrinsic name
- * \return Lowering function for intrinsic Calls for the provided target and
type
- */
-std::optional<tvm::ffi::Function> GetIntrinLowerFunc(const std::string& target,
- const std::string& name,
uint8_t type_code);
-
-/*!
- * \brief Get lowering function for other ops
- * \param target The target we are lowering to, e.g. "llvm"
- * \param type_code The datatype of the op
- * \return Lowering function for other ops for the provided target and type
- */
-#define DEFINE_GET_LOWER_FUNC_(OP)
\
- inline std::optional<tvm::ffi::Function> Get##OP##LowerFunc(const
std::string& target, \
- uint8_t
type_code) { \
- return tvm::ffi::Function::GetGlobal("tvm.datatype.lower." + target + "."
#OP "." + \
-
datatype::Registry::Global()->GetTypeName(type_code)); \
- }
-
-DEFINE_GET_LOWER_FUNC_(Add)
-DEFINE_GET_LOWER_FUNC_(Sub)
-DEFINE_GET_LOWER_FUNC_(Mul)
-DEFINE_GET_LOWER_FUNC_(Div)
-DEFINE_GET_LOWER_FUNC_(Mod)
-DEFINE_GET_LOWER_FUNC_(Min)
-DEFINE_GET_LOWER_FUNC_(Max)
-DEFINE_GET_LOWER_FUNC_(EQ)
-DEFINE_GET_LOWER_FUNC_(NE)
-DEFINE_GET_LOWER_FUNC_(LT)
-DEFINE_GET_LOWER_FUNC_(LE)
-DEFINE_GET_LOWER_FUNC_(GT)
-DEFINE_GET_LOWER_FUNC_(GE)
-// Later changes may need to add more lowering functions as we support
workloads with more ops.
-
-} // namespace datatype
-} // namespace tvm
-
-#endif // TVM_TARGET_DATATYPE_REGISTRY_H_
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 97422bf9ed..88a28ebccb 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -597,11 +597,10 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type)
const {
if (auto* ptr = type.as<PrimTypeNode>()) {
return DTypeToLLVMType(ptr->dtype);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
- // LLVM IR doesn't allow void*, nor do we require custom datatypes
- // to have LLVM equivalents, so we need to recognize these
- // patterns explicitly.
+ // LLVM IR doesn't allow void*, so pointer element types that do not
+ // have an LLVM scalar equivalent need explicit handling.
if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) {
- if (primtype->dtype.is_void() || primtype->dtype.code() >=
DataType::kCustomBegin) {
+ if (primtype->dtype.is_void()) {
return t_void_p_;
}
} else if (ptr->element_type->IsInstance<TensorMapTypeNode>()) {
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index 64f7f575d6..5cf896e4fd 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -35,7 +35,6 @@
#include <cmath>
// Centralized header for constant folders.
#include "../../arith/const_fold.h"
-#include "../../target/datatype/registry.h"
#include "../analysis/check_contains.h"
namespace tvm {
@@ -211,22 +210,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs,
Span span) { // NOLINT(*)
} else {
rhs = cast(ltype, rhs);
}
- } else if (!ltype.is_float() &&
- (rtype.is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
+ } else if (!ltype.is_float() && rtype.is_float()) {
// Cast int->float when the other operand is a float
lhs = cast(rtype, lhs);
- } else if ((ltype.is_float() ||
datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
- !rtype.is_float()) {
+ } else if (ltype.is_float() && !rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(ltype, rhs);
- } else if (!ltype.is_bfloat16() &&
- (rtype.is_bfloat16() ||
- datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
+ } else if (!ltype.is_bfloat16() && rtype.is_bfloat16()) {
// Cast int->bfloat16 when the other operand is a bfloat16
lhs = cast(rtype, lhs);
- } else if ((ltype.is_bfloat16() ||
- datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
- !rtype.is_bfloat16()) {
+ } else if (ltype.is_bfloat16() && !rtype.is_bfloat16()) {
// Cast int->bfloat16 when the other operand is a bfloat16
rhs = cast(ltype, rhs);
} else if (!ltype.is_float8() && rtype.is_float8()) {
@@ -369,15 +362,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
PrimExpr min_value(const DataType& dtype, Span span) {
using namespace tirx;
TVM_FFI_ICHECK_EQ(dtype.lanes(), 1);
- if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
- // TODO(tkonolige): need to convert all registered min functions to use
the span.
- auto f = datatype::GetMinFunc(dtype.code());
- TVM_FFI_ICHECK(f) << "No minimum function registered for custom dtype "
- << (unsigned int)dtype.code();
- // TODO(@hypercubestart) Document this change (and others associated with
the overflowing
- // floatimm min bug)
- return (*f)(dtype.bits()).cast<PrimExpr>();
- } else if (dtype.is_int()) {
+ if (dtype.is_int()) {
if (dtype.bits() == 64) {
return IntImm(dtype, std::numeric_limits<int64_t>::lowest(), span);
} else if (dtype.bits() < 64) {
diff --git a/src/tirx/transform/lower_custom_datatypes.cc
b/src/tirx/transform/lower_custom_datatypes.cc
deleted file mode 100644
index d23cfef4fb..0000000000
--- a/src/tirx/transform/lower_custom_datatypes.cc
+++ /dev/null
@@ -1,266 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file tvm/src/pass/lower_custom_datatypes.cc
- * \brief Pass for lowering custom datatypes
- */
-
-#include <tvm/ffi/cast.h>
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/target/target.h>
-#include <tvm/tirx/op.h>
-#include <tvm/tirx/stmt_functor.h>
-#include <tvm/tirx/transform.h>
-
-#include "../../target/datatype/registry.h"
-
-namespace tvm {
-namespace tirx {
-
-/*!
- * \brief Helper mutator to implement lowering of custom datatypes.
- *
- * Lowering datatypes works as follows: for every expression containing a
custom
- * datatype, we search for a global (registered by the implementer of the
custom
- * datatype) for lowering this type of expression, and uses it to lower the
- * expression.
- */
-class CustomDatatypesLowerer : public StmtExprMutator {
- public:
- explicit CustomDatatypesLowerer(const std::string& target) : target_(target)
{}
-
- PrimExpr VisitExpr_(const CastNode* op) final {
- auto type_code = op->dtype.code();
- auto src_type_code = op->value.dtype().code();
- // If either datatype is a registered custom datatype, we must lower.
- bool to_be_lowered =
datatype::Registry::Global()->GetTypeRegistered(type_code) ||
-
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- if (to_be_lowered) {
- auto lower = datatype::GetCastLowerFunc(target_, type_code,
src_type_code);
- TVM_FFI_ICHECK(lower) << "Cast lowering function for target " << target_
- << " destination type " <<
static_cast<unsigned>(type_code)
- << " source type " <<
static_cast<unsigned>(src_type_code)
- << " not found";
- return (*lower)(expr).cast<PrimExpr>();
- }
- return expr;
- }
-
- PrimExpr VisitExpr_(const FloatImmNode* imm) final {
- auto type_code = imm->dtype.code();
- auto e = ffi::GetRef<PrimExpr>(imm);
- if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
- auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
- TVM_FFI_ICHECK(lower) << "FloatImm lowering function for target " <<
target_ << " type "
- << static_cast<unsigned>(type_code) << " not
found";
- return (*lower)(e).cast<PrimExpr>();
- }
- return e;
- }
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- Var var = ffi::GetRef<Var>(op);
-
- auto itr = var_remap_.find(var);
- if (itr != var_remap_.end()) {
- return itr->second;
- } else {
- return var;
- }
- }
-
- Stmt VisitStmt_(const AllocBufferNode* op) final {
- bool to_be_lowered =
datatype::Registry::Global()->GetTypeRegistered(op->buffer->dtype.code());
-
- if (to_be_lowered) {
- auto new_allocate_type = DataType::UInt(op->buffer->dtype.bits(),
op->buffer->dtype.lanes());
- auto new_buffer_var =
- Var(op->buffer->data->name_hint,
PointerType(PrimType(new_allocate_type)));
- var_remap_[op->buffer->data] = new_buffer_var;
- }
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AllocBufferNode>();
-
- Buffer new_buf = GetRemappedBuffer(op->buffer);
- if (!new_buf.same_as(op->buffer)) {
- auto node = Downcast<AllocBuffer>(stmt);
- node.CopyOnWrite()->buffer = new_buf;
- return node;
- }
- return stmt;
- }
-
- Stmt VisitStmt_(const DeclBufferNode* op) final {
- auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
- auto modified = VisitBufferAccess(node);
-
- // Not needed for BufferStoreNode, so we can't just call
- // LegalizeDtype() in VisitBufferAccess.
- if (node.same_as(modified)) {
- return node;
-
- } else {
- auto writer = modified.CopyOnWrite();
- writer->LegalizeDType();
- return modified;
- }
- }
-
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- template <typename Node>
- Node VisitBufferAccess(Node node) {
- Buffer new_buf = GetRemappedBuffer(node->buffer);
- if (!new_buf.same_as(node->buffer)) {
- auto writer = node.CopyOnWrite();
- writer->buffer = new_buf;
- }
-
- return node;
- }
-
- Buffer GetRemappedBuffer(Buffer buf) {
- auto key = buf;
- auto cache_it = buf_remap_.find(key);
- if (cache_it != buf_remap_.end()) {
- return cache_it->second;
- }
-
- bool to_be_lowered =
datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code());
-
- if (to_be_lowered) {
- auto new_load_type = DataType::UInt(buf->dtype.bits());
- auto writer = buf.CopyOnWrite();
- writer->dtype = new_load_type;
-
- auto var_it = var_remap_.find(buf->data);
- if (var_it != var_remap_.end()) {
- writer->data = var_it->second;
- }
- }
-
- buf_remap_[key] = buf;
- return buf;
- }
-
- Stmt VisitStmt_(const AttrStmtNode* op) final {
- Stmt ret = StmtExprMutator::VisitStmt_(op);
- op = ret.as<AttrStmtNode>();
- // Due to legacy reasons, some attr node can contain
- // information(e.g. alignment) of buffer variables.
- // remap these vars when needed
- // TODO(tvm-team): remove the rewriting once the buffer var
- // attrs are being refactored into the corresponding definition node
- if (auto var_node = op->node.as<Var>()) {
- auto it = var_remap_.find(var_node.value());
- if (it != var_remap_.end()) {
- return AttrStmt(it->second, op->attr_key, op->value, op->body);
- }
- }
- return ret;
- }
-
- PrimExpr VisitExpr_(const CallNode* call) final {
- bool to_be_lowered =
datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
- PrimExpr expr = StmtExprMutator::VisitExpr_(call);
- call = expr.as<CallNode>();
- if (to_be_lowered) {
- auto op = call->op.as<OpNode>();
- TVM_FFI_ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not
implemented";
- auto lower = datatype::GetIntrinLowerFunc(target_, op->name,
call->dtype.code());
- TVM_FFI_ICHECK(lower) << "Intrinsic lowering function for target " <<
target_
- << ", intrinsic name " << op->name << ", type "
- << static_cast<unsigned>(call->dtype.code()) << "
not found";
- return (*lower)(expr).cast<PrimExpr>();
- }
- return expr;
- }
-
-#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName)
\
- PrimExpr VisitExpr_(const NodeName* op) final {
\
- auto type_code = op->dtype.code();
\
- bool to_be_lowered =
datatype::Registry::Global()->GetTypeRegistered(type_code); \
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
\
- op = expr.as<NodeName>();
\
- if (to_be_lowered) {
\
- auto lower = datatype::Get##OP##LowerFunc(target_, type_code);
\
- TVM_FFI_ICHECK(lower) << #OP " lowering function for target " << target_
<< " type " \
- << static_cast<unsigned>(type_code) << " not
found"; \
- return (*lower)(expr).cast<PrimExpr>();
\
- }
\
- return expr;
\
- }
-
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode);
- TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode);
- // Later changes may need to add more mutate functions as we support
workloads with more ops.
-
-#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE
-
- private:
- std::string target_;
- // remap buffer vars
- std::unordered_map<Var, Var> var_remap_;
- std::unordered_map<Buffer, Buffer, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
buf_remap_;
-};
-
-namespace transform {
-
-Pass LowerCustomDatatypes() {
- auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- auto* n = f.CopyOnWrite();
- auto target = f->GetAttr<Target>(tvm::attr::kTarget);
- TVM_FFI_ICHECK(target.defined()) << "LowerCustomDatatypes: Require the
target attribute";
-
- n->body =
CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body));
- return f;
- };
- return CreatePrimFuncPass(pass_func, 0, "tirx.LowerCustomDatatypes", {});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tirx.transform.LowerCustomDatatypes",
LowerCustomDatatypes);
-}
-
-} // namespace transform
-
-} // namespace tirx
-} // namespace tvm