You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2019/09/25 00:21:07 UTC

[GitHub] [madlib] khannaekta commented on a change in pull request #443: DL: Add training for multiple models

khannaekta commented on a change in pull request #443: DL: Add training for multiple models
URL: https://github.com/apache/madlib/pull/443#discussion_r327871762
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
 ##########
 @@ -0,0 +1,489 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+import time
+import sys
+
+from keras.models import *
+from madlib_keras import compute_loss_and_metrics
+from madlib_keras import get_initial_weights
+from madlib_keras import get_model_arch_weights
+from madlib_keras import get_segments_and_gpus
+from madlib_keras import get_source_summary_table_dict
+from madlib_keras_helper import *
+from madlib_keras_model_selection import ModelSelectionSchema
+from madlib_keras_validator import *
+from madlib_keras_wrapper import *
+
+from utilities.control import MinWarning
+from utilities.control import OptimizerControl
+from utilities.utilities import unique_string
+from utilities.utilities import add_postfix
+from utilities.utilities import rotate
+from utilities.utilities import madlib_version
+from utilities.utilities import is_platform_pg
+import json
+from collections import defaultdict
+import random
+import datetime
+mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+dist_key_col = DISTRIBUTION_KEY_COLNAME
+
+
+@MinWarning("warning")
+class FitMultipleModel():
+    def __init__(self, schema_madlib, source_table, model_output_table,
+                 model_selection_table, num_iterations,
+                 gpus_per_host=0, validation_table=None, **kwargs):
+        # set the random seed for visit order/scheduling
+        random.seed(1)
+        if is_platform_pg():
+            plpy.error(
+                "DL: Multiple model training is not supported on PostgreSQL.")
+        self.source_table = source_table
+        self.validation_table = validation_table
+        self.model_selection_table = model_selection_table
+        if self.model_selection_table:
+            self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(model_output_table, '_info')
+            self.model_summary_table = add_postfix(
+                model_output_table, '_summary')
+        self.num_iterations = num_iterations
+        self.module_name = 'madlib_keras_fit_multiple_model'
+        self.schema_madlib = schema_madlib
+        self.version = madlib_version(self.schema_madlib)
+        self.mst_key_col = ModelSelectionSchema.MST_KEY
+        self.model_id_col = ModelSelectionSchema.MODEL_ID
+        self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS
+        self.fit_params_col = ModelSelectionSchema.FIT_PARAMS
+        self.model_arch_table_col = ModelSelectionSchema.MODEL_ARCH_TABLE
+        self.model_weights_col=ModelArchSchema.MODEL_WEIGHTS
+        self.model_arch_col=ModelArchSchema.MODEL_ARCH
+        self.train_mst_metric_eval_time = defaultdict(list)
+        self.train_mst_loss = defaultdict(list)
+        self.train_mst_metric = defaultdict(list)
+        self.info_str = ""
+        self.seg_ids_train, self.images_per_seg_train = \
+            get_image_count_per_seg_for_minibatched_data_from_db(
+                self.source_table)
+        input_tbl_valid(self.model_selection_table, self.module_name)
+        input_tbl_valid(self.model_selection_summary_table, self.module_name,
+                        error_suffix_str="Please ensure that the model selection table ({0}) "
+                                         "has been created by "
+                                         "load_model_selection_table().".format(self.model_selection_table))
+        self.msts, self.model_arch_table = self.query_model_configs()
+        self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
+        self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
+        self.fit_validator_train = FitMultipleInputValidator(
+            self.source_table, self.validation_table, self.model_output_table,
+            self.model_arch_table, mb_dep_var_col, mb_indep_var_col,
+            self.num_iterations, 1, False)
+
+        output_tbl_valid(self.model_info_table, self.module_name)
+        if self.validation_table:
+            self.valid_mst_metric_eval_time = defaultdict(list)
+            self.valid_mst_loss = defaultdict(list)
+            self.valid_mst_metric = defaultdict(list)
+            self.seg_ids_valid, self.images_per_seg_valid = \
+                get_image_count_per_seg_for_minibatched_data_from_db(
+                    self.validation_table)
+        self.mst_weights_tbl = unique_string(desp='mst_weights')
+        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+
+        self.dist_keys = self.query_dist_keys()
+        if len(self.msts) < len(self.dist_keys):
+            self.msts_for_schedule = self.msts + [None] * \
+                (len(self.dist_keys) - len(self.msts))
+        else:
+            self.msts_for_schedule = self.msts
+        random.shuffle(self.msts_for_schedule)
+        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
+        self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
+            gpus_per_host)
+        self.create_model_output_table()
+        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
+        self.fit_multiple_model()
+
+    def fit_multiple_model(self):
+        # WARNING: set orca off to prevent unwanted redistribution
+        with OptimizerControl(False):
+            original_cuda_env = None
+            if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+                original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+            self.start_training_time = datetime.datetime.now()
+            self.train_multiple_model()
+            self.end_training_time = datetime.datetime.now()
+            self.insert_info_table()
+            self.create_model_summary_table()
+        reset_cuda_env(original_cuda_env)
+
+    def train_multiple_model(self):
+        total_msts = len(self.msts_for_schedule)
+        for e in range(self.num_iterations):
+            for i in range(total_msts):
+                mst_row = [self.grand_schedule[dist_key][i]
+                           for dist_key in self.dist_keys]
+                self.create_mst_schedule_table(mst_row)
+                if i == 0:
+                    start_iteration = time.time()
+                self.run_training()
+                if i == (total_msts - 1):
+                    end_iteration = time.time()
+                    self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(e,
+                                                                                      end_iteration - start_iteration)
+            self.info_str += "\tTraining set after iteration {0}:".format(e)
+            self.evaluate_model(e, self.source_table, True)
+            if self.validation_table:
+                self.evaluate_model(e, self.validation_table, False)
+            plpy.info("\n"+self.info_str)
+
+    def evaluate_model(self, epoch, table, is_train):
+        if is_train:
+            mst_metric_eval_time = self.train_mst_metric_eval_time
+            mst_loss = self.train_mst_loss
+            mst_metric = self.train_mst_metric
+            seg_ids = self.seg_ids_train
+            images_per_seg = self.images_per_seg_train
+        else:
+            mst_metric_eval_time = self.valid_mst_metric_eval_time
+            mst_loss = self.valid_mst_loss
+            mst_metric = self.valid_mst_metric
+            seg_ids = self.seg_ids_valid
+            images_per_seg = self.images_per_seg_valid
+        for mst in self.msts:
+            state = self.query_weights(mst[self.mst_key_col])
+            model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
+            serialized_weights = \
+                madlib_keras_serializer.get_serialized_1d_weights_from_state(
+                    state)
+            loss_metric = compute_loss_and_metrics(
+                self.schema_madlib, table, "$madlib${0}$madlib$".format(
+                    mst[self.compile_params_col]),
+                model_arch,
+                serialized_weights,
+                self.gpus_per_host,
+                self.segments_per_host,
+                seg_ids,
+                images_per_seg, [], [], epoch, True)
+            metric_eval_time, metric, loss = loss_metric
+            mst_metric_eval_time[mst[self.mst_key_col]] \
+                .append(metric_eval_time)
+            mst_loss[mst[self.mst_key_col]].append(loss)
+            mst_metric[mst[self.mst_key_col]].append(metric)
+            self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss)
+
+    def generate_schedule(self, msts):
+        grand_schedule = {}
+        for index, dist_key in enumerate(self.dist_keys):
+            grand_schedule[dist_key] = rotate(msts, index)
+        return grand_schedule
+
+    def query_model_configs(self):
+        msts_query = """
+                     SELECT * FROM {self.model_selection_table}
+                     ORDER BY {self.mst_key_col}
+                     """.format(**locals())
+        model_arch_table_query = """
+                                 SELECT {self.model_arch_table_col}
+                                 FROM {self.model_selection_summary_table}
+                                 """.format(**locals())
+        msts = list(plpy.execute(msts_query))
+        model_arch_table = plpy.execute(model_arch_table_query)[0][self.model_arch_table_col]
+        return msts, model_arch_table
+
+    def query_dist_keys(self):
+        dist_key_query = """
+                         SELECT DISTINCT({dist_key_col}) FROM {source_table}
+                         ORDER BY {dist_key_col}
+                         """.format(dist_key_col=dist_key_col,
+                                    source_table=self.source_table)
+        res = list(plpy.execute(dist_key_query))
+        res = [x[dist_key_col] for x in res]
+        return res
+
+    def create_mst_schedule_table(self, mst_row):
+        mst_temp_query = """
+                         CREATE TEMP TABLE {self.mst_current_schedule_tbl}
+                                ({self.model_id_col} INTEGER,
+                                 {self.compile_params_col} VARCHAR,
+                                 {self.fit_params_col} VARCHAR,
+                                 {dist_key_col} INTEGER,
+                                 {self.mst_key_col} INTEGER)
+                         """.format(dist_key_col=dist_key_col, **locals())
+        plpy.execute(mst_temp_query)
+        for mst, dist_key in zip(mst_row, self.dist_keys):
+            if mst:
+                model_id = mst[self.model_id_col]
+                compile_params = mst[self.compile_params_col]
+                fit_params = mst[self.fit_params_col]
+                mst_key = mst[self.mst_key_col]
+            else:
+                model_id = "NULL"
+                compile_params = "NULL"
+                fit_params = "NULL"
+                mst_key = "NULL"
+            mst_insert_query = """
+                               INSERT INTO {self.mst_current_schedule_tbl}
+                                   VALUES ({model_id},
+                                           $madlib${compile_params}$madlib$,
+                                           $madlib${fit_params}$madlib$,
+                                           {dist_key},
+                                           {mst_key})
+                                """.format(**locals())
+            plpy.execute(mst_insert_query)
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON)
+                                    """.format(self=self)
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({self.mst_key_col} INTEGER PRIMARY KEY,
+                                   {self.model_id_col} INTEGER,
+                                   {self.compile_params_col} TEXT,
+                                   {self.fit_params_col} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[])
+                               """.format(self=self)
+
+        plpy.execute(output_table_create_query)
+        plpy.execute(info_table_create_query)
+        for mst in self.msts:
+            model_arch, model_weights = get_model_arch_weights(self.model_arch_table,
+                                                               mst[self.model_id_col])
+            serialized_weights = get_initial_weights(self.model_output_table,
+                                                     model_arch,
+                                                     model_weights,
+                                                     False,
+                                                     self.gpus_per_host
+                                                     )
+            model = model_from_json(model_arch)
+
+            serialized_state = \
+                madlib_keras_serializer.serialize_state_with_nd_weights(
+                    0, model.get_weights())
+            # serialized_weights = madlib_keras_serializer.serialize_nd_weights(
+            #     model.get_weights())
+            model_size = sys.getsizeof(serialized_weights) / 1024.0
+            metrics_list = get_metrics_from_compile_param(
+                mst[self.compile_params_col])
+            is_metrics_specified = True if metrics_list else False
+            metrics_type = 'ARRAY{0}'.format(
+                metrics_list) if is_metrics_specified else 'NULL'
+
+            output_table_insert_query = """
+                                INSERT INTO {self.model_output_table}(
+                                    {self.mst_key_col}, {self.model_weights_col},
+                                    {self.model_arch_col})
+                                VALUES ({mst_key}, $1, $2)
+                                   """.format(self=self,
+                                              mst_key=mst[self.mst_key_col])
+            output_table_insert_query_prepared = plpy.prepare(
+                output_table_insert_query, ["bytea", "json"])
+            plpy.execute(output_table_insert_query_prepared, [
+                         serialized_state, json.dumps(model_arch)])
+            info_table_insert_query = """
+                    INSERT INTO {self.model_info_table}({self.mst_key_col},
+                                {self.model_id_col}, {self.compile_params_col},
+                                {self.fit_params_col}, model_type, model_size,
+                                metrics_type)
+                        VALUES ({mst_key_val}, {model_id},
+                                $madlib${compile_params}$madlib$,
+                                $madlib${fit_params}$madlib$, '{model_type}',
+                                {model_size}, {metrics_type})
+                """.format(self=self,
+                           mst_key_val=mst[self.mst_key_col],
+                           model_id=mst[self.model_id_col],
+                           compile_params=mst[self.compile_params_col],
+                           fit_params=mst[self.fit_params_col],
+                           model_type='madlib_keras',
+                           model_size=model_size,
+                           metrics_type=metrics_type)
+            plpy.execute(info_table_insert_query)
+
+    def create_model_summary_table(self):
+        src_summary_dict = get_source_summary_table_dict(self.fit_validator_train)
+        class_values = src_summary_dict['class_values']
+        dep_vartype = src_summary_dict['dep_vartype']
+        dependent_varname = \
+            src_summary_dict['dependent_varname_in_source_table']
+        independent_varname = \
+            src_summary_dict['independent_varname_in_source_table']
+        norm_const = src_summary_dict['norm_const']
+        num_classes = len(class_values)
+        class_values_colname = CLASS_VALUES_COLNAME
+        dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
+        normalizing_const_colname = NORMALIZING_CONST_COLNAME
+        float32_sql_type = FLOAT32_SQL_TYPE
+        update_query = """
+                CREATE TABLE {self.model_summary_table} AS
 
 Review comment:
   Yes, in dev-check(madlib_keras_model_selection.sql_in:L194-207).

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services