You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by kh...@apache.org on 2021/01/29 00:04:18 UTC

[madlib] 01/02: utilities: Add new function for getting data distribution per segment

This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit f833ac08c59e576d4cdd80b3ad4c4115bfa6f93d
Author: Ekta Khanna <ek...@pivotal.io>
AuthorDate: Tue Jan 19 17:12:15 2021 -0800

    utilities: Add new function for getting data distribution per segment
    
    JIRA: MADLIB-1463
    
    This commit adds a new function which returns a list with count of
    segments on each host that the input table's data is distributed on.
    
    This function will be useful for the deep learning module, where using
    the image preprocessor, the user can choose to distribute the data only to
    a part of the cluster.
    
    Co-authored-by: Nikhil Kak <nk...@vmware.com>
---
 .../deep_learning/madlib_keras_helper.py_in        | 36 ++++++++++++++
 .../test/unit_tests/test_madlib_keras.py_in        | 58 ++++++++++++++++++++++
 2 files changed, 94 insertions(+)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 2dd17aa..735f1b2 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -410,3 +410,39 @@ def generate_row_string(configs_dict):
     if result_row_string[0] == ',':
         return result_row_string[1:]
     return result_row_string
+
+def get_data_distribution_per_segment(table_name):
+    """
+    Returns a list with count of segments on each host that the input
+    table's data is distributed on.
+    :param table_name: input table name
+    :return: len(return list) = total num of segments in cluster
+    Each index of the array/list represents a segment of the cluster. If the data
+    is not distributed on that segment, then that index's value will be set to zero.
+    Otherwise the value will be set to the count of segments that have the data on
+    that segment's host.
+    For e.g. If there are 2 hosts and 3 segs per host
+    host1 - seg0, seg1, seg2
+    host2 - seg3, seg4, seg5
+    If the data is distributed on seg0, seg1 and seg3 then the return value will be
+    [2,2,0,1,0,0]
+    """
+    if is_platform_pg():
+        return [1]
+    else:
+        res = plpy.execute("""
+                    WITH cte AS (SELECT DISTINCT(gp_segment_id)
+                                 FROM {table_name})
+                    SELECT content, count as cnt
+                        FROM gp_segment_configuration
+                        JOIN (SELECT hostname, count(*)
+                              FROM gp_segment_configuration
+                              WHERE content in (SELECT * FROM cte)
+                              GROUP BY hostname) a
+                        USING (hostname)
+                        WHERE content in (SELECT * FROM cte)
+                    ORDER BY 1""".format(table_name=table_name))
+        data_distribution_per_segment = [0] * get_seg_number()
+        for r in res:
+            data_distribution_per_segment[r['content']] = int(r['cnt'])
+        return data_distribution_per_segment
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 64395eb..31a61a8 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1629,6 +1629,64 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
             "This is not valid PostgresSQL: SELECT {}[1]".format(metrics)
         )
 
+    def test_get_data_distribution_per_segment_all_segments(self):
+        # config: 3 hosts 3 seg per host
+        # data: all segments on all hosts
+        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.get_seg_number = Mock(return_value=9)
+        self.plpy_mock_execute.side_effect = \
+            [[ {'content': 0, 'cnt' : 3},
+               {'content': 1, 'cnt' : 3},
+               {'content': 2, 'cnt' : 3},
+               {'content': 3, 'cnt' : 3},
+               {'content': 4, 'cnt' : 3},
+               {'content': 5, 'cnt' : 3},
+               {'content': 6, 'cnt' : 3},
+               {'content': 7, 'cnt' : 3},
+               {'content': 8, 'cnt' : 3}
+               ]]
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([3,3,3,3,3,3,3,3,3],res)
+
+    def test_get_data_distribution_per_segment_on_some_hosts(self):
+        # config: 3 hosts 3 seg per host
+        # data: all segments on 2 hosts
+        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.get_seg_number = Mock(return_value=9)
+        self.plpy_mock_execute.side_effect = \
+            [[ {'content': 0, 'cnt' : 3},
+               {'content': 1, 'cnt' : 3},
+               {'content': 2, 'cnt' : 3},
+               {'content': 3, 'cnt' : 3},
+               {'content': 4, 'cnt' : 3},
+               {'content': 5, 'cnt' : 3}
+               ]]
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([3,3,3,3,3,3,0,0,0],res)
+
+    def test_get_data_distribution_per_segment_some_segments(self):
+        # config: 3 hosts 3 seg per host
+        # data: all seg host 1, 2 seg on host 2 and 1 seg on host 3
+        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.get_seg_number = Mock(return_value=9)
+        self.plpy_mock_execute.side_effect = \
+            [[ {'content': 0, 'cnt' : 3},
+               {'content': 1, 'cnt' : 3},
+               {'content': 2, 'cnt' : 3},
+               {'content': 3, 'cnt' : 2},
+               {'content': 4, 'cnt' : 2},
+               {'content': 8, 'cnt' : 1}
+               ]]
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([3,3,3,2,2,0,0,0,1],res)
+
+    def test_get_data_distribution_per_segment_some_segments(self):
+        # config: 3 hosts 3 seg per host
+        # data: all seg host 1, 2 seg on host 2 and 1 seg on host 3
+        self.subject.is_platform_pg = Mock(return_value = True)
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([1],res)
+
 class MadlibKerasEvaluationMergeFinalTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')