You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/12/06 12:52:11 UTC

[systemds] 01/02: [SYSTEMDS-3238] Python GMM test

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit eee15cca444ab777d2955b624b8dd3073ba899c8
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Fri Dec 3 16:03:17 2021 +0100

    [SYSTEMDS-3238] Python GMM test
    
    This commit adds a small GMM test in python for outlier/anomaly detection.
---
 .../test_source_list.py => algorithms/test_gmm.py} | 53 +++++++++++-----------
 src/main/python/tests/source/test_source_list.py   |  2 +-
 2 files changed, 28 insertions(+), 27 deletions(-)

diff --git a/src/main/python/tests/source/test_source_list.py b/src/main/python/tests/algorithms/test_gmm.py
similarity index 51%
copy from src/main/python/tests/source/test_source_list.py
copy to src/main/python/tests/algorithms/test_gmm.py
index d4ab391..066c957 100644
--- a/src/main/python/tests/source/test_source_list.py
+++ b/src/main/python/tests/algorithms/test_gmm.py
@@ -21,15 +21,13 @@
 
 import unittest
 
-import numpy as np
 from systemds.context import SystemDSContext
-from systemds.operator.algorithm.builtin.scale import scale
+from systemds.operator.algorithm import gmm, gmmPredict
 
 
-class TestSource_01(unittest.TestCase):
+class TestGMM(unittest.TestCase):
 
     sds: SystemDSContext = None
-    source_path: str = "./tests/source/source_with_list_input.dml"
 
     @classmethod
     def setUpClass(cls):
@@ -39,25 +37,28 @@ class TestSource_01(unittest.TestCase):
     def tearDownClass(cls):
         cls.sds.close()
 
-    def test_single_return(self):
-        arr = self.sds.array(self.sds.full((10, 10), 4))
-        c = self.sds.source(self.source_path, "test").func(arr)
-        res = c.sum().compute()
-        self.assertTrue(res == 10*10*4)
-
-    def test_input_multireturn(self):
-        m = self.sds.full((10, 10), 2)
-        [a, b, c] = scale(m, True, True)
-        arr = self.sds.array(a, b, c)
-        c = self.sds.source(self.source_path, "test").func(arr)
-        res = c.sum().compute(verbose=True)
-        self.assertTrue(res == 0)
-
-    # [SYSTEMDS-3224] https://issues.apache.org/jira/browse/SYSTEMDS-3224
-    # def test_multi_return(self):
-    #     arr = self.sds.array(
-    #         self.sds.full((10, 10), 4),
-    #         self.sds.full((3, 3), 5))
-    #     [b, c] = self.sds.source(self.source_path, "test", True).func2(arr)
-    #     res = c.sum().compute()
-    #     self.assertTrue(res == 10*10*4)
+    def test_lm_simple(self):
+        a = self.sds.rand(500, 10, -100, 100, pdf="normal", seed=10)
+        features = a  # training data all not outliers
+
+        notOutliers = self.sds.rand(10, 10, -1, 1,  seed=10)  # inside a
+        outliers = self.sds.rand(10, 10, 1150, 1200, seed=10)  # outliers
+
+        test = outliers.rbind(notOutliers)  # testing data half outliers
+
+        n_gaussian = 4
+
+        [_, _, _, _, mu, precision_cholesky, wight] = gmm(
+            features, False, n_components=n_gaussian, seed=10)
+
+        [_, pp] = gmmPredict(
+            test, wight, mu, precision_cholesky, model=self.sds.scalar("VVV"))
+
+        outliers = pp.max(axis=1) < 0.99
+        ret = outliers.compute()
+
+        self.assertTrue(ret.sum() == 10)
+
+
+if __name__ == "__main__":
+    unittest.main(exit=False)
diff --git a/src/main/python/tests/source/test_source_list.py b/src/main/python/tests/source/test_source_list.py
index d4ab391..02e63e9 100644
--- a/src/main/python/tests/source/test_source_list.py
+++ b/src/main/python/tests/source/test_source_list.py
@@ -50,7 +50,7 @@ class TestSource_01(unittest.TestCase):
         [a, b, c] = scale(m, True, True)
         arr = self.sds.array(a, b, c)
         c = self.sds.source(self.source_path, "test").func(arr)
-        res = c.sum().compute(verbose=True)
+        res = c.sum().compute()
         self.assertTrue(res == 0)
 
     # [SYSTEMDS-3224] https://issues.apache.org/jira/browse/SYSTEMDS-3224