This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 0ac0c25 [SYSTEMDS-263] ONNX graph importer (Python API, docs, tests)
0ac0c25 is described below
commit 0ac0c2571b39e96f7a117fd317d73443632f6f26
Author: Lukas Timpl <[email protected]>
AuthorDate: Thu May 14 23:39:04 2020 +0200
[SYSTEMDS-263] ONNX graph importer (Python API, docs, tests)
This PR implements a first poc-implementation for an ONNX importer.
It adds support for the following operators: Add, Sub, MatMul, Neg, Xor,
Or, And, Relu, Tanh, Sigmoid, Softmax, Dropout, MaxPool, Conv, If; as
well as the logic for nested sub-graphs.
AMLS project SS 2020
Closes #904.
---
.github/workflows/python.yml | 17 +-
.gitignore | 3 +
docs/Tasks.txt | 3 +-
docs/onnx-systemds-design.md | 46 --
pom.xml | 1 +
.../python/docs/source/assets/sample_graph.png | Bin 0 -> 35508 bytes
src/main/python/docs/source/index.rst | 8 +
src/main/python/docs/source/onnx_systemds.rst | 59 +++
.../python/docs/source/onnx_systemds_design.rst | 217 ++++++++++
src/main/python/systemds/__init__.py | 2 +-
src/main/python/systemds/onnx_systemds/README.md | 22 +
src/main/python/systemds/onnx_systemds/__init__.py | 14 +
src/main/python/systemds/onnx_systemds/convert.py | 53 +++
.../python/systemds/onnx_systemds/onnx_helper.py | 218 ++++++++++
.../python/systemds/onnx_systemds/operator_gen.py | 465 +++++++++++++++++++++
src/main/python/systemds/onnx_systemds/render.py | 215 ++++++++++
.../templates/graph_function.dml.jinja | 54 +++
.../onnx_systemds/templates/graph_header.dml.jinja | 22 +
.../onnx_systemds/templates/main.dml.jinja | 26 ++
.../templates/matrix_initialize.dml.jinja | 24 ++
.../onnx_systemds/templates/model_header.dml.jinja | 36 ++
.../templates/module_import.dml.jinja | 17 +
.../operators/2input_1output_operator.dml.jinja | 18 +
.../templates/operators/function_call.dml.jinja | 31 ++
.../templates/operators/if_operator.dml.jinja | 19 +
.../templates/operators/neg.dml.jinja | 18 +
.../onnx_systemds/templates/util.dml.jinja | 42 ++
src/main/python/systemds/onnx_systemds/util.py | 40 ++
src/main/python/{systemds => tests}/__init__.py | 5 -
.../python/{systemds => tests/onnx}/__init__.py | 4 -
.../dml_wrapper/simple_conv_layer_2_wrapper.dml | 27 ++
.../onnx/dml_wrapper/simple_conv_layer_wrapper.dml | 25 ++
.../dml_wrapper/simple_dropout_layer_wrapper.dml | 22 +
.../onnx/dml_wrapper/simple_if_graph_wrapper.dml | 27 ++
.../dml_wrapper/simple_mat_add_mul_sub_wrapper.dml | 24 ++
.../onnx/dml_wrapper/simple_mat_add_wrapper.dml | 24 ++
.../dml_wrapper/simple_mat_initialized_wrapper.dml | 21 +
.../dml_wrapper/simple_maxpool_layer_wrapper.dml | 22 +
.../simple_relu_tanh_sigmoid_softmax_wrapper.dml | 27 ++
.../simple_conv_layer_2_reference.out | 5 +
.../simple_conv_layer_reference.out | 25 ++
.../output_reference/simple_if_graph_reference.out | 5 +
.../simple_mat_add_mul_sub_reference.out | 4 +
.../output_reference/simple_mat_add_reference.out | 4 +
.../simple_mat_initialized_reference.out | 9 +
.../simple_maxpool_layer_reference.out | 25 ++
.../simple_relu_tanh_sigmoid_softmax_reference.out | 11 +
.../tests/onnx/test_models/model_generate.py | 388 +++++++++++++++++
src/main/python/tests/onnx/test_simple.py | 65 +++
src/main/python/tests/onnx/util.py | 84 ++++
50 files changed, 2485 insertions(+), 58 deletions(-)
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 27c12ec..156d843 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -72,9 +72,12 @@ jobs:
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{
hashFiles('src/main/python/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
+
+ - name: Install protobuf
+ run: sudo apt-get install protobuf-compiler libprotoc-dev
- name: Install pip Dependencies
- run: pip install numpy py4j wheel scipy sklearn
+ run: pip install numpy py4j wheel scipy sklearn jinja2 onnx
- name: Build Python Package
run: |
@@ -97,3 +100,15 @@ jobs:
cd src/main/python
python -m unittest tests/lineage/*.py
echo "Exit Status: " $?
+
+ - name: Run onnx-systemds python tests
+ run: |
+ export SYSTEMDS_ROOT=$(pwd)
+ export PATH=$SYSTEMDS_ROOT/bin:$PATH
+ cd src/main/python
+ echo "Creating models"
+ python tests/onnx/test_models/model_generate.py
+ ls tests/onnx/test_models/*.onnx
+ echo "Beginning tests"
+ python -m unittest tests/onnx/*.py
+ echo "Exit Status: " $?
diff --git a/.gitignore b/.gitignore
index 87a0fcd..e986cea 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,6 +50,9 @@ src/main/python/NOTICE
src/main/python/dist
src/main/python/docs/build
src/main/python/docs/source/_build
+src/main/python/tests/onnx/output_test
+src/main/python/tests/onnx/dml_output
+src/main/python/tests/onnx/test_models/*.onnx
# User configuration files
conf/SystemDS-config.xml
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index e83d0ee..8163e9f 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -213,7 +213,8 @@ SYSTEMDS-250 Extended Slice Finding
SYSTEMDS-260 Misc Tools
* 261 Stable marriage algorithm OK
* 262 Data augmentation tool for data cleaning OK
- * 263 ONNX graph importer/exporter
+ * 263 ONNX graph importer (Python API, docs, tests) OK
+ * 264 ONNX graph exporter
SYSTEMDS-270 Compressed Matrix Blocks
* 271 Reintroduce compressed matrix blocks from SystemML OK
diff --git a/docs/onnx-systemds-design.md b/docs/onnx-systemds-design.md
deleted file mode 100644
index 9650f9c..0000000
--- a/docs/onnx-systemds-design.md
+++ /dev/null
@@ -1,46 +0,0 @@
-# onnx-systemds
-
-A tool for importing/exporting
[ONNX](https://github.com/onnx/onnx/blob/master/docs/IR.md) graphs into/from
SystemDS DML scripts.
-
-
-## Goals
-
-* Support for importing [operators of the ONNX base
definition](https://github.com/onnx/onnx/blob/master/docs/Operators.md)
-
-* Support for importing [operators defined by
ONNX-ML](https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md)
-
-* Support for exporting DML script to ONNX graphs
-
-## Limitations
-
-* Not able to support all data types / operators as they are not currently
supported by SystemDS
-
-
-
-## Suggested Implementation
-
-Since the ONNX specification includes the conditional operators
[loop](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop) and
[if](https://github.com/onnx/onnx/blob/master/docs/Operators.md#If), a direct
conversion from ONNX to the internal HOP might not be ideal.
-
-Hence my suggested implementation is a dedicated tool invoked from command
line which generates DML scripts. This also enables optimizations performed by
the compiler at both graph and program level.
-
-### Example Call
-
-```bash
-onnx-systemds model.onx --out model_script.dml
-```
-
-
-### Tooling
-
-* Due to the availability of a [Python
API](https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md) for
ONNX, I would suggest implementing the tool in Python
-* Another advantage of Python is good support for template engines e.g.
[Jinja](https://jinja.palletsprojects.com/en/2.11.x/)
-* An implementation could use templates for various operators which are then
combined into a script
-
-### Implementation Details
-
-ONNX is a [serialized
graph](https://github.com/onnx/onnx/blob/master/docs/IR.md#graphs) structured
as a sorted list of nodes that form a DAG (directed acyclic graph).
-
-1. Loading in the serialized structure
-2.
[Checking](https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md#checking-an-onnx-model)
model and
[converting](https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md#converting-version-of-an-onnx-model-within-default-domain-aionnx)
models to a common version
-3. Building a simple internal graph structure (for arbitrary operators)
-4. Generating the DML script while traversing this graph (provided information
in doc_strings and other description variables are added as comments to improve
human-readability of the generated script)
diff --git a/pom.xml b/pom.xml
index 4c25f07..6d6ab8d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -508,6 +508,7 @@
<exclude>**/*.libsvm</exclude>
<exclude>**/*.mtx</exclude>
<exclude>**/*.mtd</exclude>
+
<exclude>**/*.out</exclude>
<exclude>**/part-*</exclude>
<exclude>**/*.keep</exclude>
<exclude>**/target/**</exclude>
diff --git a/src/main/python/docs/source/assets/sample_graph.png
b/src/main/python/docs/source/assets/sample_graph.png
new file mode 100644
index 0000000..630a98d
Binary files /dev/null and
b/src/main/python/docs/source/assets/sample_graph.png differ
diff --git a/src/main/python/docs/source/index.rst
b/src/main/python/docs/source/index.rst
index cdcb0a2..5b6cf53 100644
--- a/src/main/python/docs/source/index.rst
+++ b/src/main/python/docs/source/index.rst
@@ -59,3 +59,11 @@ tensors (multi-dimensional arrays) whose first dimension may
have a heterogeneou
:caption: Central Classes
matrix.rst
+
+.. toctree::
+ :maxdepth: 1
+ :hidden:
+ :caption: onnx-systemds
+
+ onnx_systemds.rst
+ onnx_systemds_design.rst
diff --git a/src/main/python/docs/source/onnx_systemds.rst
b/src/main/python/docs/source/onnx_systemds.rst
new file mode 100644
index 0000000..4d9a4f4
--- /dev/null
+++ b/src/main/python/docs/source/onnx_systemds.rst
@@ -0,0 +1,59 @@
+.. -------------------------------------------------------------
+..
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+..
+.. http://www.apache.org/licenses/LICENSE-2.0
+..
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+..
+.. -------------------------------------------------------------
+
+QuickStart
+=============
+onnx-systemds is a tool for importing/exporting onnx graphs into/from SystemDS
DML scripts.
+
+Prerequisites
+---------------
+to run onnx-systemds you need to:
+
+- install `onnx <https://github.com/onnx/onnx>`_: `Installation instructions
<https://github.com/onnx/onnx#installation>`_
+- `set up the environment
<https://github.com/apache/systemml/blob/master/bin/README.md>`_
+
+Usage
+------
+An example call from the ``src/main/python`` directory of systemds::
+
+ python -m systemds.onnx_systemds.convert
tests/onnx/test_models/simple_mat_add.onnx
+
+
+This will generate the dml script ``simple_mat_add.dml`` in the current
directory.
+
+Run Tests
+---------
+Form the ``src/main/python`` directory of systemds:
+
+At first generate the test models::
+
+ python tests/onnx/test_models/model_generate.py
+
+Then you can run the tests::
+
+ python -m unittest tests/onnx/test_simple.py
+
+
+Converter
+---------
+It is also possible to invoke the converter from within python.
+
+.. autofunction:: systemds.onnx_systemds.convert.onnx2systemds
\ No newline at end of file
diff --git a/src/main/python/docs/source/onnx_systemds_design.rst
b/src/main/python/docs/source/onnx_systemds_design.rst
new file mode 100644
index 0000000..5ff05a0
--- /dev/null
+++ b/src/main/python/docs/source/onnx_systemds_design.rst
@@ -0,0 +1,217 @@
+.. -------------------------------------------------------------
+..
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+..
+.. http://www.apache.org/licenses/LICENSE-2.0
+..
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+..
+.. -------------------------------------------------------------
+
+Design
+======
+
+This document describes the initial design of `onnx-systemds`
+
+For dealing with different operator-set versions of onnx the current strategy
is to use the
+`converter provided by onnx
<https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md#converting-version-of-an-onnx-model-within-default-domain-aionnx>`_
to convert to a common version.
+
+However, the converter does not support adapters for all op-sets/operators so
this conversion will fail for many models.
+On the onnx repository you can find a list of
+`currently supported adapters
<https://github.com/onnx/onnx/blob/master/onnx/version_converter.py#L21>`_
+
+
+Goals
+-----
+
+ - Support for importing `operators of the ONNX base definition
<https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
+ - Support for importing `operators defined by ONNX-ML
<https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md>`_
+ - Support for exporting DML script to ONNX graphs
+
+
+Limitations
+------------
+
+ - Not able to support all data types / operators as they are not currently
supported by SystemDS
+
+Onnx - Operators
+-----------------
+
+Onnx includes several very simple and also more complex operators.
+When implementing an operator it's best to have a look at the
+`operator schemas
<https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
+which precisely define the inputs, outputs and attributes of the operation.
+
+Besides the standard onnx definition, there also exists onnx-ML the operator
schemas for which are defined in a
+`separate document
<https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md>`_.
+It is an extension of the standard onnx format, however currently only onnx
standard operators are supported.
+
+Onnx - Files
+-------------
+
+Onnx uses the `ProtoBuf format
<https://developers.google.com/protocol-buffers/>`_.
+It specifies this representation in several ``.proto``/``.proto3``
+`files <https://github.com/onnx/onnx/tree/master/onnx>`_ again with dedicated
files for onnx-ML.
+These files are helpful to understand the underlying structure and values that
are possible.
+
+Protobuf creates the underlying structure such that you can access elements of
the onnx graph as if they were
+class members. For more information take a look at
+`Google's protocol-buffer documentation
+<https://developers.google.com/protocol-buffers/docs/pythontutorial#the-protocol-buffer-api>`_.
+
+This is also why in its current form, this converter does not convert the
protobuf-structure into an internal format,
+as the provided protobuf structure can already be conveniently used. Instead,
+there exist a number of onnx-helper functions/classes (see ``onnx_helper.py``).
+
+Traversing the Graph
+---------------------
+
+For creating the script, it is essential to insert computations in the right
order into the dml-script.
+To do this, the converter builds a tree-structure (DAG) from the protobuf-nodes
+(see `render.gen_graph_functions`).
+
+ - For traversing the graph, we start from the bottom.
+ - The converter starts with the graph-outputs as available outputs.
+ - It generates the dml snippets in reverse-order
+
+Graph traversal
+^^^^^^^^^^^^^^^^
+
+1. Find a node for which all outputs are available.
+
+2. Process the node:
+
+ - Generate the dml parts for this node
+ - add its inputs to the list of available outputs
+ - remove the node from the graph
+
+3. if there are nodes left restart at 1.
+
+Example
+^^^^^^^
+
+In the example below with the nodes ``Add``, ``MatMul`` and ``Sub``, we would
start with ``F`` as available output.
+Therefore the first node to insert would be ``Sub``. After inserting ``Sub``
its inputs become available outputs,
+therefore all outputs of ``MatMul`` become available. Finally, after removing
``MatMul`` from the graph all outputs
+to ``Add`` are available, and it can be removed from the graph as well.
+
+.. image:: assets/sample_graph.png
+ :width: 200px
+ :align: center
+ :alt: sample graph
+
+
+Rendering DML scripts
+---------------------
+
+The main idea of this converter is, that the logic for generating the actual
dml-syntax is handled by
+`Jinja templates <https://jinja.palletsprojects.com/en/2.11.x/>`_ (located in
``/templates``).
+Therefore the python code stays uncluttered, because it does not have to merge
strings together to produce valid
+dml-syntax and instead simply provides the elements that are needed to render
the script.
+
+The template-engine then takes these inputs and renders a human readable
script with valid dml syntax.
+To improve readability the generator also automatically ads the doc-strings
which are part of the onnx-definitions as
+comments to the script.
+
+When traversing the graph, a script part is generated for each node consisting
of three elements:
+
+ - `dml_script` The actual script snipped for the node
+ - `imports` Imports required for the node
+ - `sub_graphs` Any sub_graphs of the node that need to be handled
+
+The function that is called for rendering a specific operator is defined in
the dictionary
+``operator_generators`` in ``render.py``
+
+1. `dml_script`
+^^^^^^^^^^^^^^^^^^
+
+Depending on the operator this can be a function call or a more complex
dml-snippet.
+This part is generated by the template-engine when the corresponding template
is rendered.
+
+Many onnx-operators can be handled by a single template file. There exists a
``function_call.dml.jinja``
+template which should be able to handle a large number of operators.
+
+2. `imports`
+^^^^^^^^^^^^^
+
+Some operators are handled by calling scripts provided by systemds located in
``$SYSTEMDS_ROOT/scripts``.
+To enable these imports, the converter automatically resolves the
``$SYSTEMDS_ROOT``
+environment variable and adds a ``setw($SYSTEMDS_ROOT/scripts)`` to the script.
+
+3. `sub_graphs`
+^^^^^^^^^^^^^^^^^
+
+Since sub-graphs have their own variable scope and are independent, they are
handled as separate functions.
+The converter generates a function for each graph in the model.
+In the main-graph, the sub-graph is replaced by a function call to the
sub-graph function.
+To handle this the function ``render.gen_graph_functions`` recursively calls
itself to render sub-graph functions
+(and also the sub-graph functions of sub-graphs and so on...).
+
+Final Script
+------------
+
+In the final render all required imports, the sub-functions and the
main-function are combined in a single dml-file.
+
+Implementing new operators
+----------------------------
+
+When implementing an operator it's best to have a look at the
+`operator schemas
<https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
+which exactly define the inputs, outputs and attributes of the operation
+
+It is also nice to have a test-model to work with, to generate one refer to
+``tests/onnx/test_models/model_generate.py``.
+
+To implement a new operator, the function that handles the operator needs to
be defined in the ``operator_generators``
+located in ``render.py``.
+All functions listed in this dictionary need to have the same call structure.
+
+If there exists a dml-script (in ``$SYSTEMDS_ROOT/scripts``) that provides the
functionality the operator
+can be implemented by translating the arguments/inputs, adding the
import-render and function-call-render to this script.
+
+Testing models
+---------------
+
+onnx provides a convenient way for
+`creating models
<https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md#checking-an-onnx-model>`_
+using helper functions in python. All current test-models are produced like
this (see ``tests/onnx/test_models``).
+
+Creating a Testcase
+^^^^^^^^^^^^^^^^^^^^^
+
+The current test-system takes a model, converts it to dml using the converter
and then runs a
+``dml_wrapper`` which calls the model-function using the script
``$SYSTEMDS_ROOT/bin/systemds``.
+Finally, the output (stored by the dml-wrapper) is compared to a reference
output.
+
+When creating files stick to the naming conventions of other files in the same
folder.
+
+Steps:
+""""""""
+
+1. Create a model in ``tests/onnx/test_models``, e.g. ``sample_model.onnx``
+
+2. Create a dml wrapper that calls the model-function in
``tests/onnx/dml_wrapper/sample_model_wrapper.dml``
+
+ - The wrapper needs to call the model-function and store the output to
``output_test/sample_model.out``
+ - The name of the model-function is generated from the model-name (see
``util.generate_function_name`` )
+
+3. Provide a reference output in
``tests/onnx/output_reference/sample_model_reference.out``
+
+4. Create the unit test function.
+
+Tools
+------
+
+ - `Pycharm <https://www.jetbrains.com/pycharm/>`_ in the professional version
allows you to `debug template files
<https://www.jetbrains.com/help/pycharm/templates.html#debug>`_ which can be
handy
+ - `Neutron <https://github.com/lutzroeder/netron>`_ is a nice free tool for
viewing onnx-graphs
\ No newline at end of file
diff --git a/src/main/python/systemds/__init__.py
b/src/main/python/systemds/__init__.py
index e51fbf8..f4e5a1d 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/systemds/__init__.py
@@ -19,4 +19,4 @@
#
#-------------------------------------------------------------
-__all__ = ['context', 'matrix']
+__all__ = ['context', 'matrix', 'onnx_systemds']
diff --git a/src/main/python/systemds/onnx_systemds/README.md
b/src/main/python/systemds/onnx_systemds/README.md
new file mode 100644
index 0000000..5b855de
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/README.md
@@ -0,0 +1,22 @@
+<!--
+{% comment %}
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+{% end comment %}
+-->
+
+# onnx-systemds
+
+A tool for importing/exporting
[onnx](https://github.com/onnx/onnx/blob/master/docs/IR.md) graphs into/from
SystemDS DML scripts.
diff --git a/src/main/python/systemds/onnx_systemds/__init__.py
b/src/main/python/systemds/onnx_systemds/__init__.py
new file mode 100644
index 0000000..27bcb2e
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/__init__.py
@@ -0,0 +1,14 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/main/python/systemds/onnx_systemds/convert.py
b/src/main/python/systemds/onnx_systemds/convert.py
new file mode 100644
index 0000000..5ee062b
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/convert.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os.path
+import systemds.onnx_systemds.onnx_helper as onnx_helper
+from systemds.onnx_systemds import render
+
+
+def init_argparse() -> argparse.ArgumentParser:
+ arg_parser = argparse.ArgumentParser(description="Convert onnx models into
dml scripts")
+ arg_parser.add_argument("input", type=str)
+ arg_parser.add_argument("-o", "--output", type=str,
+ help="output file", required=False)
+ return arg_parser
+
+
+def onnx2systemds(input_onnx_file: str, output_dml_file: str = None) -> None:
+ """
+ Loads the model from the input file and generates a dml file.
+
+ :param input_onnx_file: the onnx input file
+ :param output_dml_file: (optional) the dml output file,
+ if this parameter is not given the output file will have the same name
as the input file
+ """
+ if not os.path.isfile(input_onnx_file):
+ raise Exception("Invalid input-file: " + str(input_onnx_file))
+
+ if not output_dml_file:
+ output_dml_file =
os.path.splitext(os.path.basename(input_onnx_file))[0] + ".dml"
+
+ model = onnx_helper.load_model(input_onnx_file)
+ render.gen_script(model, output_dml_file)
+
+
+if __name__ == '__main__':
+ parser = init_argparse()
+ args = parser.parse_args()
+ input_file = args.input
+ output_file = args.output
+ onnx2systemds(input_file, output_file)
diff --git a/src/main/python/systemds/onnx_systemds/onnx_helper.py
b/src/main/python/systemds/onnx_systemds/onnx_helper.py
new file mode 100644
index 0000000..0425ea4
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/onnx_helper.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import functools
+
+import onnx
+import onnx.version_converter
+
+
+class TreeNode:
+ def __init__(self, node: onnx.NodeProto):
+ self.node = node
+ self.parent_nodes = list()
+ self.child_nodes = list()
+
+
+class NodeTree:
+ """ A simple class for representing a tree structure of nodes """
+
+ def __init__(self, nodes: [onnx.NodeProto]):
+ self.nodes = [TreeNode(node) for node in nodes]
+ self.root_nodes = [] # nodes that have no parents
+ self.end_nodes = [] # nodes that have no children
+
+ # find parents and children for each node
+ for tree_node in self.nodes:
+ for compare_tree_node in self.nodes:
+ if tree_node != compare_tree_node:
+ for node_output in tree_node.node.output:
+ if node_output in compare_tree_node.node.input:
+ tree_node.child_nodes.append(compare_tree_node)
+ compare_tree_node.parent_nodes.append(tree_node)
+
+ for node in self.nodes:
+ if len(node.child_nodes) == 0:
+ self.end_nodes.append(node)
+ if len(node.parent_nodes) == 0:
+ self.root_nodes.append(node)
+
+ def remove_end_node(self, node: TreeNode):
+ """
+ Removes the given end-node from the tree.
+ Removing a non-existing or non end-node raises an exception.
+ :param node: The node that shall be removed
+ """
+ if node not in self.end_nodes:
+ raise Exception("Can only remove end nodes")
+ self.end_nodes.remove(node)
+ self.nodes.remove(node)
+
+ for parent_node in node.parent_nodes:
+ parent_node.child_nodes.remove(node)
+ self.end_nodes += node.parent_nodes
+ node.parent_nodes = []
+
+
+def load_model(onnx_file: str) -> onnx.ModelProto:
+ """
+ Loads the onnx file, checks the model and converts it to a common version
if necessary.
+
+ :param onnx_file:
+ :return: the loaded onnx-model
+ """
+ TARGET_VERSION = 12
+ model = onnx.load(onnx_file)
+ onnx.checker.check_model(model)
+ if len(list(model.opset_import)) == 1 and
list(model.opset_import)[0].version == TARGET_VERSION:
+ return model
+ else:
+ return onnx.version_converter.convert_version(model, TARGET_VERSION)
+
+
+def get_value_info(graph: onnx.GraphProto, name: str) -> onnx.ValueInfoProto:
+ """
+ Searches the `graph` for the given `name` and returns the associated
ValueInfo,
+ if the name is not found None is returned.
+
+ :param graph: the onnx-graph that shall be searched
+ :param name: the name of the value
+ :return: the value-info or None if it is not found
+ """
+ for info in graph.input:
+ if info.name == name:
+ return info
+
+ for info in graph.value_info:
+ if info.name == name:
+ return info
+
+ for info in graph.output:
+ if info.name == name:
+ return info
+
+ return None
+
+
+def get_graph_inputs_without_initializers(graph: onnx.GraphProto) ->
[onnx.ValueInfoProto]:
+ """
+ Returns all inputs of the `graph` that have no associated initializer
values.
+
+ :param graph: the onnx-graph
+ :return: list of uninitialized inputs
+ """
+ inputs_without_initializers = []
+ for input in graph.input:
+ has_initializer = False
+ for initializer in graph.initializer:
+ if initializer.name == input.name:
+ has_initializer = True
+ break
+
+ if not has_initializer:
+ inputs_without_initializers.append(input)
+
+ return inputs_without_initializers
+
+
+def get_graph_inputs_with_initializers(graph: onnx.GraphProto) ->
[(onnx.ValueInfoProto, onnx.TensorProto)]:
+ """
+ Returns all initialized inputs of the `graph` with their corresponding
initializer.
+
+ :param graph: the onnx-graph
+ :return: list of tuples of (input, initializer)
+ """
+ inputs_with_initializers = []
+
+ for input in graph.input:
+ for initializer in graph.initializer:
+ if initializer.name == input.name:
+ inputs_with_initializers.append((input, initializer))
+
+ return inputs_with_initializers
+
+
+class PreparedValue:
+ """ Class for preparing onnx value structures for writing them to the dml
script """
+ def __init__(self, value_info: onnx.ValueInfoProto, initializer:
onnx.TensorProto = None):
+
+ systemds_supported_types = ["integer", "boolean", "double", "string"]
+
+ # TODO: these type translations are not correct double -> float
+ # Translating onnx types to systemds types
+ type_translation = {
+ 1: "double", # float
+ 2: "unsigned integer", # uint8_t
+ 3: "integer", # int8_t
+ 4: "unsigned integer", # uint16_t
+ 5: "integer", # int16_t
+ 6: "integer", # int32_t
+ 7: "long", # int64_t
+ 8: "string",
+ 9: "boolean", # bool
+
+ 10: "double", # float16,
+ 11: "double",
+ 12: "unsigned integer", # uint32
+ 13: "unsigned long", # uint64
+
+ 14: "COMPLEX64",
+ 15: "COMPLEX128",
+ 16: "BFLOAT16"
+ }
+
+ if value_info.type.tensor_type.elem_type not in
type_translation.keys():
+ raise NotImplementedError("Only support Tensor Types")
+
+ # TODO: add support for other data types
+
+ self.value_type =
type_translation[value_info.type.tensor_type.elem_type]
+ if self.value_type not in systemds_supported_types:
+ raise NotImplementedError("The type " + self.value_type + " is
currently not supported")
+
+ self.shape = []
+ dims = get_valueinfo_dimensions(value_info)
+
+ if len(dims) == 1 and dims[0] == 1:
+ self.data_type = "scalar"
+ self.shape = [1]
+ else:
+ self.data_type = "matrix"
+ if self.value_type != "double":
+ raise NotImplementedError("A matrix can only have the type
double")
+ shape_dimensions = value_info.type.tensor_type.shape.dim
+ for dim in shape_dimensions:
+ # TODO: shapes with no value but instead name -> support?
+ if len(dim.dim_param) != 0:
+ raise NotImplementedError("Only support dim_value")
+ self.shape.append(dim.dim_value)
+
+ if len(self.shape) > 2:
+ # TODO: not sure this is the solution for every instance of
this problem
+ # Multiply all shapes right
+ rows = self.shape[0]
+ cols = functools.reduce(lambda s0, s1: s0 * s1, self.shape[1:])
+ self.shape = [rows, cols]
+
+ self.identifier_name = value_info.name
+ self.description = value_info.doc_string
+ self.initializer = None
+
+ if initializer:
+ self.initializer_values = list(initializer.float_data)
+
+
+def get_valueinfo_dimensions(value_info: onnx.ValueInfoProto) -> [int]:
+ return [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
+
diff --git a/src/main/python/systemds/onnx_systemds/operator_gen.py
b/src/main/python/systemds/onnx_systemds/operator_gen.py
new file mode 100644
index 0000000..b8021ac
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/operator_gen.py
@@ -0,0 +1,465 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from random import randint
+
+import jinja2
+import onnx
+import systemds.onnx_systemds.onnx_helper as onnx_helper
+from systemds.onnx_systemds import util
+
+
+class GeneratedScriptPart:
+ def __init__(self, dml_script: str, imports: [str] = None, sub_graphs:
[onnx.GraphProto] = None):
+ if sub_graphs is None:
+ sub_graphs = []
+ if imports is None:
+ imports = []
+ self.dml_script = dml_script
+ self.imports = imports
+ self.sub_graphs = sub_graphs
+
+
+def gen_simple_function_call(env: jinja2.environment.Environment, graph:
onnx.GraphProto,
+ node: onnx.NodeProto) -> GeneratedScriptPart:
+ """
+ Generates a simple function call by directly providing the node inputs as
arguments
+ and node outputs as outputs to a function call. Additionally adds the
required imports.
+
+ :param env: Jinja environment to load the template files
+ :param graph: the onnx-graph for which the script shall be generated
+ :param node: the onnx-node for which the script shall be generated
+ :return: The generated script part
+ """
+ operator_template = env.get_template("operators/" +
"function_call.dml.jinja")
+ import_template = env.get_template("module_import.dml.jinja")
+
+ if len(list(node.output)) != 1:
+ raise Exception("Function call needs output")
+
+ if len(node.attribute) != 0:
+ raise Exception("Attributes not supported for this generator")
+
+ required_import = {
+ "Relu": {"path": "/nn/layers/relu.dml", "import_name": "relu_layer",
"function_name": "forward"},
+ "Tanh": {"path": "/nn/layers/tanh.dml", "import_name": "tanh_layer",
"function_name": "forward"},
+ "Sigmoid": {"path": "/nn/layers/sigmoid.dml", "import_name":
"sigmoid_layer", "function_name": "forward"},
+ "Softmax": {"path": "/nn/layers/softmax.dml", "import_name":
"softmax_layer", "function_name": "forward"}
+ }
+
+ import_render = ""
+ function_name = node.op_type
+ function_namespace = ""
+ if node.op_type in required_import.keys():
+ module_import = required_import[node.op_type]
+ import_render = import_template.render(
+ path=module_import["path"],
+ name=module_import["import_name"]
+ )
+ function_name = module_import["function_name"]
+ function_namespace = module_import["import_name"]
+
+ node_render = operator_template.render(
+ function_namespace=function_namespace,
+ function=function_name,
+ arguments=list(node.input),
+ outputs=list(node.output),
+ doc_string=node.doc_string
+ )
+ return GeneratedScriptPart(imports=[import_render], dml_script=node_render)
+
+
+def gen_2input_1output_operator(env: jinja2.environment.Environment, graph:
onnx.GraphProto,
+ node: onnx.NodeProto) -> GeneratedScriptPart:
+ """
+ Generates simple operator calls like 'z = x + y' which have two inputs
(left and right) and one output.
+ :param env: Jinja environment to load the template files
+ :param graph: the onnx-graph for which the script shall be generated
+ :param node: the onnx-node for which the script shall be generated
+ :return: The generated script part
+ """
+ operator = {
+ "Add": "+",
+ "Sub": "-",
+ "MatMul": "%*%",
+ "And": "&",
+ "Or": "|"
+ }
+ operator_template =
env.get_template("operators/2input_1output_operator.dml.jinja")
+
+ if len(node.attribute) != 0:
+ raise Exception("attributes not supported for operator")
+
+ if len(list(node.input)) > 2 or len(list(node.output)) > 1:
+ raise Exception("Operator needs 2 inputs and 1 output")
+
+ node_render = operator_template.render(
+ input_0=list(node.input)[0],
+ input_1=list(node.input)[1],
+ output=list(node.output)[0],
+ operator=operator[node.op_type],
+ doc_string=node.doc_string
+ )
+ return GeneratedScriptPart(node_render)
+
+
+def gen_1input_1output_mat_operator(env: jinja2.environment.Environment,
graph: onnx.GraphProto,
+ node: onnx.NodeProto) ->
GeneratedScriptPart:
+ """
+ Generates simple operators like 'y = -x' which have one input and one
output.
+ :param env: Jinja environment to load the template files
+ :param graph: the onnx-graph for which the script shall be generated
+ :param node: the onnx-node for which the script shall be generated
+ :return: The generated script part
+ """
+ template_for_operator = {
+ "Neg": "neg.dml.jinja",
+ }
+
+ operator_template = env.get_template("operators/" +
template_for_operator[node.op_type])
+
+ if len(node.attribute) != 0:
+ raise Exception("attributes not supported for operator")
+
+ if len(list(node.input)) != 1 or len(list(node.output)) != 1:
+ raise Exception("Operator needs 1 input and 1 output")
+
+ node_render = operator_template.render(
+ input=list(node.input)[0],
+ output=list(node.output)[0],
+ doc_string=node.doc_string
+ )
+ return GeneratedScriptPart(dml_script=node_render)
+
+
+def gen_dropout_call(env: jinja2.environment.Environment, graph:
onnx.GraphProto,
+ node: onnx.NodeProto) -> GeneratedScriptPart:
+ operator_template = env.get_template("operators/" +
"function_call.dml.jinja")
+ import_template = env.get_template("module_import.dml.jinja")
+
+ function_namespace = "dropout_layer"
+ function_name = "forward"
+ path = "/nn/layers/dropout.dml"
+
+ # * Inputs:
+ # * - X: Inputs, of shape (any, any).
+ # * - p: Probability of keeping a neuron output.
+ # * - seed: [Optional: -1] Random number generator seed to allow for
+ # * deterministic evaluation. Set to -1 for a random seed.
+ # * Outputs:
+ # * - out: Outputs, of same shape as `X`.
+ # * - mask: Dropout mask used to compute the output.
+
+ X = list(node.input)[0]
+ p = 0.5
+ seed = -1
+ if len(list(node.attribute)) > 0:
+ attributes = list(node.attribute)
+ if attributes[0].name != "ratio" or len(attributes) > 1:
+ raise Exception("Error in generating dropout call invalid
attributes" + str(attributes))
+ p = attributes[0].f
+
+ import_render = import_template.render(
+ path=path,
+ name=function_namespace
+ )
+
+ node_render = operator_template.render(
+ function_namespace=function_namespace,
+ function=function_name,
+ arguments=[X, p, seed],
+ outputs=list(node.output),
+ doc_string=node.doc_string
+ )
+ return GeneratedScriptPart(imports=[import_render], dml_script=node_render)
+
+
+def __compute_pad(auto_pad: str, Hf: int, Wf: int, strides: [int], pads:
[int], Hin: int, Win: int):
+ strideh = strides[0]
+ stridew = strides[1]
+
+ if auto_pad == "NOTSET":
+ padh = pads[0]
+ padw = pads[1]
+ if pads[0] != pads[2] or pads[1] != pads[3]:
+ raise Exception("Only support symmetric pads")
+ elif auto_pad == "SAME_UPPER" or "SAME_LOWER":
+ # pad such that output size matches input
+ padh = (Hin * (strideh - 1) + Hf - strideh) / 2
+ padw = (Win * (stridew - 1) + Wf - stridew) / 2
+ elif auto_pad == "VALID":
+ # no padding
+ padh = 0
+ padw = 0
+ else:
+ raise Exception("Invalid auto_pad value")
+
+ return padh, padw
+
+
+def gen_maxpool_call(env: jinja2.environment.Environment, graph:
onnx.GraphProto,
+ node: onnx.NodeProto) -> GeneratedScriptPart:
+ operator_template = env.get_template("operators/" +
"function_call.dml.jinja")
+ import_template = env.get_template("module_import.dml.jinja")
+
+ function_namespace = "maxpool_layer"
+ function_name = "forward"
+ path = "/nn/layers/max_pool2d.dml"
+
+ # * Inputs:
+ # * - X: Inputs, of shape (N, C*Hin*Win).
+ # * - C: Number of input channels (dimensionality of input depth).
+ # * - Hin: Input height.
+ # * - Win: Input width.
+ # * - Hf: Filter height.
+ # * - Wf: Filter width.
+ # * - strideh: Stride over height.
+ # * - stridew: Stride over width.
+ # * - padh: Padding for top and bottom sides.
+ # * A typical value is 0.
+ # * - padw: Padding for left and right sides.
+ # * A typical value is 0.
+ # *
+ # * Outputs:
+ # * - out: Outputs, of shape (N, C*Hout*Wout).
+ # * - Hout: Output height.
+ # * - Wout: Output width.
+
+ if len(node.input) != 1:
+ raise Exception("Invalid number of inputs")
+ if len(node.output) < 1 or len(node.output) > 2:
+ raise Exception("Invalid number of outputs")
+
+ # Inputs
+ x = onnx_helper.get_value_info(graph, node.input[0])
+ # dimensions are (N x C x H x W), where N is the batch size, C is the
number of channels,
+ # and H and W are the height and the width
+ x_shape = onnx_helper.get_valueinfo_dimensions(x)
+ if len(x_shape) > 4:
+ raise NotImplementedError("Currently only MaxPool-2D supported")
+
+ batch_size = x_shape[0] # TODO: currently not used
+ C = x_shape[1]
+ Hin = x_shape[2]
+ Win = x_shape[3]
+
+ # Attributes
+ auto_pad = "NOTSET"
+ ceil_mode = 0
+ dilations = [1, 1]
+ kernel_shape = None
+ pads = [0, 0, 0, 0]
+ storage_order = 0
+ strides = [1, 1]
+ for attribute in node.attribute:
+ if attribute.name == "auto_pad":
+ auto_pad = attribute.strings[0]
+ elif attribute.name == "ceil_mode":
+ ceil_mode = attribute.ints[0]
+ raise NotImplementedError("Currently no support for ceil_mode")
+ elif attribute.name == "dilations":
+ raise NotImplementedError
+ elif attribute.name == "kernel_shape":
+ kernel_shape = attribute.ints
+ elif attribute.name == "pads":
+ pads = attribute.ints
+ elif attribute.name == "storage_order":
+ raise NotImplementedError("Currently no support for storage_order")
+ elif attribute.name == "strides":
+ strides = attribute.ints
+ else:
+ raise Exception("Invalid Attribute")
+
+ if kernel_shape is None:
+ raise Exception("kernel_shape attribute is required")
+
+ Hf = kernel_shape[0]
+ Wf = kernel_shape[1]
+ strideh = strides[0]
+ stridew = strides[1]
+ padh, padw = __compute_pad(auto_pad, Hf, Wf, strides, pads, Hin, Win)
+
+ # Create render
+ node_render = operator_template.render(
+ function_namespace=function_namespace,
+ function=function_name,
+ arguments=[x.name, C, Hin, Win, Hf, Wf, strideh, stridew, padh, padw],
+ outputs=list(node.output),
+ doc_string=node.doc_string
+ )
+
+ import_render = import_template.render(
+ path=path,
+ name=function_namespace
+ )
+
+ return GeneratedScriptPart(imports=[import_render], dml_script=node_render)
+
+
+def gen_conv_call(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) \
+ -> GeneratedScriptPart:
+ operator_template = env.get_template("operators/" +
"function_call.dml.jinja")
+ import_template = env.get_template("module_import.dml.jinja")
+
+ function_namespace = "conv_layer"
+ function_name = "forward"
+ path = "/nn/layers/conv2d.dml"
+
+ # * Inputs:
+ # * - X: Inputs, of shape (N, C*Hin*Win).
+ # * - W: Weights, of shape (F, C*Hf*Wf).
+ # * - b: Biases, of shape (F, 1).
+ # * - C: Number of input channels (dimensionality of input depth).
+ # * - Hin: Input height.
+ # * - Win: Input width.
+ # * - Hf: Filter height.
+ # * - Wf: Filter width.
+ # * - strideh: Stride over height.
+ # * - stridew: Stride over width.
+ # * - padh: Padding for top and bottom sides.
+ # * - padw: Padding for left and right sides.
+ # *
+ # * Outputs:
+ # * - out: Outputs, of shape (N, F*Hout*Wout).
+ # * - Hout: Output height.
+ # * - Wout: Output width.
+
+ if len(node.input) < 2 or len(node.input) > 3:
+ raise Exception("Invalid number of inputs")
+
+ if len(node.output) > 1:
+ raise Exception("Invalid number of outputs")
+
+ # Inputs
+ x = onnx_helper.get_value_info(graph, node.input[0])
+ # size (N x C x H x W), where N is the batch size, C is the number of
channels, and H and W are the height and width
+ x_shape = onnx_helper.get_valueinfo_dimensions(x)
+ if len(x_shape) > 4:
+ raise NotImplementedError("Currently only Conv-2D supported")
+ batch_size = x_shape[0] # TODO: Batch size unused?
+ C = x_shape[1]
+ Hin = x_shape[2]
+ Win = x_shape[3]
+
+ w = onnx_helper.get_value_info(graph, node.input[1])
+ W_shape = onnx_helper.get_valueinfo_dimensions(w)
+ M = W_shape[0]
+ C_group = W_shape[1] # TODO Channels/group unused?
+ Hf = W_shape[2]
+ Wf = W_shape[3]
+
+ bias = None
+ bias_initializer_render = ""
+ if len(node.input) == 2:
+ # Generate 0-bias if no bias given
+ generated_bias_identifier = "gen_bias"
+ while onnx_helper.get_value_info(graph, generated_bias_identifier) is
not None:
+ # add random number to create unique identifier if already exists
+ generated_bias_identifier += str(randint())
+
+ bias_init_template = env.get_template("matrix_initialize.dml.jinja")
+ bias_initializer_render = bias_init_template.render(
+ identifier_name=generated_bias_identifier,
+ initializer_values=[0] * M,
+ rows=M,
+ cols=1
+ )
+ bias = generated_bias_identifier
+ elif len(node.input) == 3:
+ bias = node.input[3]
+
+ # Attributes
+ auto_pad = "NOTSET"
+ dilations = [1, 1]
+ group = 1
+ pads = [0, 0, 0, 0]
+ strides = [1, 1]
+ for attribute in node.attribute:
+ if attribute.name == "auto_pad":
+ auto_pad = attribute.strings[0]
+ elif attribute.name == "dilations":
+ raise NotImplementedError
+ elif attribute.name == "group":
+ group = attribute.ints[0]
+ elif attribute.name == "kernel_shape":
+ kernel_shape = attribute.ints
+ if kernel_shape[0] != Hf or kernel_shape[1] != Wf:
+ raise Exception("Invalid kernel shape")
+ elif attribute.name == "pads":
+ pads = attribute.ints
+ elif attribute.name == "strides":
+ strides = attribute.ints
+ else:
+ raise Exception("Invalid Attribute")
+
+ strideh = strides[0]
+ stridew = strides[1]
+ padh, padw = __compute_pad(auto_pad, Hf, Wf, strides, pads, Hin, Win)
+
+ node_render = operator_template.render(
+ function_namespace=function_namespace,
+ function=function_name,
+ arguments=[x.name, w.name, bias, C, Hin, Win, Hf, Wf, strideh,
stridew, padh, padw],
+ outputs=list(node.output),
+ doc_string=node.doc_string
+ )
+
+ import_render = import_template.render(
+ path=path,
+ name=function_namespace
+ )
+
+ return GeneratedScriptPart(imports=[import_render],
dml_script=bias_initializer_render + "\n" + node_render)
+
+
+def gen_if_call(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) \
+ -> GeneratedScriptPart:
+ operator_template = env.get_template("operators/if_operator.dml.jinja")
+ function_call_template =
env.get_template("operators/function_call.dml.jinja")
+
+ if len(node.input) != 1:
+ raise Exception("Wrong number of inputs")
+ if len(node.attribute) != 2:
+ raise Exception("Wrong number of attributes")
+ if node.attribute[0].name != "else_branch" or node.attribute[1].name !=
"then_branch":
+ raise Exception("Wrong attributes")
+
+ else_graph = node.attribute[0].g
+ then_graph = node.attribute[1].g
+
+ else_call = function_call_template.render(
+ doc_string="",
+ function_namespace="",
+ function=util.generate_function_name(else_graph.name),
+ arguments=[i.name for i in list(else_graph.input)],
+ outputs=[o.name for o in list(else_graph.output)],
+ )
+
+ then_call = function_call_template.render(
+ doc_string="",
+ function_namespace="",
+ function=util.generate_function_name(then_graph.name),
+ arguments=[i.name for i in list(then_graph.input)],
+ outputs=[o.name for o in list(then_graph.output)],
+ )
+
+ sub_graphs = [else_graph, then_graph]
+
+ node_render = operator_template.render(
+ cond=node.input[0],
+ then_function_call=then_call,
+ else_function_call=else_call
+ )
+
+ return GeneratedScriptPart(dml_script=node_render, sub_graphs=sub_graphs)
diff --git a/src/main/python/systemds/onnx_systemds/render.py
b/src/main/python/systemds/onnx_systemds/render.py
new file mode 100644
index 0000000..8140f95
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/render.py
@@ -0,0 +1,215 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from systemds.onnx_systemds import util, operator_gen
+import onnx
+import systemds.onnx_systemds.onnx_helper as onnx_helper
+import jinja2
+
+# Each operator listed shall be supported by this converter
+operator_generators = {
+ "Add": operator_gen.gen_2input_1output_operator,
+ "Sub": operator_gen.gen_2input_1output_operator,
+ "MatMul": operator_gen.gen_2input_1output_operator,
+ "Neg": operator_gen.gen_1input_1output_mat_operator,
+ "Xor": operator_gen.gen_simple_function_call,
+ "Or": operator_gen.gen_2input_1output_operator,
+ "And": operator_gen.gen_2input_1output_operator,
+ "Relu": operator_gen.gen_simple_function_call,
+ "Tanh": operator_gen.gen_simple_function_call,
+ "Sigmoid": operator_gen.gen_simple_function_call,
+ "Softmax": operator_gen.gen_simple_function_call,
+ "Dropout": operator_gen.gen_dropout_call,
+ "MaxPool": operator_gen.gen_maxpool_call,
+ "Conv": operator_gen.gen_conv_call,
+ "If": operator_gen.gen_if_call,
+}
+
+
+def gen_node_script(env: jinja2.environment.Environment, graph:
onnx.GraphProto, node: onnx.NodeProto) \
+ -> operator_gen.GeneratedScriptPart:
+ """
+ Generates a dml script snippet, the required imports and sub-graphs for
the given `node`
+
+ :param env: Jinja environment to load the template files
+ :param graph: the onnx graph for which the script shall be generated
+ :param node: the node for which the dml snippet shall be generated
+ :return: The generated script-part
+ """
+ try:
+ return operator_generators[node.op_type](env, graph, node)
+ except KeyError as error:
+ print("Operator " + str(node.op_type) + " not supported")
+ raise error
+
+
+def gen_graph_functions(env: jinja2.environment.Environment, main_graph:
onnx.GraphProto) -> ([str], str, [str]):
+ """
+ Traverses the node tree of the onnx-graph structure and generates a script
string for each node,
+ as well as a string for the required imports together with all functions
of sub-graphs.
+ The resulting lists are correctly ordered for inserting them in the dml
script.
+
+ :param env: Jinja environment to load the template files
+ :param main_graph: the onnx graph for which the script shall be generated
+ :return: Tuple (imports, main function, sub-graph functions)
+ """
+
+ main_function_node_scripts = []
+ sub_graph_functions = []
+ generated_imports = set() # set to avoid duplicate imports
+
+ node_tree = onnx_helper.NodeTree(main_graph.node)
+ available_outputs = [o.name for o in list(main_graph.output)]
+
+ while len(node_tree.nodes) != 0:
+ current_lowest_nodes = node_tree.end_nodes
+
+ # Find next operation to insert -> check if all outputs are available
+ next_tree_node = None
+ for tree_node in current_lowest_nodes:
+ if all(output in available_outputs for output in
list(tree_node.node.output)):
+ next_tree_node = tree_node
+ break
+ if not next_tree_node:
+ raise Exception("Error in parsing nodes, did not find a next node
to compute")
+
+ # Insert generated parts
+ generated_node = gen_node_script(env, main_graph, next_tree_node.node)
+ generated_imports.update(generated_node.imports)
+ main_function_node_scripts.append(generated_node.dml_script)
+ # handle sub-graphs
+ for sub_graph in generated_node.sub_graphs:
+ sub_graph_imports, sub_graph_main_function,
sub_graph_sub_graph_functions = \
+ gen_graph_functions(env, sub_graph)
+ # Inherit imports
+ generated_imports.update(sub_graph_imports)
+ # Inherit sub-graph functions of sub-graph
+ sub_graph_functions += sub_graph_sub_graph_functions
+ # Sub-graph main-function becomes sub-graph function
+ sub_graph_functions.append(sub_graph_main_function)
+
+ # After insertion the inputs to the node become available and the node
is removed
+ available_outputs += list(next_tree_node.node.input)
+ node_tree.remove_end_node(next_tree_node)
+
+ main_function_node_scripts.reverse()
+ main_graph_function = render_function(env, main_graph,
main_function_node_scripts)
+ return list(generated_imports), main_graph_function, sub_graph_functions
+
+
+def render_function(env: jinja2.environment.Environment, graph:
onnx.GraphProto,
+ generated_node_scripts: [str]) -> str:
+ """
+ Generates the dml function for the given `graph` and inserts the
'generated_node_scripts' in
+ the function-body.
+
+ :param env: Jinja environment to load the template files
+ :param graph: the graph for which the function shall be generated
+ :param generated_node_scripts: the node scripts in correct order for the
function-body
+ :return: the generated function
+ """
+ function_template = env.get_template("graph_function.dml.jinja")
+
+ inputs_with_initializers =
onnx_helper.get_graph_inputs_with_initializers(graph)
+ inputs_without_initializers =
onnx_helper.get_graph_inputs_without_initializers(graph)
+ outputs = list(graph.output)
+
+ # prepare inputs/outputs
+ function_inputs = [onnx_helper.PreparedValue(i) for i in
inputs_without_initializers]
+ function_outputs = [onnx_helper.PreparedValue(o) for o in outputs]
+ function_initializers = [onnx_helper.PreparedValue(info, init) for info,
init in inputs_with_initializers]
+
+ # render function
+ graph_function_render = function_template.render(
+ function_inputs=function_inputs,
+ function_outputs=function_outputs,
+ function_start_initializers=function_initializers,
+ graph_function_name=util.generate_function_name(graph.name),
+ graph_function_description=graph.doc_string,
+ node_scripts=generated_node_scripts
+ )
+ return graph_function_render
+
+
+def gen_model_header(env: jinja2.environment.Environment, model:
onnx.ModelProto) -> str:
+ """
+ Generates the header of the script for the given `model`
+
+ :param env: Jinja environment to load the template files
+ :param model: the onnx model for which the header shall be generated
+ :return: the generated header
+ """
+ header_template = env.get_template("model_header.dml.jinja")
+ header_infos = dict()
+
+ header_infos["ir_version"] = model.ir_version
+ opset_import = list()
+ for opset in model.opset_import:
+ if len(opset.domain) == 0:
+ opset.domain = "ONNX"
+ opset_import.append(opset.domain + "/" + str(opset.version))
+ header_infos["producer_name"] = model.producer_name
+ header_infos["producer_version"] = model.producer_version
+ header_infos["domain"] = model.domain
+ header_infos["model_version"] = model.model_version
+ header_infos["doc_string"] = model.doc_string
+ metadata_props = [[prop.key, prop.vale] for prop in model.metadata_props]
+
+ model_header_render = header_template.render(
+ header_components=header_infos,
+ opset_import=opset_import,
+ metadata_props=metadata_props
+ )
+ return model_header_render
+
+
+def gen_script(model: onnx.ModelProto, output_file: str = None) -> str:
+ """
+ Generate the dml script for the given `model` and return it.
+ If an `output_file` is given, the script is also written to a file.
+
+ :param model: the model for which the dml script shall be generated
+ :param output_file: (optional) the file to which the script shall be
written
+ :return: the generated dml-script
+ """
+ current_dir = os.path.dirname(os.path.realpath(__file__))
+ env = jinja2.Environment(loader=jinja2.FileSystemLoader(current_dir +
'/templates/'))
+ model_header_render = gen_model_header(env, model)
+ imports, main_function, sub_functions = gen_graph_functions(env,
model.graph)
+
+ wdir = ""
+ if len(imports) > 0:
+ # need to set wdir to enable imports
+ wdir = util.resolve_systemds_root() + "/scripts"
+
+ main_template = env.get_template("main.dml.jinja")
+ result_render = main_template.render(
+ title="This file was generated by onnx-systemds",
+ model_header_render=model_header_render,
+ wdir=wdir,
+ imports=imports,
+ main_function=main_function,
+ sub_functions=sub_functions
+ )
+ if output_file:
+ directory = os.path.dirname(output_file)
+ if len(directory) > 0:
+ os.makedirs(directory, exist_ok=True)
+ with open(output_file, 'w') as f:
+ f.write(result_render)
+
+ return result_render
diff --git
a/src/main/python/systemds/onnx_systemds/templates/graph_function.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/graph_function.dml.jinja
new file mode 100644
index 0000000..d35d70a
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/graph_function.dml.jinja
@@ -0,0 +1,54 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+{%- macro parameter_description() -%}
+# {{ "%-15.15s" | format("NAME") }} {{ "%-8.8s" | format("TYPE") }} {{
"%-12.12s" | format("VALUE_TYPE") }} {{ "%-8.8s" | format("SHAPE") }} {{
"%-10s" | format("MEANING") }}
+{%- endmacro -%}
+
+{% import 'util.dml.jinja' as util %}
+{%- if graph_function_description != "" -%} # {{ graph_function_description +
"\n" }} {%- endif -%}
+{%- if function_inputs | length > 0 -%}
+#
--------------------------------------------------------------------------------------------
+# INPUT PARAMETERS:
+#
--------------------------------------------------------------------------------------------
+{{ parameter_description() }}
+{% for input in function_inputs -%}
+# {{ "%-15.15s" | format(input.identifier_name) }} {{ "%-8.8s" |
format(input.data_type,) }} {{ "%-12.12s" | format(input.value_type,) }} {{
"%-8.8s" | format(input.shape,) }} {{ "%-10s" | format(input.description,) }}
+{% endfor -%}
+{%- endif %}
+{%- if function_outputs | length > 0 -%}
+#
---------------------------------------------------------------------------------------------
+# OUTPUTS:
+#
---------------------------------------------------------------------------------------------
+{{ parameter_description() }}
+{% for output in function_outputs -%}
+# {{ "%-15.15s" | format(output.identifier_name) }} {{ "%-8.8s" |
format(output.data_type,) }} {{ "%-12.12s" | format(output.value_type,) }} {{
"%-8.8s" | format(output.shape,) }} {{ "%-10s" | format(output.description,) }}
+{% endfor -%}
+#
---------------------------------------------------------------------------------------------
+{%- endif %}
+{{ graph_function_name }} = function(
+{{ util.generate_call_list(function_inputs) }}
+)
+return (
+{{ util.generate_call_list(function_outputs) }}
+) {
+{%- for initializer in function_start_initializers %}
+ {{ util.initialize_variable(initializer) }}
+{% endfor %}
+{% for node_script in node_scripts -%}
+{{ node_script | indent(4, first=true) }}
+{% endfor -%}
+}
\ No newline at end of file
diff --git
a/src/main/python/systemds/onnx_systemds/templates/graph_header.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/graph_header.dml.jinja
new file mode 100644
index 0000000..7e831fe
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/graph_header.dml.jinja
@@ -0,0 +1,22 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+# -------------------------------------------------------------
+# Graph
+{%- for key, value in header_components.items() %}
+# {{ key }}: {{ value }}
+{%- endfor %}
+# -------------------------------------------------------------
diff --git a/src/main/python/systemds/onnx_systemds/templates/main.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/main.dml.jinja
new file mode 100644
index 0000000..db77074
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/main.dml.jinja
@@ -0,0 +1,26 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+### {{ title }}
+{{ model_header_render }}
+{% if wdir != "" -%} setwd("{{ wdir }}") {%- endif %}
+{% for import in imports -%}
+{{ import }}
+{% endfor %}
+{% for sub_function in sub_functions %}
+{{ sub_function }}
+{% endfor %}
+{{ main_function }}
diff --git
a/src/main/python/systemds/onnx_systemds/templates/matrix_initialize.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/matrix_initialize.dml.jinja
new file mode 100644
index 0000000..a0d930c
--- /dev/null
+++
b/src/main/python/systemds/onnx_systemds/templates/matrix_initialize.dml.jinja
@@ -0,0 +1,24 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+{{ identifier_name }} = matrix("
+ {%- for value in initializer_values -%}
+ {{ value }}
+ {%- if not loop.last %} {% endif -%}
+ {%- endfor %}",
+ rows={{ rows }},
+ cols={{ cols -}}
+)
\ No newline at end of file
diff --git
a/src/main/python/systemds/onnx_systemds/templates/model_header.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/model_header.dml.jinja
new file mode 100644
index 0000000..e1ba18d
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/model_header.dml.jinja
@@ -0,0 +1,36 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+# -------------------------------------------------------------
+# Model
+{% for key, value in header_components.items() -%}
+{%- if value != "" -%}
+# {{ key }}: {{ value }} {%- if not loop.last -%} {{ "\n" }} {%- endif -%}
+{%- endif %}
+{%- endfor -%}
+# opset_import:
+{%- for opset_domain in opset_import %}
+# - domain/opset = {{ opset_domain }}
+{%- endfor %}
+{%- if metadata_props | length > 0 -%}
+{%- for key, value in metadata_props %}
+{% if loop.first %}
+# metadata_props:
+{% endif %}
+# - {{ key }}: {{ value }}
+{%- endfor %}
+{% endif %}
+# -------------------------------------------------------------
diff --git
a/src/main/python/systemds/onnx_systemds/templates/module_import.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/module_import.dml.jinja
new file mode 100644
index 0000000..a2a3b92
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/module_import.dml.jinja
@@ -0,0 +1,17 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+source("{{ path }}") as {{ name }}
\ No newline at end of file
diff --git
a/src/main/python/systemds/onnx_systemds/templates/operators/2input_1output_operator.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/operators/2input_1output_operator.dml.jinja
new file mode 100644
index 0000000..5d3d2fa
--- /dev/null
+++
b/src/main/python/systemds/onnx_systemds/templates/operators/2input_1output_operator.dml.jinja
@@ -0,0 +1,18 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+{% if doc_string != "" %}# {{ doc_string }} {% endif %}
+{{ output }} = ({{ input_0 }} {{ operator }} {{ input_1 }})
\ No newline at end of file
diff --git
a/src/main/python/systemds/onnx_systemds/templates/operators/function_call.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/operators/function_call.dml.jinja
new file mode 100644
index 0000000..96aecfa
--- /dev/null
+++
b/src/main/python/systemds/onnx_systemds/templates/operators/function_call.dml.jinja
@@ -0,0 +1,31 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+{% if doc_string != "" %}# {{ doc_string + "\n" }} {% endif %}
+ {%- if outputs | length == 1 -%}
+{%- for output in outputs -%}{{ output }} = {% endfor -%}
+{% elif outputs | length > 1 -%}
+[
+{%- for output in outputs -%}
+{{ output }}
+{%- if loop.last %}] = {% else -%}, {% endif -%}
+{%- endfor -%}
+{%- endif -%}
+{% if function_namespace != "" -%}{{ function_namespace }}::{%- endif -%}{{
function }}(
+ {%- for argument in arguments -%}
+ {{ argument }}
+ {%- if not loop.last -%}, {%- endif -%}
+{%- endfor -%})
\ No newline at end of file
diff --git
a/src/main/python/systemds/onnx_systemds/templates/operators/if_operator.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/operators/if_operator.dml.jinja
new file mode 100644
index 0000000..816987c
--- /dev/null
+++
b/src/main/python/systemds/onnx_systemds/templates/operators/if_operator.dml.jinja
@@ -0,0 +1,19 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+if ({{ cond }}) { {{ then_function_call | indent(4) }}
+} else { {{ else_function_call | indent(4) }}
+}
\ No newline at end of file
diff --git
a/src/main/python/systemds/onnx_systemds/templates/operators/neg.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/operators/neg.dml.jinja
new file mode 100644
index 0000000..4a733f0
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/operators/neg.dml.jinja
@@ -0,0 +1,18 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+{% if doc_string != "" %}# {{ doc_string }} {% endif %}
+{{ output }} = (-{{ input }})
diff --git a/src/main/python/systemds/onnx_systemds/templates/util.dml.jinja
b/src/main/python/systemds/onnx_systemds/templates/util.dml.jinja
new file mode 100644
index 0000000..b3f4084
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/templates/util.dml.jinja
@@ -0,0 +1,42 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to you under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ #}
+{% macro generate_call_list(variables) -%}
+{% for input in variables -%}
+ {%- if input.data_type == "matrix" -%}
+ {{ ' ' + "matrix" }} [{{ input.value_type }}] {{ input.identifier_name }}
+ {%- elif input.data_type == "scalar" -%}
+ {{ ' ' + input.value_type }} {{ input.identifier_name }}
+ {%- endif -%}
+ {%- if not loop.last -%},
+{% endif %}
+{%- endfor %}
+{%- endmacro %}
+
+{% macro initialize_variable(variable) -%}
+{%- if variable.data_type == "matrix" -%}
+{{ variable.identifier_name }} = matrix("
+ {%- for value in variable.initializer_values -%}
+ {{ value }}
+ {%- if not loop.last %} {% endif -%}
+ {%- endfor %}",
+ rows={{ variable.shape[0] }},
+ cols={{ variable.shape[1] -}}
+)
+{%- elif variable.data_type == "scalar" -%}
+ {{ input.value_type }} {{ input.identifier_name }} = {{
input.initializer_values[0] }}
+ {%- endif -%}
+{%- endmacro %}
\ No newline at end of file
diff --git a/src/main/python/systemds/onnx_systemds/util.py
b/src/main/python/systemds/onnx_systemds/util.py
new file mode 100644
index 0000000..4ba9638
--- /dev/null
+++ b/src/main/python/systemds/onnx_systemds/util.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+
+
+def generate_function_name(graph_name: str) -> str:
+ """
+ Takes the given graph name and constructs a valid function name from it.
+ :param graph_name: The name of the graph.
+ :return: the constructed function name.
+ """
+ function_name = "gen_" + re.sub(r"[-| ]", "_", graph_name.lower())
+ return re.sub(r"[^0-9a-z_]", "", function_name)
+
+
+def resolve_systemds_root() -> str:
+ """
+ Searches for SYSTEMDS_ROOT in the environment variables.
+ :return: The SYSTEMDS_ROOT path
+ """
+ try:
+ systemds_root_path = os.environ['SYSTEMDS_ROOT']
+ return systemds_root_path
+ except KeyError as error:
+ print("ERROR environment variable SYSTEMDS_ROOT_PATH not set could not
resolve path to module")
+ exit(-1)
diff --git a/src/main/python/systemds/__init__.py
b/src/main/python/tests/__init__.py
similarity index 83%
copy from src/main/python/systemds/__init__.py
copy to src/main/python/tests/__init__.py
index e51fbf8..217e5db 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/tests/__init__.py
@@ -1,4 +1,3 @@
-#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
@@ -16,7 +15,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-#-------------------------------------------------------------
-
-__all__ = ['context', 'matrix']
diff --git a/src/main/python/systemds/__init__.py
b/src/main/python/tests/onnx/__init__.py
similarity index 83%
copy from src/main/python/systemds/__init__.py
copy to src/main/python/tests/onnx/__init__.py
index e51fbf8..fe95886 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/tests/onnx/__init__.py
@@ -1,4 +1,3 @@
-#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
@@ -17,6 +16,3 @@
# specific language governing permissions and limitations
# under the License.
#
-#-------------------------------------------------------------
-
-__all__ = ['context', 'matrix']
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_conv_layer_2_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_conv_layer_2_wrapper.dml
new file mode 100644
index 0000000..d1cfcec
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_conv_layer_2_wrapper.dml
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_conv_layer_2.dml") as simple_conv
+
+
+x = matrix("0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34.", rows=1,
cols=35)
+W = matrix("1. 1. 1. 1. 1. 1. 1. 1. 1.", rows=1, cols=9)
+
+[o0, o1, o2] = simple_conv::gen_a_simple_convolution_graph(x, W)
+out = append(toString(o0), toString(o1))
+out = append(out, toString(o2))
+
+write(out, "tests/onnx/output_test/simple_conv_layer_2.out")
+
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_conv_layer_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_conv_layer_wrapper.dml
new file mode 100644
index 0000000..0dd1daa
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_conv_layer_wrapper.dml
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_conv_layer.dml") as simple_conv
+
+
+x = matrix("0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
19. 20. 21. 22. 23. 24.", rows=1, cols=25)
+W = matrix("1. 1. 1. 1. 1. 1. 1. 1. 1.", rows=1, cols=9)
+
+out = simple_conv::gen_a_simple_convolution_graph(x, W)
+
+write(out, "tests/onnx/output_test/simple_conv_layer.out")
+
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_dropout_layer_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_dropout_layer_wrapper.dml
new file mode 100644
index 0000000..25ff9cf
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_dropout_layer_wrapper.dml
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_dropout_layer.dml") as simple_dropout
+
+A = matrix("3 8 9 10", rows=2, cols=2)
+[O, mask] = simple_dropout::gen_simple_dropout_graph(A)
+out = append(toString(O), toString(mask))
+
+write(out, "tests/onnx/output_test/simple_dropout_layer.out")
diff --git a/src/main/python/tests/onnx/dml_wrapper/simple_if_graph_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_if_graph_wrapper.dml
new file mode 100644
index 0000000..755ec52
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_if_graph_wrapper.dml
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_if_graph.dml") as simple_if_graph
+
+A = matrix("3 8 9 10", rows=2, cols=2)
+cond = TRUE
+O_true = simple_if_graph::gen_a_simple_if_graph(A, cond)
+cond = FALSE
+O_false = simple_if_graph::gen_a_simple_if_graph(A, cond)
+
+out = append(toString(O_true), toString(O_false))
+write(out, "tests/onnx/output_test/simple_if_graph.out")
+
+
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_mat_add_mul_sub_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_mat_add_mul_sub_wrapper.dml
new file mode 100644
index 0000000..efcdb87
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_mat_add_mul_sub_wrapper.dml
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_mat_add_mul_sub.dml") as
simple_mat_add_mul_sub
+
+A = matrix("3 8 9 10", rows=2, cols=2)
+B = matrix("2 5 7 8", rows=2, cols=2)
+C = matrix("1 2 3 4", rows=2, cols=2)
+O = matrix(0, rows=2, cols=2)
+O =
simple_mat_add_mul_sub::gen_a_simple_matrix_addition_multiplication_and_substraction_test_graph(A,
B, C)
+
+write(O, "tests/onnx/output_test/simple_mat_add_mul_sub.out")
diff --git a/src/main/python/tests/onnx/dml_wrapper/simple_mat_add_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_mat_add_wrapper.dml
new file mode 100644
index 0000000..943e93d
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_mat_add_wrapper.dml
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_mat_add.dml") as simple_mat_add
+
+A = matrix("3 8 9 10", rows=2, cols=2)
+B = matrix("2 5 7 8", rows=2, cols=2)
+C = matrix("1 2 3 4", rows=2, cols=2)
+O = matrix(0, rows=2, cols=2)
+O = simple_mat_add::gen_a_simple_matrix_addition_test_graph(A, B, C)
+
+write(O, "tests/onnx/output_test/simple_mat_add.out")
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_mat_initialized_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_mat_initialized_wrapper.dml
new file mode 100644
index 0000000..f99a168
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_mat_initialized_wrapper.dml
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_mat_initialized.dml") as
simple_mat_initialized
+
+O = simple_mat_initialized::gen_simple_mat_initialized_graph()
+
+write(O, "tests/onnx/output_test/simple_mat_initialized.out")
+
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_maxpool_layer_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_maxpool_layer_wrapper.dml
new file mode 100644
index 0000000..c6604af
--- /dev/null
+++ b/src/main/python/tests/onnx/dml_wrapper/simple_maxpool_layer_wrapper.dml
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_maxpool_layer.dml") as simple_maxpool
+
+x = matrix("1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
25", rows=1, cols=25)
+out = simple_maxpool::gen_a_simple_maxpool_graph(x)
+
+write(out, "tests/onnx/output_test/simple_maxpool_layer.out")
+
diff --git
a/src/main/python/tests/onnx/dml_wrapper/simple_relu_tanh_sigmoid_softmax_wrapper.dml
b/src/main/python/tests/onnx/dml_wrapper/simple_relu_tanh_sigmoid_softmax_wrapper.dml
new file mode 100644
index 0000000..f306094
--- /dev/null
+++
b/src/main/python/tests/onnx/dml_wrapper/simple_relu_tanh_sigmoid_softmax_wrapper.dml
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source("tests/onnx/dml_output/simple_relu_tanh_sigmoid_softmax.dml") as
simple_relu_tanh
+
+A = matrix("0.2 2 12 20", rows=2, cols=2)
+B = matrix("0.2 2 12 20", rows=2, cols=2)
+C = matrix("0.2 2 12 20", rows=2, cols=2)
+D = matrix("0.2 2 12 20", rows=2, cols=2)
+[O, U, I, J] = simple_relu_tanh::gen_simple_relu_tanh_sigmoid_softmax_graph(A,
B, C, D)
+out = append(toString(O), toString(U))
+out = append(out, toString(I))
+out = append(out, toString(J))
+
+write(out, "tests/onnx/output_test/simple_relu_tanh_sigmoid_softmax.out")
\ No newline at end of file
diff --git
a/src/main/python/tests/onnx/output_reference/simple_conv_layer_2_reference.out
b/src/main/python/tests/onnx/output_reference/simple_conv_layer_2_reference.out
new file mode 100644
index 0000000..b7283d2
--- /dev/null
+++
b/src/main/python/tests/onnx/output_reference/simple_conv_layer_2_reference.out
@@ -0,0 +1,5 @@
+12.000 27.000 24.000 63.000 108.000 81.000 123.000 198.000 141.000 112.000
177.000 124.000
+
+54.000 72.000 144.000 162.000 234.000 252.000
+
+21.000 33.000 99.000 117.000 189.000 207.000 171.000 183.000
diff --git
a/src/main/python/tests/onnx/output_reference/simple_conv_layer_reference.out
b/src/main/python/tests/onnx/output_reference/simple_conv_layer_reference.out
new file mode 100644
index 0000000..a0631f6
--- /dev/null
+++
b/src/main/python/tests/onnx/output_reference/simple_conv_layer_reference.out
@@ -0,0 +1,25 @@
+1 1 12.0
+1 2 21.0
+1 3 27.0
+1 4 33.0
+1 5 24.0
+1 6 33.0
+1 7 54.0
+1 8 63.0
+1 9 72.0
+1 10 51.0
+1 11 63.0
+1 12 99.0
+1 13 108.0
+1 14 117.0
+1 15 81.0
+1 16 93.0
+1 17 144.0
+1 18 153.0
+1 19 162.0
+1 20 111.0
+1 21 72.0
+1 22 111.0
+1 23 117.0
+1 24 123.0
+1 25 84.0
diff --git
a/src/main/python/tests/onnx/output_reference/simple_if_graph_reference.out
b/src/main/python/tests/onnx/output_reference/simple_if_graph_reference.out
new file mode 100644
index 0000000..4c5603d
--- /dev/null
+++ b/src/main/python/tests/onnx/output_reference/simple_if_graph_reference.out
@@ -0,0 +1,5 @@
+0.995 1.000
+1.000 1.000
+
+3.000 8.000
+9.000 10.000
diff --git
a/src/main/python/tests/onnx/output_reference/simple_mat_add_mul_sub_reference.out
b/src/main/python/tests/onnx/output_reference/simple_mat_add_mul_sub_reference.out
new file mode 100644
index 0000000..6f41b8a
--- /dev/null
+++
b/src/main/python/tests/onnx/output_reference/simple_mat_add_mul_sub_reference.out
@@ -0,0 +1,4 @@
+1 1 -41.0
+1 2 -54.0
+2 1 -61.0
+2 2 -94.0
diff --git
a/src/main/python/tests/onnx/output_reference/simple_mat_add_reference.out
b/src/main/python/tests/onnx/output_reference/simple_mat_add_reference.out
new file mode 100644
index 0000000..0b12bb0
--- /dev/null
+++ b/src/main/python/tests/onnx/output_reference/simple_mat_add_reference.out
@@ -0,0 +1,4 @@
+1 1 9.0
+1 2 23.0
+2 1 28.0
+2 2 32.0
diff --git
a/src/main/python/tests/onnx/output_reference/simple_mat_initialized_reference.out
b/src/main/python/tests/onnx/output_reference/simple_mat_initialized_reference.out
new file mode 100644
index 0000000..54c1684
--- /dev/null
+++
b/src/main/python/tests/onnx/output_reference/simple_mat_initialized_reference.out
@@ -0,0 +1,9 @@
+1 1 1.0
+1 2 2.0
+1 3 3.0
+2 1 4.0
+2 2 5.0
+2 3 6.0
+3 1 7.0
+3 2 8.0
+3 3 9.0
diff --git
a/src/main/python/tests/onnx/output_reference/simple_maxpool_layer_reference.out
b/src/main/python/tests/onnx/output_reference/simple_maxpool_layer_reference.out
new file mode 100644
index 0000000..a594c76
--- /dev/null
+++
b/src/main/python/tests/onnx/output_reference/simple_maxpool_layer_reference.out
@@ -0,0 +1,25 @@
+1 1 13.0
+1 2 14.0
+1 3 15.0
+1 4 15.0
+1 5 15.0
+1 6 18.0
+1 7 19.0
+1 8 20.0
+1 9 20.0
+1 10 20.0
+1 11 23.0
+1 12 24.0
+1 13 25.0
+1 14 25.0
+1 15 25.0
+1 16 23.0
+1 17 24.0
+1 18 25.0
+1 19 25.0
+1 20 25.0
+1 21 23.0
+1 22 24.0
+1 23 25.0
+1 24 25.0
+1 25 25.0
diff --git
a/src/main/python/tests/onnx/output_reference/simple_relu_tanh_sigmoid_softmax_reference.out
b/src/main/python/tests/onnx/output_reference/simple_relu_tanh_sigmoid_softmax_reference.out
new file mode 100644
index 0000000..d92ee57
--- /dev/null
+++
b/src/main/python/tests/onnx/output_reference/simple_relu_tanh_sigmoid_softmax_reference.out
@@ -0,0 +1,11 @@
+0.200 2.000
+12.000 20.000
+
+0.197 0.964
+1.000 1.000
+
+0.550 0.881
+1.000 1.000
+
+0.142 0.858
+0.000 1.000
diff --git a/src/main/python/tests/onnx/test_models/model_generate.py
b/src/main/python/tests/onnx/test_models/model_generate.py
new file mode 100644
index 0000000..3d64eb6
--- /dev/null
+++ b/src/main/python/tests/onnx/test_models/model_generate.py
@@ -0,0 +1,388 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import onnx
+from onnx import helper
+
+
+def save_graph(graph_def, name):
+ model_def = helper.make_model(graph_def, producer_name="onnx-systemds
test-graph generator")
+ onnx.save_model(model_def, os.path.dirname(os.path.realpath(__file__)) +
"/" + name)
+
+
+def generate_simple_add_graph():
+ A = helper.make_tensor_value_info('A', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable A")
+ B = helper.make_tensor_value_info('B', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable B")
+ C = helper.make_tensor_value_info('C', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable C")
+ D = helper.make_tensor_value_info('D', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable D")
+ E = helper.make_tensor_value_info('E', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable E")
+ F = helper.make_tensor_value_info('F', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable F")
+
+ nodes = [
+ helper.make_node("Add", ['A', 'B'], ['C'], name="AddNodeName",
+ doc_string="This is a description of this Add
operation"),
+ helper.make_node("Add", ['C', 'D'], ['E'], name="MulNodeName",
+ doc_string="This is a description of this mul
operation"),
+ helper.make_node("Add", ['A', 'E'], ['F'], name="MulNodeName")
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="A simple matrix addition test graph",
+ inputs=[A, B, D],
+ outputs=[F],
+ initializer=None,
+ doc_string="Doc string of a simple matrix addition test graph",
+ value_info=[C, E]
+ )
+
+ save_graph(graph, "simple_mat_add.onnx")
+
+
+def generate_simple_mat_add_mul_sub_graph():
+ A = helper.make_tensor_value_info('A', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable A")
+ B = helper.make_tensor_value_info('B', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable B")
+ C = helper.make_tensor_value_info('C', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable C")
+ D = helper.make_tensor_value_info('D', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable D")
+ E = helper.make_tensor_value_info('E', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable E")
+ F = helper.make_tensor_value_info('F', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable F")
+
+ nodes = [
+ helper.make_node("Add", ['A', 'B'], ['C'], name="AddNodeName",
+ doc_string="This is a description of this Add
operation"),
+ helper.make_node("MatMul", ['C', 'D'], ['E'], name="MulNodeName",
+ doc_string="This is a description of this mul
operation"),
+ helper.make_node("Sub", ['A', 'E'], ['F'], name="MulNodeName")
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="A simple matrix addition, multiplication and substraction test
graph",
+ inputs=[A, B, D],
+ outputs=[F],
+ initializer=None,
+ doc_string="Doc string with additional information",
+ value_info=[C, E]
+ )
+
+ save_graph(graph, "simple_mat_add_mul_sub.onnx")
+
+
+def generate_simple_initialized_graph():
+ A_init = helper.make_tensor("A_init", onnx.TensorProto.FLOAT, [3, 3], [1,
2, 3, 4, 5, 6, 7, 8, 9])
+ B_init = helper.make_tensor("B_init", onnx.TensorProto.FLOAT, [3, 3], [2,
4, 6, 8, 10, 12, 14, 16, 18])
+
+ B_init_valinfo = helper.make_tensor_value_info(B_init.name,
onnx.TensorProto.FLOAT, B_init.dims,
+ doc_string="A single value
tensor")
+ A_init_valinfo = helper.make_tensor_value_info(A_init.name,
onnx.TensorProto.FLOAT, A_init.dims,
+ doc_string="A 3x3 matrix")
+ C = helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [3, 3],
doc_string="This is the output C")
+ D = helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [3, 3],
doc_string="This is the output D")
+
+ nodes = [
+ helper.make_node("Neg", ["A_init"], ["C"]),
+ helper.make_node("Add", ["B_init", "C"], ["D"])
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="Simple mat initialized graph",
+ inputs=[A_init_valinfo, B_init_valinfo],
+ outputs=[D],
+ initializer=[A_init, B_init],
+ value_info=[A_init_valinfo, B_init_valinfo, C, D]
+ )
+
+ save_graph(graph, "simple_mat_initialized.onnx")
+
+
+def generate_simple_boolean_noshape():
+ A = helper.make_tensor_value_info('A', onnx.TensorProto.BOOL,
+ doc_string="This is a description of
variable A", shape=[])
+ B = helper.make_tensor_value_info('B', onnx.TensorProto.BOOL,
+ doc_string="This is a description of
variable B", shape=[])
+ C = helper.make_tensor_value_info('C', onnx.TensorProto.BOOL,
+ doc_string="This is a description of
variable C", shape=[])
+ D = helper.make_tensor_value_info('D', onnx.TensorProto.BOOL,
+ doc_string="This is a description of
variable D", shape=[])
+ E = helper.make_tensor_value_info('E', onnx.TensorProto.BOOL,
+ doc_string="This is a description of
variable E", shape=[])
+
+ nodes = [
+ helper.make_node("Or", ["A", "B"], ["C"]),
+ helper.make_node("And", ["C", "A"], ["D"]),
+ helper.make_node("Xor", ["B", "D"], ["E"])
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="Simple bool and or xor noshape graph",
+ inputs=[A, B],
+ outputs=[E],
+ value_info=[A, B, C, D, E]
+ )
+
+ save_graph(graph, "simple_bool_and_or_xor_noshape.onnx")
+
+
+def generate_simple_relu_tanh_sigmoid_softmax():
+ A = helper.make_tensor_value_info('A', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable A")
+ B = helper.make_tensor_value_info('B', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable B")
+ C = helper.make_tensor_value_info('C', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable C")
+ D = helper.make_tensor_value_info('D', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable D")
+ E = helper.make_tensor_value_info('E', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable E")
+ F = helper.make_tensor_value_info('F', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable F")
+ G = helper.make_tensor_value_info('G', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable G")
+ H = helper.make_tensor_value_info('H', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable H")
+
+ nodes = [
+ helper.make_node("Relu", ["A"], ["E"], doc_string="Call of Relu
function"),
+ helper.make_node("Tanh", ["B"], ["F"], doc_string="Call of Tanh
function"),
+ helper.make_node("Sigmoid", ["C"], ["G"], doc_string="Call of Sigmoid
function"),
+ helper.make_node("Softmax", ["D"], ["H"], doc_string="Call of Softmax
function")
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="Simple relu tanh sigmoid softmax graph",
+ inputs=[A, B, C, D],
+ outputs=[E, F, G, H],
+ value_info=[A, B, D, E, F, G, H],
+ doc_string="This graph tests simple nn layer calls"
+ )
+
+ save_graph(graph, "simple_relu_tanh_sigmoid_softmax.onnx")
+
+
+def generate_simple_dropout_layer():
+ x = helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable x")
+ y = helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable y")
+ z = helper.make_tensor_value_info('z', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable z")
+ mask = helper.make_tensor_value_info('mask', onnx.TensorProto.BOOL, [2, 2],
+ doc_string="This is a description of
variable mask")
+
+ nodes = [
+ onnx.helper.make_node(
+ op_type='Dropout',
+ inputs=['x'],
+ outputs=['y'],
+ ratio=.1
+ ),
+ onnx.helper.make_node(
+ op_type='Dropout',
+ inputs=['y'],
+ outputs=['z', 'mask'],
+ ratio=.1
+ )
+ ]
+
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="Simple dropout graph",
+ inputs=[x],
+ outputs=[z, mask],
+ value_info=[x, y, z],
+ doc_string="This graph tests a simple dropout layer call"
+ )
+
+ save_graph(graph, "simple_dropout_layer.onnx")
+
+
+def generate_simple_conv():
+ x_init_valinfo = helper.make_tensor_value_info("x",
onnx.TensorProto.FLOAT, [1, 1, 5, 5],
+ doc_string="A multi
dimensional input value")
+ W_init_valinfo = helper.make_tensor_value_info("W",
onnx.TensorProto.FLOAT, [1, 1, 3, 3],
+ doc_string="A multi
dimensional input value")
+ y_valinfo = helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1,
1, 5, 5], doc_string="Y output")
+
+ # Convolution with padding
+ node = onnx.helper.make_node(
+ op_type='Conv',
+ inputs=['x', 'W'],
+ outputs=['y'],
+ kernel_shape=[3, 3],
+ # Default values for other attributes: strides=[1, 1], dilations=[1,
1], groups=1
+ pads=[1, 1, 1, 1],
+ )
+
+ nodes = [node]
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="A simple convolution graph",
+ inputs=[x_init_valinfo, W_init_valinfo],
+ outputs=[y_valinfo],
+ value_info=[x_init_valinfo, W_init_valinfo]
+ )
+
+ save_graph(graph, "simple_conv_layer.onnx")
+
+
+def generate_simple_conv_2():
+ x_init_valinfo = helper.make_tensor_value_info("x",
onnx.TensorProto.FLOAT, [1, 1, 7, 5],
+ doc_string="A multi
dimensional input value")
+ W_init_valinfo = helper.make_tensor_value_info("W",
onnx.TensorProto.FLOAT, [1, 1, 3, 3],
+ doc_string="A multi
dimensional input value")
+
+ y_0_valinfo = helper.make_tensor_value_info("y_0", onnx.TensorProto.FLOAT,
[1, 1, 4, 3], doc_string="Y0 output")
+
+ # Convolution with strides=2 and padding
+ node_with_padding = onnx.helper.make_node(
+ op_type='Conv',
+ inputs=['x', 'W'],
+ outputs=['y_0'],
+ kernel_shape=[3, 3],
+ pads=[1, 1, 1, 1],
+ strides=[2, 2], # Default values for other attributes: dilations=[1,
1], groups=1
+ )
+
+ y_1_valinfo = helper.make_tensor_value_info("y_1", onnx.TensorProto.FLOAT,
[1, 1, 3, 2], doc_string="Y1 output")
+
+ # Convolution with strides=2 and no padding
+ node_without_padding = onnx.helper.make_node(
+ op_type='Conv',
+ inputs=['x', 'W'],
+ outputs=['y_1'],
+ kernel_shape=[3, 3],
+ pads=[0, 0, 0, 0],
+ strides=[2, 2], # Default values for other attributes: dilations=[1,
1], groups=1
+ )
+
+ y_2_valinfo = helper.make_tensor_value_info("y_2", onnx.TensorProto.FLOAT,
[1, 1, 4, 2], doc_string="Y2 output")
+ # Convolution with strides=2 and padding only along one dimension (the H
dimension in NxCxHxW tensor)
+ node_with_asymmetric_padding = onnx.helper.make_node(
+ op_type='Conv',
+ inputs=['x', 'W'],
+ outputs=['y_2'],
+ kernel_shape=[3, 3],
+ pads=[1, 0, 1, 0],
+ strides=[2, 2], # Default values for other attributes: dilations=[1,
1], groups=1
+ )
+
+ nodes = [node_with_padding, node_without_padding,
node_with_asymmetric_padding]
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="A simple convolution graph",
+ inputs=[x_init_valinfo, W_init_valinfo],
+ outputs=[y_0_valinfo, y_1_valinfo, y_2_valinfo],
+ value_info=[x_init_valinfo, W_init_valinfo]
+ )
+
+ save_graph(graph, "simple_conv_layer_2.onnx")
+
+
+def generate_simple_maxpool():
+ x_init_valinfo = helper.make_tensor_value_info("x",
onnx.TensorProto.FLOAT, [1, 1, 5, 5],
+ doc_string="A multi
dimensional input value")
+
+ y_valinfo = helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1,
1, 5, 5], doc_string="Y output")
+
+ node = onnx.helper.make_node(
+ op_type='MaxPool',
+ inputs=['x'],
+ outputs=['y'],
+ kernel_shape=[5, 5],
+ pads=[2, 2, 2, 2]
+ )
+ nodes = [node]
+ graph = helper.make_graph(
+ nodes=nodes,
+ name="A simple maxpool graph",
+ inputs=[x_init_valinfo],
+ outputs=[y_valinfo],
+ value_info=[x_init_valinfo]
+ )
+
+ save_graph(graph, "simple_maxpool_layer.onnx")
+
+
+def generate_simple_if():
+ A = helper.make_tensor_value_info('A', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable A")
+ E = helper.make_tensor_value_info('E', onnx.TensorProto.FLOAT, [2, 2],
+ doc_string="This is a description of
variable E")
+ condition = helper.make_tensor_value_info('cond', onnx.TensorProto.BOOL,
[1], doc_string="Condition for the if")
+
+ else_node = helper.make_node("Relu", ["A"], ["E"], doc_string="Call of
Relu function in else branch")
+ else_graph = helper.make_graph(
+ nodes=[else_node],
+ name="The Else branch graph",
+ inputs=[A],
+ outputs=[E],
+ value_info=[A]
+ )
+
+ then_node = helper.make_node("Tanh", ["A"], ["E"], doc_string="Call of
Tanh function in then branch")
+ then_graph = helper.make_graph(
+ nodes=[then_node],
+ name="The Then branch graph",
+ inputs=[A],
+ outputs=[E],
+ value_info=[A]
+ )
+
+ node = helper.make_node(
+ op_type="If",
+ name="A simple if graph",
+ inputs=["cond"],
+ outputs=["E"],
+ else_branch=else_graph,
+ then_branch=then_graph
+ )
+
+ graph = helper.make_graph(
+ nodes=[node],
+ name="A simple if graph",
+ inputs=[A, condition],
+ outputs=[E]
+ )
+
+ save_graph(graph, "simple_if_graph.onnx")
+
+
+if __name__ == '__main__':
+ generate_simple_add_graph()
+ # generate_simple_boolean_noshape()
+ generate_simple_mat_add_mul_sub_graph()
+ generate_simple_initialized_graph()
+ generate_simple_relu_tanh_sigmoid_softmax()
+ generate_simple_dropout_layer()
+ generate_simple_conv()
+ generate_simple_conv_2()
+ generate_simple_maxpool()
+ generate_simple_if()
diff --git a/src/main/python/tests/onnx/test_simple.py
b/src/main/python/tests/onnx/test_simple.py
new file mode 100644
index 0000000..bb7d822
--- /dev/null
+++ b/src/main/python/tests/onnx/test_simple.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import tests.onnx.util as util
+
+
+class TestSimpleOperators(unittest.TestCase):
+ def test_simple_mat_add(self):
+ name = "simple_mat_add"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_mat_add_mul_sub(self):
+ name = "simple_mat_add_mul_sub"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_mat_initialized(self):
+ name = "simple_mat_initialized"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_relu_tanh_sigmoid_softmax(self):
+ name = "simple_relu_tanh_sigmoid_softmax"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_conv2d_layer(self):
+ name = "simple_conv_layer"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_conv2d_layer_2(self):
+ name = "simple_conv_layer_2"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_maxpool_layer(self):
+ name = "simple_maxpool_layer"
+ util.run_and_compare_output(name, self)
+
+ def test_simple_if_graph(self):
+ name = "simple_if_graph"
+ util.run_and_compare_output(name, self)
+
+ # TODO: dml implementation of dropout does not work
+ # def test_simple_dropout_layer(self):
+ # name = "simple_dropout_layer"
+ # test_util.run_and_compare_output(name, self)
+
+ # TODO: dml does not support boolean matrices?
+ # def test_simple_bool_and_or_xor_noshape(self):
+ # name = "simple_bool_and_or_xor_noshape"
+ # test_util.run_and_compare_output(name, self)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/src/main/python/tests/onnx/util.py
b/src/main/python/tests/onnx/util.py
new file mode 100644
index 0000000..1ee82e6
--- /dev/null
+++ b/src/main/python/tests/onnx/util.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import subprocess
+import unittest
+
+from systemds.onnx_systemds.convert import onnx2systemds
+from systemds.onnx_systemds.util import resolve_systemds_root
+
+
+def invoke_systemds(input_file: str, args: [str] = None) -> int:
+ """
+ Runs systemds by running the script in $SYSTEMDS_ROOT_PATH/bin/systemds
with the provided input_file,
+ will fail if environment variable SYSTEMDS_ROOT_PATH is not set.
+
+ :param input_file: the dml script to run
+ :param args: additional arguments if needed
+ :return: the return-code of systemds
+ """
+ if args is None:
+ args = []
+
+ systemds_root_path = resolve_systemds_root()
+
+ try:
+ realpath_input = os.path.relpath(input_file, os.getcwd())
+ res = subprocess.run([systemds_root_path + "/bin/systemds",
realpath_input] + args,
+ check=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ timeout=10000)
+ except subprocess.CalledProcessError as systemds_error:
+ print("SYSTEMDS FAILED!")
+ print("error code: " + str(systemds_error.returncode))
+ print("Stdout:")
+ print(systemds_error.output.decode("utf-8"))
+ print("Stderr:")
+ print(systemds_error.stderr.decode("utf-8"))
+ return systemds_error.returncode
+
+ stderr = res.stderr.decode("utf-8")
+ if len(stderr) != 0:
+ print("No exception but stderr was not empty:")
+ print(stderr)
+
+ return res.returncode
+
+
+def run_and_compare_output(name: str, test_case: unittest.TestCase) -> None:
+ """
+ Converts the onnx-model to dml, runs systemds with the dml-wrapper and
compares the resulting output
+ with the reference output.
+ :param name: The name of the test-case (also used for finding onnx-model,
dml-wrapper and reference output)
+ :param test_case: The testcase
+ """
+ onnx2systemds("tests/onnx/test_models/" + name + ".onnx",
"tests/onnx/dml_output/" + name + ".dml")
+ ret = invoke_systemds("tests/onnx/dml_wrapper/" + name + "_wrapper.dml")
+ test_case.assertEqual(ret, 0, "systemds failed")
+
+ # We read the file content such that pytest can present the actual
difference between the files
+ with open("tests/onnx/output_reference/" + name + "_reference.out") as
reference_file:
+ reference_content = reference_file.read()
+
+ with open("tests/onnx/output_test/" + name + ".out") as output_file:
+ test_content = output_file.read()
+
+ test_case.assertEqual(
+ test_content,
+ reference_content,
+ "generated output differed from reference output"
+ )