This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d0de9067bc [Relax] Allow ingesting vector_norm from torch.export
(#17722)
d0de9067bc is described below
commit d0de9067bcb5360f540fb4528280129a4a43ec6a
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 10 11:15:36 2025 -0400
[Relax] Allow ingesting vector_norm from torch.export (#17722)
- Implement torch's vector_norm as a function of existing relax ops
- add a unit test
---
.../frontend/torch/base_fx_graph_translator.py | 36 +++++++++
.../frontend/torch/exported_program_translator.py | 2 +
tests/python/relax/test_from_exported_to_cuda.py | 92 ++++++++++++++++++++++
3 files changed, 130 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a9f54d91e3..8b771b5d2f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -306,6 +306,42 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ ########## Linear Algebra ##########
+
+ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
+
+ args = self.retrieve_args(node)
+
+ data = args[0]
+ # Default ord=2 if not supplied
+ ord_val = args[1] if len(args) > 1 else 2.0
+ dim = args[2] if len(args) > 2 else None
+ keepdim = args[3] if len(args) > 3 else False
+
+ # If ord_val is a Python float/int, wrap it in a Relax const
+ # so that it matches data's dtype.
+ dtype = data.struct_info.dtype
+ ord_expr = (
+ ord_val if isinstance(ord_val, relax.Expr) else
relax.const(float(ord_val), dtype)
+ )
+ # Reciprocal
+ reci_expr = (
+ relax.op.divide(relax.const(1.0, dtype), ord_expr)
+ if isinstance(ord_val, relax.Expr)
+ else relax.const(1.0 / float(ord_val), dtype)
+ )
+
+ # abs(data)
+ abs_data = self.block_builder.emit(relax.op.abs(data))
+ # abs_data^ord
+ abs_data_pow = self.block_builder.emit(relax.op.power(abs_data,
ord_expr))
+ # sum over dim
+ reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim,
keepdims=keepdim))
+ # (sum(...))^(1/ord)
+ norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr))
+
+ return norm_val
+
########## Neural Network ##########
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 335de7a240..e8e8706714 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -231,6 +231,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"__or__.Scalar": self._binary_op(relax.op.bitwise_or,
operator.or_),
"__xor__.Tensor": self._binary_op(relax.op.bitwise_xor,
operator.xor),
"__xor__.Scalar": self._binary_op(relax.op.bitwise_xor,
operator.xor),
+ # linear algebra
+ "linalg_vector_norm.default": self._linalg_vector_norm,
# neural network
"_native_batch_norm_legit_no_training.default":
self._batch_norm_legit_no_training,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
new file mode 100644
index 0000000000..d39bb8e9fe
--- /dev/null
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import torch
+from torch.export import export
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.frontend.torch import from_exported_program
+
+
+def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev):
+ """
+ This util ensures that a torch module can successfully be exported to TVM
+ using torch.export and that the resuling IR program gives the same result
+ as PyTorch when ran on CUDA.
+ """
+ raw_data_for_tvm = raw_data.copy() # In case the data is modified
+ torch_data = torch.from_numpy(raw_data)
+ example_args = (torch_data,)
+
+ with torch.no_grad():
+ exported_program = export(torch_module, example_args)
+ mod_from_torch = from_exported_program(exported_program,
keep_params_as_input=True)
+
+ tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch)
+
+ relax_pipeline =
relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda()))
+ # TODO try pipeline below?
+ # releax_pipeline =
relax.backend.cuda.pipeline.get_default_pipeline(target)
+ ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline)
+ vm = relax.VirtualMachine(ex, dev)
+
+ gpu_data = tvm.nd.array(raw_data_for_tvm, dev)
+ gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
+ gpu_out = vm["main"](gpu_data, *gpu_params)
+
+ pytorch_out = torch_module(torch_data).detach().numpy()
+ actual = gpu_out[0].numpy()
+ desired = pytorch_out
+ np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
+
+
[email protected]_targets("cuda")
+def test_linalg_vector_norm(target, dev):
+ class VectorNorm0(torch.nn.Module):
+ def forward(self, x):
+ return torch.linalg.vector_norm(x, ord=1, dim=-1)
+
+ class VectorNorm1(torch.nn.Module):
+ def forward(self, x):
+ return torch.linalg.vector_norm(x, ord=2, dim=2)
+
+ class VectorNorm2(torch.nn.Module):
+ def forward(self, x):
+ return torch.linalg.vector_norm(x, ord=1, dim=-1)
+
+ class VectorNorm3(torch.nn.Module):
+ def forward(self, x):
+ return torch.linalg.vector_norm(x, ord=2, dim=2)
+
+ raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)
+
+ torch_module0 = VectorNorm0().eval()
+ torch_module1 = VectorNorm1().eval()
+ torch_module2 = VectorNorm2().eval()
+ torch_module3 = VectorNorm3().eval()
+
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3,
target, dev)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()