gussmith23 commented on a change in pull request #5812:
URL: https://github.com/apache/incubator-tvm/pull/5812#discussion_r473443975
##########
File path: python/tvm/target/datatype.py
##########
@@ -135,8 +155,40 @@ def lower(op):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
- if isinstance(op, (_Cast, _FloatImm)):
- return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value)
- return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
+ if isinstance(op, _Cast):
+ src_bits = bit_length(op.value.dtype)
+ return call_pure_extern(dtype, extern_func_map[(src_bits,
t.bits)], op.value)
+ if isinstance(op, _FloatImm):
+ return call_pure_extern(dtype, extern_func_map[t.bits], op.value)
+ if isinstance(op, _Call):
+ return call_pure_extern(dtype, extern_func_map[t.bits], *op.args)
+ if isinstance(op, _BinaryOpExpr):
+ return call_pure_extern(dtype, extern_func_map[t.bits], op.a, op.b)
+
+ raise RuntimeError(f"lowering unsupported op: {op.astext()}")
return lower
+
+def bit_length(type_str):
+ t = DataType(type_str)
+ return t.bits
+
+def lower_ite(ite_intrin):
Review comment:
- [ ] @gussmith23 Document
##########
File path: 3rdparty/posit/posit-wrapper.cc
##########
@@ -0,0 +1,211 @@
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cstdint>
+
+#include "universal/posit/posit.hpp"
+// must go after posit.hpp
+#include "universal/posit/math/exponent.hpp"
+#include "universal/posit/math/hyperbolic.hpp"
+#include "universal/posit/math/logarithm.hpp"
+#include "universal/posit/math/sqrt.hpp"
Review comment:
Interesting -- import order shouldn't matter. Are we using universal
correctly?
##########
File path: python/tvm/target/datatype.py
##########
@@ -94,33 +100,47 @@ class name (e.g. Add, LE, Cast).
target : str
The name of codegen target.
- type_name : str
+ src_type_name : str
The name of the custom datatype, e.g. posit (but not custom[posit]8).
Review comment:
Update to `posites2`
##########
File path: python/tvm/target/datatype.py
##########
@@ -94,33 +100,47 @@ class name (e.g. Add, LE, Cast).
target : str
The name of codegen target.
- type_name : str
+ src_type_name : str
The name of the custom datatype, e.g. posit (but not custom[posit]8).
- src_type_name : str
- If op_name is "Cast", then this should be set to the source datatype of
+ dest_type_name : str
+ If op_name is "Cast", then this is required and should be set to the
dest datatype of
the argument to the Cast. If op_name is not "Cast", this is unused.
+
+ intrinsic_name : str
+ If op_name is "Call" and intrinsic_name is not None, then we assume the
+ op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's
+ name.
"""
if op_name == "Cast":
- assert src_type_name is not None
+ assert dest_type_name is not None
lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "."
\
- + type_name + "." + src_type_name
+ + dest_type_name + "." + src_type_name
+ elif op_name == "Call" and intrinsic_name is not None:
+ lower_func_name = "tvm.datatype.lower." + target + "." + op_name \
+ + ".intrin." + intrinsic_name + "." + src_type_name
else:
lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "."
\
- + type_name
+ + src_type_name
tvm._ffi.register_func(lower_func_name, lower_func)
+# TODO(gus) could probably make this a decorator if i want
Review comment:
- [ ] @gussmith23 Document
##########
File path: python/tvm/relay/frontend/change_datatype.py
##########
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Change Datatype Pass"""
+from ..function import Function
+from ..expr_functor import ExprMutator
+from ..transform.transform import function_pass
+from ..expr import var, bind
+
+
+# TODO(@gussmith23) what's the right opt level here?
+@function_pass(opt_level=0)
+class ChangeDatatype(ExprMutator):
+ """Mutator for changing the datatype of Relay programs.
+
+ Example usage:
+ ```python
+ from tvm.relay.testing.inception_v3 import get_workload
+ expr, params = get_workload()
+
+ def change_dtype(src, dst, expr, params):
+ cdtype = ChangeDatatype(src, dst)
+ expr = cdtype.visit(expr)
+ expr = relay.ir_pass.infer_type(expr)
+ params = dict((p, tvm.nd.array(params[p].asnumpy().astype(dst))) for p
in params)
+ return expr, params
Review comment:
@gussmith23 update this code
##########
File path: src/target/datatype/registry.h
##########
@@ -60,7 +62,7 @@ class Registry {
* manually allocated by the user, and the user must ensure that no two
custom types share the
* same code. Generally, this should be straightforward, as the user will be
manually registering
* all of their custom types.
- * \param type_name The name of the type, e.g. "bfloat"
+ * \param type_name The name of the type, e.g. "posit"
Review comment:
Change to `posites2`
##########
File path: 3rdparty/posit/posit-wrapper.cc
##########
@@ -0,0 +1,211 @@
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cstdint>
+
+#include "universal/posit/posit.hpp"
+// must go after posit.hpp
+#include "universal/posit/math/exponent.hpp"
+#include "universal/posit/math/hyperbolic.hpp"
+#include "universal/posit/math/logarithm.hpp"
+#include "universal/posit/math/sqrt.hpp"
+
+TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
+ sw::unum::bitblock<8> bb;
+ bb = static_cast<unsigned long long>(in);
+ return sw::unum::posit<8, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint8_t RawPosit8es2(uint8_t in) { return in; }
+
+TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
+ return static_cast<uint8_t>(in.get().to_ullong());
+}
+
+TVM_DLL float Posit8es2ToFloat(uint8_t in) { return
Uint8ToPosit8es2(in).operator float(); }
+
+TVM_DLL uint8_t FloatToPosit8es2(float in) {
+ auto posit = sw::unum::posit<8, 2>(in);
+ return Posit8es2toUint8(posit);
+}
+
+// TODO(gus) how wide should the input be?
+TVM_DLL uint8_t IntToPosit8es2(int in) { return
Posit8es2toUint8(sw::unum::posit<8, 2>(in)); }
+
+TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) {
+ return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) {
+ auto a_p = Uint8ToPosit8es2(a);
+ auto b_p = Uint8ToPosit8es2(b);
+ return Posit8es2toUint8(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Exp(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Log(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) {
+ auto posit_one = sw::unum::posit<8, 2>(1);
+ return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) +
posit_one));
+}
+
+TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) {
+ return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
+ sw::unum::bitblock<16> bb;
+ bb = static_cast<unsigned long long>(in);
+ return sw::unum::posit<16, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint16_t RawPosit16es2(uint16_t in) { return in; }
+
+TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
+ return static_cast<uint16_t>(in.get().to_ullong());
+}
+
+TVM_DLL float Posit16es2ToFloat(uint16_t in) { return
Uint16ToPosit16es2(in).operator float(); }
+
+TVM_DLL uint16_t FloatToPosit16es2(float in) {
+ auto posit = sw::unum::posit<16, 2>(in);
+ return Posit16es2toUint16(posit);
+}
+
+// TODO(gus) how wide should the input be?
+TVM_DLL uint16_t IntToPosit16es2(int in) { return
Posit16es2toUint16(sw::unum::posit<16, 2>(in)); }
+
+TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) {
+ return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) {
+ auto a_p = Uint16ToPosit16es2(a);
+ auto b_p = Uint16ToPosit16es2(b);
+ return Posit16es2toUint16(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Exp(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Log(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) {
+ auto posit_one = sw::unum::posit<16, 2>(1);
+ return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a))
+ posit_one));
+}
+
+TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) {
+ return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
+ sw::unum::bitblock<32> bb;
+ bb = static_cast<unsigned long long>(in);
+ return sw::unum::posit<32, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint32_t RawPosit32es2(uint32_t in) { return in; }
Review comment:
What are these for?
##########
File path: tests/python/unittest/test_custom_datatypes.py
##########
@@ -0,0 +1,405 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for changing datatypes of models."""
+import tvm
+import tvm.topi.testing
+import numpy as np
+import pytest
+from numpy.random import MT19937, RandomState, SeedSequence
+from tvm import relay
+from tvm.relay.testing.inception_v3 import get_workload as get_inception
+from tvm.relay.testing.resnet import get_workload as get_resnet
+from tvm.relay.testing.layers import batch_norm_infer
+from tvm.relay.testing.mobilenet import get_workload as get_mobilenet
+from tvm.target.datatype import register, register_min_func, register_op,
create_lower_func, lower_ite, lower_call_pure_extern
+from tvm.tir.op import call_pure_extern
+from nose.tools import nottest
+
+# we use a random seed to generate input_data
+# to guarantee stable tests
+rs = RandomState(MT19937(SeedSequence(123456789)))
+
+def convert_ndarray(dst_dtype, *args, **kwargs):
+ """Converts NDArray(s) into the specified datatype"""
+ def convert(array):
+ x = relay.var('x', shape=array.shape, dtype=str(array.dtype))
+ cast = relay.Function([x], x.astype(dst_dtype))
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ return relay.create_executor('graph').evaluate(cast)(array)
+
+ return (tuple([convert(x) for x in args]), {k: convert(v) for (k, v) in
kwargs.items()})
+
+
+def change_dtype(src, dst, module, params):
+ module = relay.frontend.ChangeDatatype(src, dst)(module)
+ module = relay.transform.InferType()(module)
+ params = dict((p, convert_ndarray(dst, params[p])) for p in params)
+ return module, params
+
+def compare(module, input, src_dtype, dst_dtype, rtol, atol, params = {},
target='llvm'):
+ module = relay.transform.SimplifyInference()(module)
+ ex = relay.create_executor("graph", mod=module)
+
+ correct = ex.evaluate()(*input, **params)
+
+ module, _ = change_dtype(src_dtype, dst_dtype, module, [])
+ ex = relay.create_executor("graph", mod=module, target=target)
+ # converts all inputs to dst_dtype
+ x_converted, x_params_converted = convert_ndarray(dst_dtype, *input,
**params)
+
+ # Vectorization is not implemented with custom datatypes
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
Review comment:
- [ ] @gussmith23 document wherever we are documenting stuff that
vectorization must be disabled
##########
File path: python/tvm/target/datatype.py
##########
@@ -135,8 +155,40 @@ def lower(op):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
- if isinstance(op, (_Cast, _FloatImm)):
- return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value)
- return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
+ if isinstance(op, _Cast):
+ src_bits = bit_length(op.value.dtype)
+ return call_pure_extern(dtype, extern_func_map[(src_bits,
t.bits)], op.value)
+ if isinstance(op, _FloatImm):
+ return call_pure_extern(dtype, extern_func_map[t.bits], op.value)
+ if isinstance(op, _Call):
+ return call_pure_extern(dtype, extern_func_map[t.bits], *op.args)
+ if isinstance(op, _BinaryOpExpr):
+ return call_pure_extern(dtype, extern_func_map[t.bits], op.a, op.b)
+
+ raise RuntimeError(f"lowering unsupported op: {op.astext()}")
return lower
+
+def bit_length(type_str):
+ t = DataType(type_str)
+ return t.bits
+
+def lower_ite(ite_intrin):
+ dtype = ite_intrin.dtype
+ t = tvm.DataType(dtype)
+ assert get_type_registered(t.type_code)
+ dtype = "uint" + str(t.bits)
+ if t.lanes > 1:
+ dtype += "x" + str(t.lanes)
+ return call_intrin(dtype, "tir.if_then_else", convert(ite_intrin.args[0]),
+ convert(ite_intrin.args[1]),
+ convert(ite_intrin.args[2]))
+
+def lower_call_pure_extern(op):
Review comment:
- [ ] @gussmith23 Document
##########
File path: tests/python/unittest/test_custom_datatypes.py
##########
@@ -0,0 +1,405 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for changing datatypes of models."""
+import tvm
+import tvm.topi.testing
+import numpy as np
+import pytest
+from numpy.random import MT19937, RandomState, SeedSequence
+from tvm import relay
+from tvm.relay.testing.inception_v3 import get_workload as get_inception
+from tvm.relay.testing.resnet import get_workload as get_resnet
+from tvm.relay.testing.layers import batch_norm_infer
+from tvm.relay.testing.mobilenet import get_workload as get_mobilenet
+from tvm.target.datatype import register, register_min_func, register_op,
create_lower_func, lower_ite, lower_call_pure_extern
+from tvm.tir.op import call_pure_extern
+from nose.tools import nottest
+
+# we use a random seed to generate input_data
+# to guarantee stable tests
+rs = RandomState(MT19937(SeedSequence(123456789)))
+
+def convert_ndarray(dst_dtype, *args, **kwargs):
+ """Converts NDArray(s) into the specified datatype"""
+ def convert(array):
+ x = relay.var('x', shape=array.shape, dtype=str(array.dtype))
+ cast = relay.Function([x], x.astype(dst_dtype))
+ with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+ return relay.create_executor('graph').evaluate(cast)(array)
+
+ return (tuple([convert(x) for x in args]), {k: convert(v) for (k, v) in
kwargs.items()})
+
+
+def change_dtype(src, dst, module, params):
+ module = relay.frontend.ChangeDatatype(src, dst)(module)
+ module = relay.transform.InferType()(module)
+ params = dict((p, convert_ndarray(dst, params[p])) for p in params)
+ return module, params
+
+def compare(module, input, src_dtype, dst_dtype, rtol, atol, params = {},
target='llvm'):
+ module = relay.transform.SimplifyInference()(module)
Review comment:
- [ ] @gussmith23 document here that `SimplifyInference` must be used.
- [ ] @gussmith23 (more importantly) document elsewhere that
`SimplifyInference` must be used.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]