mkroening opened a new issue #9927:
URL: https://github.com/apache/tvm/issues/9927
### Expected behavior
No issue when running
```console
tvmc compile --target llvm model.onnx
```
with the ONNX model containing an `Einsum` with large input dimensions (ONNX
Runtime is working fine).
Model:
```onnx
ir_version: 8
graph {
node {
input: "a"
output: "b"
op_type: "Einsum"
attribute {
name: "equation"
s: "i->"
type: STRING
}
}
name: "test-model"
input {
name: "a"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8301
}
}
}
}
}
output {
name: "b"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 15
}
```
### Actual behavior
On my machine, starting with a size of `8301`, I get segmentation faults
_sometimes_, increasingly frequently with higher values.
This back trace makes it look like a stack overflow (I used a value of
`88301`, since the problem appears less frequently in `gdb`):
```
# [..]
Thread 1 "python3" received signal SIGSEGV, Segmentation fault.
0x00007fffdf50beb3 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
(gdb) bt
#0 0x00007fffdf50beb3 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#1 0x00007fffdf39ea4a in tvm::te::(anonymous
namespace)::ComputeVerifier::VisitExpr(tvm::PrimExpr const&) ()
from /home/mkroening/Development/tvm/install/lib/libtvm.so
#2 0x00007fffdf50beb6 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#3 0x00007fffdf39ea4a in tvm::te::(anonymous
namespace)::ComputeVerifier::VisitExpr(tvm::PrimExpr const&) ()
from /home/mkroening/Development/tvm/install/lib/libtvm.so
#4 0x00007fffdf50beb6 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#5 0x00007fffdf39ea4a in tvm::te::(anonymous
namespace)::ComputeVerifier::VisitExpr(tvm::PrimExpr const&) ()
from /home/mkroening/Development/tvm/install/lib/libtvm.so
# [..]
#149324 0x00007fffdf50beb6 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149325 0x00007fffdf39b585 in tvm::te::(anonymous
namespace)::ComputeVerifier::Run() () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149326 0x00007fffdf39bc89 in tvm::te::ComputeOp::ComputeOp(std::string,
std::string, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::ObjectRef,
void, void>, tvm::runtime::Array<tvm::tir::IterVar, void>,
tvm::runtime::Array<tvm::PrimExpr, void>) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149327 0x00007fffdf39c9b1 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149328 0x00007fffdf9848ef in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149329 0x00007fffdf95bb52 in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
tvm::topi::__mk_TVM26::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149330 0x00007fffe0356449 in TVMFuncCall () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149331 0x00007fffdd8c782c in __pyx_f_3tvm_4_ffi_4_cy3_4core_FuncCall3
(__pyx_v_nargs=2, __pyx_v_ret_tcode=0x7fffffff9868,
__pyx_v_ret_val=0x7fffffff9870, __pyx_v_args=0x7fffdb4929c0,
__pyx_v_chandle=0x1269500) at tvm/_ffi/_cython/core.cpp:7416
#149332 __pyx_f_3tvm_4_ffi_4_cy3_4core_FuncCall (__pyx_v_chandle=0x1269500,
__pyx_v_args=__pyx_v_args@entry=0x7fffdb4929c0,
__pyx_v_ret_val=__pyx_v_ret_val@entry=0x7fffffff9870,
__pyx_v_ret_tcode=__pyx_v_ret_tcode@entry=0x7fffffff9868) at
tvm/_ffi/_cython/core.cpp:7508
#149333 0x00007fffdd8c8044 in
__pyx_pf_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_4__call__
(__pyx_v_self=0x7fffdcc7cdb0, __pyx_v_self=0x7fffdcc7cdb0,
__pyx_v_args=0x7fffdb4929c0) at tvm/_ffi/_cython/core.cpp:8403
#149334 __pyx_pw_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_5__call__
(__pyx_v_self=0x7fffdcc7cdb0, __pyx_args=0x7fffdb4929c0, __pyx_kwds=<optimized
out>) at tvm/_ffi/_cython/core.cpp:8367
#149335 0x00000000005f6a46 in _PyObject_MakeTpCall ()
#149336 0x0000000000570a1f in _PyEval_EvalFrameDefault ()
#149337 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149338 0x00000000005f6403 in _PyFunction_Vectorcall ()
#149339 0x00000000005f55f2 in PyObject_Call ()
#149340 0x000000000056cbfb in _PyEval_EvalFrameDefault ()
#149341 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149342 0x00000000004ea9d2 in _PyFunction_FastCallDict ()
#149343 0x00007fffdd8c613b in __Pyx_PyObject_Call (kw=0x0,
arg=0x7fffdb492380, func=0x7fffdb4bd9d0) at tvm/_ffi/_cython/core.cpp:13483
#149344 __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback (__pyx_v_args=<optimized
out>, __pyx_v_type_codes=<optimized out>, __pyx_v_num_args=<optimized out>,
__pyx_v_ret=0x7fffffffa000, __pyx_v_fhandle=<optimized out>) at
tvm/_ffi/_cython/core.cpp:4542
#149345 0x00007fffe0351737 in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149346 0x00007fffe01f5ce6 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149347 0x00007fffe01f5ec8 in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
tvm::relay::__mk_TVM6::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149348 0x00007fffe0356449 in TVMFuncCall () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149349 0x00007fffdd8c7bb6 in __pyx_f_3tvm_4_ffi_4_cy3_4core_FuncCall
(__pyx_v_chandle=0x120b020, __pyx_v_args=__pyx_v_args@entry=0x7fffe30bc0e0,
__pyx_v_ret_val=__pyx_v_ret_val@entry=0x7fffffffa290,
__pyx_v_ret_tcode=__pyx_v_ret_tcode@entry=0x7fffffffa288) at
tvm/_ffi/_cython/core.cpp:7619
#149350 0x00007fffdd8c8044 in
__pyx_pf_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_4__call__
(__pyx_v_self=0x7fffdccd19a0, __pyx_v_self=0x7fffdccd19a0,
__pyx_v_args=0x7fffe30bc0e0) at tvm/_ffi/_cython/core.cpp:8403
#149351 __pyx_pw_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_5__call__
(__pyx_v_self=0x7fffdccd19a0, __pyx_args=0x7fffe30bc0e0, __pyx_kwds=<optimized
out>) at tvm/_ffi/_cython/core.cpp:8367
#149352 0x00000000005f6a46 in _PyObject_MakeTpCall ()
#149353 0x0000000000570612 in _PyEval_EvalFrameDefault ()
#149354 0x000000000050ac5e in ?? ()
#149355 0x00000000005703e6 in _PyEval_EvalFrameDefault ()
#149356 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149357 0x00000000005f6403 in _PyFunction_Vectorcall ()
#149358 0x000000000056b3fe in _PyEval_EvalFrameDefault ()
#149359 0x00000000006b398c in ?? ()
#149360 0x00000000004ea887 in _PyFunction_FastCallDict ()
#149361 0x00007fffdd8c613b in __Pyx_PyObject_Call (kw=0x0,
arg=0x7fffdb490a80, func=0x7fffdd429f70) at tvm/_ffi/_cython/core.cpp:13483
#149362 __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback (__pyx_v_args=<optimized
out>, __pyx_v_type_codes=<optimized out>, __pyx_v_num_args=<optimized out>,
__pyx_v_ret=0x7fffffffab40, __pyx_v_fhandle=<optimized out>) at
tvm/_ffi/_cython/core.cpp:4542
#149363 0x00007fffe0351737 in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149364 0x00007fffe0149db8 in
tvm::relay::tec::ScheduleBuilder::VisitExpr_(tvm::relay::CallNode const*) ()
from /home/mkroening/Development/tvm/install/lib/libtvm.so
#149365 0x00007fffe0141fb4 in
tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void>
(tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&,
tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void>
(tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&,
tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void>
(tvm::RelayExpr const&)>*) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149366 0x00007fffe014edd1 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149367 0x00007fffe014b02d in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149368 0x00007fffe0140e30 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149369 0x00007fffe0139f3d in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149370 0x00007fffe013aecb in
tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&,
tvm::runtime::String) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149371 0x00007fffe0136948 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149372 0x00007fffe01387ab in
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode
const*) () from /home/mkroening/Development/tvm/install/lib/libtvm.so
#149373 0x00007fffdff4514a in
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode
const*) () from /home/mkroening/Development/tvm/install/lib/libtvm.so
#149374 0x00007fffdffed1a4 in tvm::relay::ExprFunctor<tvm::RelayExpr
(tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&,
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr
const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&,
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149375 0x00007fffe01d8773 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149376 0x00007fffe01d6173 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149377 0x00007fffdff430ca in
tvm::relay::transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode
const*) [clone .localalias] () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149378 0x00007fffe0132555 in
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode
const*) () from /home/mkroening/Development/tvm/install/lib/libtvm.so
#149379 0x00007fffdff46cca in
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode
const*) () from /home/mkroening/Development/tvm/install/lib/libtvm.so
#149380 0x00007fffdffed184 in tvm::relay::ExprFunctor<tvm::RelayExpr
(tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&,
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr
const&)>*)#5}::_FUN(tvm::runtime::ObjectRef const&,
tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149381 0x00007fffe01d8773 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149382 0x00007fffe012a24e in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function,
tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTensorExpr(tvm::runtime::String
const&, tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>,
tvm::VirtualDevice)::{lambda(tvm::relay::Function, tvm::IRModule,
tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTensorExpr(tvm::runtime::String
const&, tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>,
tvm::VirtualDevice)::{lambda(tvm::relay::Function, tvm::IRModule,
tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149383 0x00007fffe01fc412 in
tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149384 0x00007fffdf15d845 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149385 0x00007fffdf15efb0 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149386 0x00007fffe0125629 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149387 0x00007fffe0129d84 in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTEPass(tvm::runtime::String
const&, std::function<void (tvm::BaseFunc)>,
tvm::VirtualDevice)::{lambda(tvm::IRModule,
tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTEPass(tvm::runtime::String
const&, std::function<void (tvm::BaseFunc)>,
tvm::VirtualDevice)::{lambda(tvm::IRModule,
tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149388 0x00007fffdf16124e in
tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149389 0x00007fffdf15d845 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149390 0x00007fffdf160a02 in
tvm::transform::SequentialNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149391 0x00007fffdf15d845 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149392 0x00007fffdf15efb0 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149393 0x00007fffe00f3c7d in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149394 0x00007fffe00f59ae in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(std::string
const&, tvm::runtime::ObjectPtr<tvm::runtime::Object>
const&)::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149395 0x00007fffe00cff91 in ?? () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149396 0x00007fffe00d29fd in std::_Function_handler<void
(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*),
tvm::relay::backend::RelayBuildModule::GetFunction(std::string const&,
tvm::runtime::ObjectPtr<tvm::runtime::Object>
const&)::{lambda(tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&,
tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149397 0x00007fffe0356449 in TVMFuncCall () from
/home/mkroening/Development/tvm/install/lib/libtvm.so
#149398 0x00007fffdd8c7bb6 in __pyx_f_3tvm_4_ffi_4_cy3_4core_FuncCall
(__pyx_v_chandle=0x18d7ae0, __pyx_v_args=__pyx_v_args@entry=0x7fffdb486400,
__pyx_v_ret_val=__pyx_v_ret_val@entry=0x7fffffffcaa0,
__pyx_v_ret_tcode=__pyx_v_ret_tcode@entry=0x7fffffffca98) at
tvm/_ffi/_cython/core.cpp:7619
#149399 0x00007fffdd8c8044 in
__pyx_pf_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_4__call__
(__pyx_v_self=0x7fffdb4470e0, __pyx_v_self=0x7fffdb4470e0,
__pyx_v_args=0x7fffdb486400) at tvm/_ffi/_cython/core.cpp:8403
#149400 __pyx_pw_3tvm_4_ffi_4_cy3_4core_14PackedFuncBase_5__call__
(__pyx_v_self=0x7fffdb4470e0, __pyx_args=0x7fffdb486400, __pyx_kwds=<optimized
out>) at tvm/_ffi/_cython/core.cpp:8367
#149401 0x00000000005f6a46 in _PyObject_MakeTpCall ()
#149402 0x0000000000570a1f in _PyEval_EvalFrameDefault ()
#149403 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149404 0x000000000050add0 in ?? ()
#149405 0x000000000056c5d1 in _PyEval_EvalFrameDefault ()
#149406 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149407 0x00000000005f6403 in _PyFunction_Vectorcall ()
#149408 0x000000000056c5d1 in _PyEval_EvalFrameDefault ()
#149409 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149410 0x00000000005f6403 in _PyFunction_Vectorcall ()
#149411 0x000000000056c5d1 in _PyEval_EvalFrameDefault ()
#149412 0x00000000005f6226 in _PyFunction_Vectorcall ()
#149413 0x00000000005703e6 in _PyEval_EvalFrameDefault ()
#149414 0x00000000005f6226 in _PyFunction_Vectorcall ()
#149415 0x000000000056b3fe in _PyEval_EvalFrameDefault ()
#149416 0x00000000005f6226 in _PyFunction_Vectorcall ()
#149417 0x00000000005703e6 in _PyEval_EvalFrameDefault ()
#149418 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149419 0x000000000068db17 in PyEval_EvalCode ()
#149420 0x0000000000600f34 in ?? ()
#149421 0x00000000005c4ad0 in ?? ()
#149422 0x000000000056b3fe in _PyEval_EvalFrameDefault ()
#149423 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149424 0x00000000005f6403 in _PyFunction_Vectorcall ()
#149425 0x000000000056b3fe in _PyEval_EvalFrameDefault ()
#149426 0x00000000005696da in _PyEval_EvalCodeWithName ()
#149427 0x00000000005f6403 in _PyFunction_Vectorcall ()
#149428 0x00000000005f55f2 in PyObject_Call ()
#149429 0x00000000006b7662 in ?? ()
#149430 0x00000000006b7a69 in Py_RunMain ()
#149431 0x00000000006b7c8d in Py_BytesMain ()
#149432 0x00007ffff7dee0b3 in __libc_start_main (main=0x4eed30 <main>,
argc=7, argv=0x7fffffffe298, init=<optimized out>, fini=<optimized out>,
rtld_fini=<optimized out>, stack_end=0x7fffffffe288) at ../csu/libc-start.c:308
#149433 0x00000000005fb12e in _start ()
```
### Environment
Operating System: Ubuntu 20.04.3 LTS
TVM version: f9d8c2b99615f074fd7b0ae95d04825ab443fa33
### Steps to reproduce
Create `model.onnx` with:
```python
import onnx
from onnx import helper
from onnx import TensorProto
m = 8301
A = helper.make_tensor_value_info('a', TensorProto.FLOAT, [m])
B = helper.make_tensor_value_info('b', TensorProto.FLOAT, [])
Eqn = 'i->'
node_def = helper.make_node(
'Einsum', # name
['a'], # inputs
['b'], # outputs
equation=Eqn # attributes
)
graph_def = helper.make_graph(
[node_def], # nodes
'test-model', # name
[A], # inputs
[B], # outputs
)
model_def = helper.make_model(graph_def)
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')
onnx.save(model_def, 'model.onnx')
```
I found this when doing batched matrix multiplies, but reduced it to a sum
over a vector.
Thanks a lot for your help! :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]