This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3ab5124d5a [SYSTEMDS-3754] Python API missing builtin trace
3ab5124d5a is described below
commit 3ab5124d5adb42d3dbee077b9c963eda344c98ba
Author: e-strauss <[email protected]>
AuthorDate: Tue Sep 3 12:24:31 2024 +0200
[SYSTEMDS-3754] Python API missing builtin trace
Closes #2091
---
src/main/python/systemds/operator/nodes/matrix.py | 7 +++++++
src/main/python/tests/matrix/test_aggregations.py | 9 +++++++++
2 files changed, 16 insertions(+)
diff --git a/src/main/python/systemds/operator/nodes/matrix.py
b/src/main/python/systemds/operator/nodes/matrix.py
index ed47192a92..2862686ca3 100644
--- a/src/main/python/systemds/operator/nodes/matrix.py
+++ b/src/main/python/systemds/operator/nodes/matrix.py
@@ -259,6 +259,13 @@ class Matrix(OperationNode):
raise ValueError(
f"Axis has to be either 0, 1 or None, for column, row or complete
{self.operation}")
+ def trace(self) -> 'Scalar':
+ """Calculate trace.
+
+ :return: `Matrix` representing operation
+ """
+ return Scalar(self.sds_context, 'trace', [self])
+
def abs(self) -> 'Matrix':
"""Calculate absolute.
diff --git a/src/main/python/tests/matrix/test_aggregations.py
b/src/main/python/tests/matrix/test_aggregations.py
index 1d7172e32b..1b345d6b21 100644
--- a/src/main/python/tests/matrix/test_aggregations.py
+++ b/src/main/python/tests/matrix/test_aggregations.py
@@ -112,5 +112,14 @@ class TestMatrixAggFn(unittest.TestCase):
self.assertTrue(np.allclose(
self.sds.from_numpy(m1).max(axis=1).compute(),
m1.max(axis=1).reshape(dim, 1)))
+ def test_trace1(self):
+ self.assertTrue(np.allclose(
+ self.sds.from_numpy(m1).trace().compute(), m1.trace()))
+
+ def test_trace2(self):
+ self.assertTrue(np.allclose(
+ self.sds.from_numpy(m2).trace().compute(), m2.trace()))
+
+
if __name__ == "__main__":
unittest.main(exit=False)