This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new a400f82 [TFLite Runtime] Fix bug and re-enable RPC execution test
(#5436)
a400f82 is described below
commit a400f825281f3c6f0688e8b16deea4ba12ee6bb5
Author: Michal Piszczek <[email protected]>
AuthorDate: Thu May 14 20:16:57 2020 -0700
[TFLite Runtime] Fix bug and re-enable RPC execution test (#5436)
---
src/runtime/contrib/tflite/tflite_runtime.cc | 8 +-
src/runtime/contrib/tflite/tflite_runtime.h | 3 +
src/runtime/module.cc | 2 +
tests/python/contrib/test_tflite_runtime.py | 202 ++++++++++++++++-----------
tests/scripts/task_config_build_cpu.sh | 3 +
5 files changed, 135 insertions(+), 83 deletions(-)
diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc
b/src/runtime/contrib/tflite/tflite_runtime.cc
index 53d7754..8b34e90 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.cc
+++ b/src/runtime/contrib/tflite/tflite_runtime.cc
@@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext
ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
+ // The buffer used to construct the model must be kept alive for
+ // dependent interpreters to be used.
+ flatBuffersBuffer_ = std::unique_ptr<char[]>(new char[buffer_size]);
+ std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size);
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
+ tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(),
buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model,
resolver)(&interpreter_);
@@ -173,5 +177,7 @@ Module TFLiteRuntimeCreate(const std::string&
tflite_model_bytes, TVMContext ctx
TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args,
TVMRetValue* rv) {
*rv = TFLiteRuntimeCreate(args[0], args[1]);
});
+
+TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate);
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/contrib/tflite/tflite_runtime.h
b/src/runtime/contrib/tflite/tflite_runtime.h
index f61f6ee..f3e3bd9 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.h
+++ b/src/runtime/contrib/tflite/tflite_runtime.h
@@ -26,6 +26,7 @@
#define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_
#include <dlpack/dlpack.h>
+#include <tensorflow/lite/interpreter.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
@@ -93,6 +94,8 @@ class TFLiteRuntime : public ModuleNode {
*/
NDArray GetOutput(int index) const;
+ // Buffer backing the interpreter's model
+ std::unique_ptr<char[]> flatBuffersBuffer_;
// TFLite interpreter
std::unique_ptr<tflite::Interpreter> interpreter_;
// TVM context
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index be75ff2..46ef6fa 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.opencl";
} else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal";
+ } else if (target == "tflite") {
+ f_name = "target.runtime.tflite";
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") {
diff --git a/tests/python/contrib/test_tflite_runtime.py
b/tests/python/contrib/test_tflite_runtime.py
index 8c883b0..1b911b7 100644
--- a/tests/python/contrib/test_tflite_runtime.py
+++ b/tests/python/contrib/test_tflite_runtime.py
@@ -14,92 +14,130 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
from tvm import te
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
-# import tensorflow as tf
-# import tflite_runtime.interpreter as tflite
-
-
-def skipped_test_tflite_runtime():
-
- def create_tflite_model():
- root = tf.Module()
- root.const = tf.constant([1., 2.], tf.float32)
- root.f = tf.function(lambda x: root.const * x)
-
- input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
- concrete_func = root.f.get_concrete_function(input_signature)
- converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- tflite_model = converter.convert()
- return tflite_model
-
-
- def check_local():
- tflite_fname = "model.tflite"
- tflite_model = create_tflite_model()
- temp = util.tempdir()
- tflite_model_path = temp.relpath(tflite_fname)
- open(tflite_model_path, 'wb').write(tflite_model)
-
- # inference via tflite interpreter python apis
- interpreter = tflite.Interpreter(model_path=tflite_model_path)
- interpreter.allocate_tensors()
- input_details = interpreter.get_input_details()
- output_details = interpreter.get_output_details()
-
- input_shape = input_details[0]['shape']
- tflite_input = np.array(np.random.random_sample(input_shape),
dtype=np.float32)
- interpreter.set_tensor(input_details[0]['index'], tflite_input)
- interpreter.invoke()
- tflite_output = interpreter.get_tensor(output_details[0]['index'])
-
- # inference via tvm tflite runtime
- with open(tflite_model_path, 'rb') as model_fin:
- runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
- runtime.set_input(0, tvm.nd.array(tflite_input))
- runtime.invoke()
- out = runtime.get_output(0)
- np.testing.assert_equal(out.asnumpy(), tflite_output)
-
-
- def check_remote():
- tflite_fname = "model.tflite"
- tflite_model = create_tflite_model()
- temp = util.tempdir()
- tflite_model_path = temp.relpath(tflite_fname)
- open(tflite_model_path, 'wb').write(tflite_model)
-
- # inference via tflite interpreter python apis
- interpreter = tflite.Interpreter(model_path=tflite_model_path)
- interpreter.allocate_tensors()
- input_details = interpreter.get_input_details()
- output_details = interpreter.get_output_details()
-
- input_shape = input_details[0]['shape']
- tflite_input = np.array(np.random.random_sample(input_shape),
dtype=np.float32)
- interpreter.set_tensor(input_details[0]['index'], tflite_input)
- interpreter.invoke()
- tflite_output = interpreter.get_tensor(output_details[0]['index'])
-
- # inference via remote tvm tflite runtime
- server = rpc.Server("localhost")
- remote = rpc.connect(server.host, server.port)
- ctx = remote.cpu(0)
- a = remote.upload(tflite_model_path)
-
- with open(tflite_model_path, 'rb') as model_fin:
- runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
- runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
- runtime.invoke()
- out = runtime.get_output(0)
- np.testing.assert_equal(out.asnumpy(), tflite_output)
-
- check_local()
- check_remote()
+
+
+def _create_tflite_model():
+ if not tvm.runtime.enabled("tflite"):
+ print("skip because tflite runtime is not enabled...")
+ return
+ if not tvm.get_global_func("tvm.tflite_runtime.create", True):
+ print("skip because tflite runtime is not enabled...")
+ return
+
+ try:
+ import tensorflow as tf
+ except ImportError:
+ print('skip because tensorflow not installed...')
+ return
+
+ root = tf.Module()
+ root.const = tf.constant([1., 2.], tf.float32)
+ root.f = tf.function(lambda x: root.const * x)
+
+ input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
+ concrete_func = root.f.get_concrete_function(input_signature)
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ tflite_model = converter.convert()
+ return tflite_model
+
+
[email protected]('skip because accessing output tensor is flakey')
+def test_local():
+ if not tvm.runtime.enabled("tflite"):
+ print("skip because tflite runtime is not enabled...")
+ return
+ if not tvm.get_global_func("tvm.tflite_runtime.create", True):
+ print("skip because tflite runtime is not enabled...")
+ return
+
+ try:
+ import tensorflow as tf
+ except ImportError:
+ print('skip because tensorflow not installed...')
+ return
+
+ tflite_fname = "model.tflite"
+ tflite_model = _create_tflite_model()
+ temp = util.tempdir()
+ tflite_model_path = temp.relpath(tflite_fname)
+ open(tflite_model_path, 'wb').write(tflite_model)
+
+ # inference via tflite interpreter python apis
+ interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ input_shape = input_details[0]['shape']
+ tflite_input = np.array(np.random.random_sample(input_shape),
dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], tflite_input)
+ interpreter.invoke()
+ tflite_output = interpreter.get_tensor(output_details[0]['index'])
+
+ # inference via tvm tflite runtime
+ with open(tflite_model_path, 'rb') as model_fin:
+ runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
+ runtime.set_input(0, tvm.nd.array(tflite_input))
+ runtime.invoke()
+ out = runtime.get_output(0)
+ np.testing.assert_equal(out.asnumpy(), tflite_output)
+
+
+def test_remote():
+ if not tvm.runtime.enabled("tflite"):
+ print("skip because tflite runtime is not enabled...")
+ return
+ if not tvm.get_global_func("tvm.tflite_runtime.create", True):
+ print("skip because tflite runtime is not enabled...")
+ return
+
+ try:
+ import tensorflow as tf
+ except ImportError:
+ print('skip because tensorflow not installed...')
+ return
+
+ tflite_fname = "model.tflite"
+ tflite_model = _create_tflite_model()
+ temp = util.tempdir()
+ tflite_model_path = temp.relpath(tflite_fname)
+ open(tflite_model_path, 'wb').write(tflite_model)
+
+ # inference via tflite interpreter python apis
+ interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ input_shape = input_details[0]['shape']
+ tflite_input = np.array(np.random.random_sample(input_shape),
dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], tflite_input)
+ interpreter.invoke()
+ tflite_output = interpreter.get_tensor(output_details[0]['index'])
+
+ # inference via remote tvm tflite runtime
+ server = rpc.Server("localhost")
+ remote = rpc.connect(server.host, server.port)
+ ctx = remote.cpu(0)
+ a = remote.upload(tflite_model_path)
+
+ with open(tflite_model_path, 'rb') as model_fin:
+ runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
+ runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
+ runtime.invoke()
+ out = runtime.get_output(0)
+ np.testing.assert_equal(out.asnumpy(), tflite_output)
+
+ server.terminate()
+
if __name__ == "__main__":
- # skipped_test_tflite_runtime()
- pass
+ test_local()
+ test_remote()
diff --git a/tests/scripts/task_config_build_cpu.sh
b/tests/scripts/task_config_build_cpu.sh
index 9c1cf28..ce545bd 100755
--- a/tests/scripts/task_config_build_cpu.sh
+++ b/tests/scripts/task_config_build_cpu.sh
@@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake
echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
echo set\(USE_VTA_TSIM ON\) >> config.cmake
echo set\(USE_VTA_FSIM ON\) >> config.cmake
+echo set\(USE_TFLITE ON\) >> config.cmake
+echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake
+echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake