This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 71b9c074ac [SYSTEMDS-3739] Python API Builtin diag
71b9c074ac is described below
commit 71b9c074ac04989b7c938b62ca97ce64788aef15
Author: e-strauss <[email protected]>
AuthorDate: Mon Sep 2 11:49:39 2024 +0200
[SYSTEMDS-3739] Python API Builtin diag
This patch adds the built in operator for diag
to the python api.
Closes #2085
---
src/main/python/systemds/operator/nodes/matrix.py | 8 ++-
src/main/python/tests/matrix/test_diag.py | 71 +++++++++++++++++++++++
2 files changed, 78 insertions(+), 1 deletion(-)
diff --git a/src/main/python/systemds/operator/nodes/matrix.py
b/src/main/python/systemds/operator/nodes/matrix.py
index 23c40422eb..ed47192a92 100644
--- a/src/main/python/systemds/operator/nodes/matrix.py
+++ b/src/main/python/systemds/operator/nodes/matrix.py
@@ -387,7 +387,13 @@ class Matrix(OperationNode):
:return: the OperationNode representing this operation
"""
return Matrix(self.sds_context, 'cholesky', [self])
-
+
+ def diag(self) -> 'Matrix':
+ """ Create diagonal matrix from (n x 1) matrix, or take diagonal from
square matrix
+
+ :return: the OperationNode representing this operation
+ """
+ return Matrix(self.sds_context, 'diag', [self])
def svd(self) -> 'Matrix':
"""
diff --git a/src/main/python/tests/matrix/test_diag.py
b/src/main/python/tests/matrix/test_diag.py
new file mode 100644
index 0000000000..7e3f103aeb
--- /dev/null
+++ b/src/main/python/tests/matrix/test_diag.py
@@ -0,0 +1,71 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+import numpy as np
+from systemds.context import SystemDSContext
+
+
+class TestDIAG(unittest.TestCase):
+ def setUp(self):
+ self.sds = SystemDSContext()
+
+ def tearDown(self):
+ self.sds.close()
+
+ def test_diag_basic1(self):
+ input_matrix = np.array([1, 2, 3, 4])
+
+ sds_input = self.sds.from_numpy(input_matrix)
+ sds_result = sds_input.diag().compute()
+ np_result = np.diag(input_matrix)
+ print(np_result)
+ print(sds_result)
+ assert np.allclose(sds_result, np_result, 1e-9)
+
+ def test_diag_basic2(self):
+ input_matrix = np.array([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9, 10, 11, 12],
+ [13, 14, 15, 16]])
+
+ sds_input = self.sds.from_numpy(input_matrix)
+ sds_result = sds_input.diag().compute()
+ np_result = np.reshape(np.diag(input_matrix), (-1,1))
+ assert np.allclose(sds_result, np_result, 1e-9)
+
+ def test_diag_random1(self):
+ input_matrix = np.random.random(10)
+ sds_input = self.sds.from_numpy(input_matrix)
+ sds_result = sds_input.diag().compute()
+ np_result = np.diag(input_matrix)
+ assert np.allclose(sds_result, np_result, 1e-9)
+
+ def test_diag_random2(self):
+ input_matrix = np.random.random((10, 10))
+ sds_input = self.sds.from_numpy(input_matrix)
+ sds_result = sds_input.diag().compute()
+ np_result = np.reshape(np.diag(input_matrix), (-1,1))
+ assert np.allclose(sds_result, np_result, 1e-9)
+
+
+if __name__ == '__main__':
+ unittest.main()