gussmith23 commented on a change in pull request #5812:
URL: https://github.com/apache/incubator-tvm/pull/5812#discussion_r477643995



##########
File path: 3rdparty/posit/posit-wrapper.cc
##########
@@ -0,0 +1,211 @@
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <cstdint>
+
+#include "universal/posit/posit.hpp"
+// must go after posit.hpp
+#include "universal/posit/math/exponent.hpp"
+#include "universal/posit/math/hyperbolic.hpp"
+#include "universal/posit/math/logarithm.hpp"
+#include "universal/posit/math/sqrt.hpp"
+
+TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) {
+  sw::unum::bitblock<8> bb;
+  bb = static_cast<unsigned long long>(in);
+  return sw::unum::posit<8, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint8_t RawPosit8es2(uint8_t in) { return in; }
+
+TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) {
+  return static_cast<uint8_t>(in.get().to_ullong());
+}
+
+TVM_DLL float Posit8es2ToFloat(uint8_t in) { return 
Uint8ToPosit8es2(in).operator float(); }
+
+TVM_DLL uint8_t FloatToPosit8es2(float in) {
+  auto posit = sw::unum::posit<8, 2>(in);
+  return Posit8es2toUint8(posit);
+}
+
+// TODO(gus) how wide should the input be?
+TVM_DLL uint8_t IntToPosit8es2(int in) { return 
Posit8es2toUint8(sw::unum::posit<8, 2>(in)); }
+
+TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) {
+  return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b));
+}
+
+TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) {
+  auto a_p = Uint8ToPosit8es2(a);
+  auto b_p = Uint8ToPosit8es2(b);
+  return Posit8es2toUint8(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Exp(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Log(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a)));
+}
+
+TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) {
+  auto posit_one = sw::unum::posit<8, 2>(1);
+  return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) + 
posit_one));
+}
+
+TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) {
+  return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) {
+  sw::unum::bitblock<16> bb;
+  bb = static_cast<unsigned long long>(in);
+  return sw::unum::posit<16, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint16_t RawPosit16es2(uint16_t in) { return in; }
+
+TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) {
+  return static_cast<uint16_t>(in.get().to_ullong());
+}
+
+TVM_DLL float Posit16es2ToFloat(uint16_t in) { return 
Uint16ToPosit16es2(in).operator float(); }
+
+TVM_DLL uint16_t FloatToPosit16es2(float in) {
+  auto posit = sw::unum::posit<16, 2>(in);
+  return Posit16es2toUint16(posit);
+}
+
+// TODO(gus) how wide should the input be?
+TVM_DLL uint16_t IntToPosit16es2(int in) { return 
Posit16es2toUint16(sw::unum::posit<16, 2>(in)); }
+
+TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) {
+  return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b));
+}
+
+TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) {
+  auto a_p = Uint16ToPosit16es2(a);
+  auto b_p = Uint16ToPosit16es2(b);
+  return Posit16es2toUint16(a_p > b_p ? a_p : b_p);
+}
+
+TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Exp(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Log(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a)));
+}
+
+TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) {
+  auto posit_one = sw::unum::posit<16, 2>(1);
+  return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a)) 
+ posit_one));
+}
+
+TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) {
+  return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a)));
+}
+}
+
+TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) {
+  sw::unum::bitblock<32> bb;
+  bb = static_cast<unsigned long long>(in);
+  return sw::unum::posit<32, 2>().set(bb);
+}
+
+extern "C" {
+TVM_DLL uint32_t RawPosit32es2(uint32_t in) { return in; }

Review comment:
       It's confusing me as is, because it's taking in a `uint32_t` and 
returning a `uint32_t` -- is there something I'm missing? It doesn't seem to be 
doing anything at the moment!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to