You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2018/07/23 18:12:09 UTC

madlib git commit: DT/RF: Ensure cat features are recorded per group

Repository: madlib
Updated Branches:
  refs/heads/master ebd453cbb -> 0f7834e9e


DT/RF: Ensure cat features are recorded per group

JIRA: MADLIB-1254

If tree_train/forest_train is run with grouping enabled and if one of
the groups has a categorical feature with just single level, then the
categorical feature is eliminated for that group. If other groups retain
that feature, then we end up with incorrect "bins" data structure built
as part of DT.

This commit fixes this issue by recording the categorical features
present in each group separately.

Closes #296


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0f7834e9
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0f7834e9
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0f7834e9

Branch: refs/heads/master
Commit: 0f7834e9e9c8c7a26e9041f6a74725d76323bd35
Parents: ebd453c
Author: Rahul Iyer <ri...@apache.org>
Authored: Mon Jul 23 11:10:48 2018 -0700
Committer: Orhan Kislal <ok...@pivotal.io>
Committed: Mon Jul 23 11:10:48 2018 -0700

----------------------------------------------------------------------
 .../recursive_partitioning/decision_tree.py_in  | 191 ++++++++++++++-----
 .../recursive_partitioning/random_forest.py_in  |  52 ++---
 2 files changed, 155 insertions(+), 88 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/0f7834e9/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
index 57c3025..89acd8a 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
@@ -14,27 +14,31 @@ from operator import itemgetter
 from itertools import groupby
 from collections import Iterable
 
-from validation.cross_validation import cross_validation_grouping_w_params
+from internal.db_utils import quote_literal
+
 from utilities.control import MinWarning
 from utilities.control import OptimizerControl
 from utilities.control import HashaggControl
-from utilities.validate_args import get_cols
-from utilities.validate_args import get_cols_and_types
-from utilities.validate_args import _get_table_schema_names
-from utilities.validate_args import get_expr_type
-from utilities.validate_args import table_exists
-from utilities.validate_args import table_is_empty
-from utilities.validate_args import columns_exist_in_table
-from utilities.validate_args import is_var_valid
-from utilities.validate_args import unquote_ident
 from utilities.utilities import _assert
 from utilities.utilities import _array_to_string
-from utilities.utilities import extract_keyvalue_params
-from utilities.utilities import unique_string
 from utilities.utilities import add_postfix
+from utilities.utilities import extract_keyvalue_params
 from utilities.utilities import is_psql_numeric_type, is_psql_boolean_type
-from utilities.utilities import split_quoted_delimited_str
 from utilities.utilities import py_list_to_sql_string
+from utilities.utilities import split_quoted_delimited_str
+from utilities.utilities import unique_string
+
+from utilities.validate_args import _get_table_schema_names
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import get_cols
+from utilities.validate_args import get_cols_and_types
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import is_var_valid
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import table_exists
+from utilities.validate_args import unquote_ident
+
+from validation.cross_validation import cross_validation_grouping_w_params
 # ------------------------------------------------------------
 
 
@@ -265,6 +269,7 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion,
 
     dep_n_levels = len(dep_list) if dep_list else 1
 
+    cat_features_info_table = unique_string()
     if not grouping_cols:   # non-grouping case
         # 3)  Find the splitting bins, one dict containing two arrays:
         #       categorical bins and continuous bins
@@ -276,12 +281,15 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion,
         cat_features = bins['cat_features']
         if not cat_features and not con_features:
             plpy.error("Decision tree: None of the input features are valid")
+        _create_cat_features_info_table(cat_features_info_table, bins)
 
         # 4) Run tree train till the training is finished
         #  finished: 0 = running, 1 = finished training, 2 = terminated prematurely
         tree = _tree_train_using_bins(**locals())
         tree['grp_key'] = ''
         tree['cp'] = grp_key_to_cp[tree['grp_key']]
+        tree['cat_features'] = cat_features
+        tree['con_features'] = con_features
         tree_states = [tree]
     else:
         grouping_array_str = get_grouping_array_str(training_table_name, grouping_cols)
@@ -305,16 +313,20 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion,
                 if not cat_features and not con_features:
                     plpy.error("Decision tree: None of the input features "
                                "are valid for some groups")
+                _create_cat_features_info_table(cat_features_info_table, bins)
 
                 # 3b) Load each group's tree state in memory and set to the initial tree
                 tree_states = _tree_train_grps_using_bins(**locals())
                 for tree in tree_states:
+                    grp_key = tree['grp_key']
                     if len(grp_key_to_cp.values()) == 1:
                         # for train w/out CV, the cp value remains the same for
                         # all groups. This is passed as a single-element list.
                         tree['cp'] = grp_key_to_cp.values()[0]
                     else:
-                        tree['cp'] = grp_key_to_cp[tree['grp_key']]
+                        tree['cp'] = grp_key_to_cp[grp_key]
+                    tree['cat_features'] = bins['grp_to_cat_features'][grp_key]
+                    tree['con_features'] = bins['con_features']
 
     # 5) prune the tree using provided 'cp' value and produce a list of
     #   cp values if cross-validation is required (cp_list = [] if not)
@@ -333,10 +345,10 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion,
 
         importance_vectors = _compute_var_importance(
             schema_madlib, tree,
-            len(cat_features), len(con_features))
+            len(tree['cat_features']), len(tree['con_features']))
         tree.update(**importance_vectors)
 
-    return tree_states, bins, dep_list, n_rows
+    return tree_states, bins, dep_list, n_rows, cat_features_info_table
 # -------------------------------------------------------------------------
 
 
@@ -387,7 +399,8 @@ def _build_tree(schema_madlib, is_classification, split_criterion,
 
     with MinWarning(msg_level):
         plpy.notice("Building tree for cross validation")
-        tree_states, bins, dep_list, n_rows = _get_tree_states(**locals())
+        tree_states, bins, dep_list, n_rows, cat_features_info_table = \
+            _get_tree_states(**locals())
         all_cols_types = dict([(f, get_expr_type(f, training_table_name))
                                for f in cat_features + con_features])
 
@@ -509,7 +522,8 @@ def tree_train(schema_madlib, training_table_name, output_table_name,
         grp_key_to_cp = {'': cp}
         # main training function to get trained decision trees
         plpy.notice("Getting initial tree")
-        tree_states, bins, dep_list, n_rows = _get_tree_states(**locals())
+        tree_states, bins, dep_list, n_rows, cat_features_info_table = \
+            _get_tree_states(**locals())
 
         # 5) Perform cross-validation to compute the lowest cp
         dep_n_levels = len(dep_list) if dep_list else 1
@@ -528,7 +542,8 @@ def _create_output_tables(schema_madlib, training_table_name, output_table_name,
                           id_col_name, dependent_variable, list_of_features,
                           list_of_features_to_exclude,
                           is_classification, n_all_rows, n_rows, dep_list, cp,
-                          all_cols_types, grouping_cols=None,
+                          all_cols_types, cat_features_info_table,
+                          grouping_cols=None,
                           use_existing_tables=False, running_cv=False,
                           n_folds=0, null_proxy=None, **kwargs):
     if not grouping_cols:
@@ -539,8 +554,9 @@ def _create_output_tables(schema_madlib, training_table_name, output_table_name,
     else:
         _create_grp_result_table(
             schema_madlib, tree_states, bins, bins['cat_features'],
-            bins['con_features'], output_table_name, grouping_cols,
-            training_table_name, use_existing_tables, running_cv, n_folds)
+            bins['con_features'], output_table_name, cat_features_info_table,
+            grouping_cols, training_table_name, use_existing_tables,
+            running_cv, n_folds)
 
     failed_groups = sum(row['finished'] != 1 for row in tree_states)
     _create_summary_table(
@@ -1005,7 +1021,8 @@ def _get_bins_grps(
         if len(use_cat_features) != len(cat_features):
             plpy.warning("Decision tree warning: Categorical columns with only "
                          "one value are dropped from the tree model.")
-            cat_features = [feature for feature in cat_features if feature in use_cat_features]
+            cat_features = [feature for feature in cat_features
+                            if feature in use_cat_features]
 
         # grp_col_to_levels is a list of tuples (pairs) with
         #   first value = group value,
@@ -1023,7 +1040,8 @@ def _get_bins_grps(
         grp_to_col_to_levels = [
             (grp_key, dict((row['colname'], row['levels']) for row in items))
             for grp_key, items in groupby(all_levels, key=itemgetter('grp_key'))]
-    if cat_features:
+        grp_to_cat_features = dict([(g, col_to_levels.keys())
+                                    for (g, col_to_levels) in grp_to_col_to_levels])
         # Below statements collect the grp_to_col_to_levels into multiple variables
         # From above eg.
         #   cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]]
@@ -1039,7 +1057,11 @@ def _get_bins_grps(
     else:
         cat_n = []
         cat_origin = []
-        grp_key_cat=[con_splits['grp_key'] for con_splits in con_splits_all]
+        grp_key_cat = [con_splits['grp_key'] for con_splits in con_splits_all]
+        grp_to_col_to_levels = [(con_splits['grp_key'], dict())
+                                for con_splits in con_splits_all]
+        grp_to_cat_features = dict([(con_splits['grp_key'], list())
+                                   for con_splits in con_splits_all])
 
     if con_features:
         con = [con_splits['con_splits'] for con_splits in con_splits_all]
@@ -1055,10 +1077,80 @@ def _get_bins_grps(
                 cat_n=cat_n,
                 cat_features=cat_features,
                 grp_key_cat=grp_key_cat,
-                grouping_array_str=grouping_array_str)
+                grouping_array_str=grouping_array_str,
+                grp_to_col_to_levels=grp_to_col_to_levels,
+                grp_to_cat_features=grp_to_cat_features)
 # ------------------------------------------------------------
 
 
+def _create_cat_features_info_table(cat_features_info_table, bins):
+    # bins['grp_to_col_to_levels'] =
+    #   [
+    #       ('3', {'vs': [0, 1], 'cyl': [4,6,8]}),
+    #       ('4', {'vs': [0, 1], 'cyl': [4,6]}),
+    #       ('5', {'vs': [0, 1]})
+    #   ]
+    #  Convert this into a VALUES command and place in a table
+    #      VALUES (('3', ARRAY[2, 3], ARRAY['0', '1', '4', '6', '8']),
+    #              ('4', ARRAY[2, 2], ARRAY['0', '1', '4', '6']),
+    #              ('5', ARRAY[2], ARRAY['0', '1']),
+    #             )
+    cat_features_info_values = []
+    if 'grp_to_col_to_levels' in bins:
+        # Grouping enabled, implies the cat levels can be different for
+        # different groups
+        for i, (grp_key, col_to_levels) in enumerate(bins['grp_to_col_to_levels'], start=1):
+            grp_key_str = quote_literal(grp_key)
+            cat_names_levels = [(c, col_to_levels[c]) for c in bins['cat_features']
+                                if c in col_to_levels]
+            if cat_names_levels:
+                cat_names, cat_levels = zip(*cat_names_levels)
+                # categorical features in current group (expressed in an array)
+                cat_names_str = py_list_to_sql_string(
+                    map(quote_literal, cat_names), 'text', long_format=True)
+                # number of levels in each cat feature
+                cat_n_levels_str = py_list_to_sql_string(
+                    map(len, cat_levels), 'integer', long_format=True)
+                # flatten the levels across all cat features
+                cat_levels = [quote_literal(each_level)
+                              for sublist in cat_levels
+                              for each_level in sublist]
+                cat_levels_str = py_list_to_sql_string(cat_levels, 'text', long_format=True)
+            else:
+                # this is the case if no categorical features present
+                cat_names_str = cat_n_levels_str = cat_levels_str = "NULL"
+
+            cat_features_info_values.append(
+                "({i}::INTEGER, {grp_key_str}::TEXT, {cat_names_str}::TEXT[], "
+                "{cat_n_levels_str}::INTEGER[], {cat_levels_str}::TEXT[])".
+                format(**locals()))
+    else:
+        # no grouping
+        if bins['cat_features']:
+            cat_names_str = py_list_to_sql_string(
+                map(quote_literal, bins['cat_features']), 'text', long_format=True)
+            cat_n_levels_str = py_list_to_sql_string(bins['cat_n'], 'integer', long_format=True)
+            cat_levels_str = py_list_to_sql_string(
+                map(quote_literal, bins['cat_origin']), 'text', long_format=True)
+        else:
+            cat_names_str = cat_n_levels_str = cat_levels_str = "NULL"
+        cat_features_info_values.append(
+            "(1::INTEGER, ''::TEXT, {0}::TEXT[], {1}::INTEGER[], {2}::TEXT[])".
+            format(cat_names_str, cat_n_levels_str, cat_levels_str))
+
+    sql_cat_features_info = """
+            CREATE TEMP TABLE {0} AS
+            SELECT *
+            FROM (
+                VALUES {1}
+            ) AS q(gid, grp_key, cat_names, cat_n_levels, cat_levels_in_text)
+            """.format(cat_features_info_table,
+                       ',\n'.join(cat_features_info_values))
+    plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info)
+    plpy.execute(sql_cat_features_info.format(**locals()))
+# ------------------------------------------------------------------------------
+
+
 def get_feature_str(schema_madlib, boolean_cats,
                     cat_features, con_features,
                     levels_str, n_levels_str,
@@ -1194,7 +1286,7 @@ def _one_step_for_grps(
         con_features, boolean_cats, bins, n_bins, tree_states, weights,
         grouping_cols, grouping_array_str, dep_var, min_split, min_bucket,
         max_depth, filter_null, dep_n_levels, subsample, n_random_features,
-        max_n_surr=0, null_proxy=None):
+        cat_features_info_table, max_n_surr=0, null_proxy=None):
     """ One step of trees training with grouping support
     """
     # The function _map_catlevel_to_int maps a categorical variable value to its
@@ -1249,12 +1341,12 @@ def _one_step_for_grps(
                     FROM
                         {training_table_name} as src,
                         ( SELECT
-                            grp_key            AS {grp_key},
-                            finished           AS {finished},
-                            tree_state         AS {tree_state},
-                            con_splits         AS {con_splits},
-                            cat_n_levels       AS {cat_n_levels},
-                            cat_levels_in_text AS {cat_levels_in_text}
+                            grp_key                    AS {grp_key},
+                            finished                   AS {finished},
+                            tree_state                 AS {tree_state},
+                            con_splits                 AS {con_splits},
+                            cat_n_levels::INTEGER[]    AS {cat_n_levels},
+                            cat_levels_in_text::TEXT[] AS {cat_levels_in_text}
                           FROM
                             (   SELECT
                                     unnest($1) AS grp_key,
@@ -1264,11 +1356,11 @@ def _one_step_for_grps(
                             JOIN (
                                 SELECT
                                     unnest($4) AS grp_key,
-                                    unnest($9) AS con_splits
+                                    unnest($5) AS con_splits
                             ) AS con_splits
                             USING (grp_key)
                             JOIN
-                                {schema_madlib}._gen_cat_levels_set($5, $6, $7, $8) AS cat_levels
+                                {cat_features_info_table}
                             USING (grp_key)
                         ) AS needed_data
                     WHERE {grouping_array_str} = {grp_key}
@@ -1286,21 +1378,20 @@ def _one_step_for_grps(
                             JOIN
                              (  SELECT
                                     unnest($4) AS grp_key,
-                                    unnest($9) AS con_splits
+                                    unnest($5) AS con_splits
                              ) AS con_splits
                             USING (grp_key)
                     ) s2
             USING (grp_key)
     """
-    train_sql = "SELECT grp_key, (result).* from (" + sql + ") sub"
+    train_sql = "SELECT grp_key, (result).* FROM (" + sql + ") sub"
     train_sql = train_sql.format(aggregate=train_aggregate,
                                  apply_func=train_apply_func,
                                  # check_finished="AND " + finished + " = 0",
                                  **locals())
     train_sql_plan = plpy.prepare(train_sql,
-                                  ['text[]', 'integer[]', bytea8arr, 'text[]',
-                                   'text[]', 'integer[]', 'integer', 'text[]',
-                                   bytea8arr])
+                                  ['text[]', 'integer[]', bytea8arr,
+                                   'text[]', bytea8arr])
 
     unfinished_trees = [t for t in tree_states if t['finished'] == 0]
     finished_trees = [t for t in tree_states if t['finished'] != 0]
@@ -1312,10 +1403,6 @@ def _one_step_for_grps(
         [t['finished'] for t in unfinished_trees],
         [t['tree_state'] for t in unfinished_trees],
         bins['grp_key_con'],
-        bins['grp_key_cat'],
-        bins['cat_n'],
-        len(cat_features),
-        bins['cat_origin'],
         bins['con']]))
 
     if max_n_surr > 0:
@@ -1347,17 +1434,12 @@ def _one_step_for_grps(
                               **locals())
         surr_sql_plan = plpy.prepare(surr_sql,
                                      ['text[]', 'integer[]', bytea8arr,
-                                      'text[]', 'text[]', 'integer[]', 'integer',
                                       'text[]', bytea8arr])
         surr_trees = list(plpy.execute(surr_sql_plan, [
             [t['grp_key'] for t in updated_unfinished],
             [t['finished'] for t in updated_unfinished],
             [t['tree_state'] for t in updated_unfinished],
             bins['grp_key_con'],
-            bins['grp_key_cat'],
-            bins['cat_n'],
-            len(cat_features),
-            bins['cat_origin'],
             bins['con']]))
 
         surr_dict = dict()
@@ -1376,7 +1458,8 @@ def _one_step_for_grps(
 
 def _create_grp_result_table(
         schema_madlib, tree_states, bins, cat_features,
-        con_features, output_table_name, grouping_cols,
+        con_features, output_table_name, cat_features_info_table,
+        grouping_cols,
         training_table_name, use_existing_tables=False,
         running_cv=False, k=0):
     """ Create the output table for grouping case.
@@ -1435,7 +1518,7 @@ def _create_grp_result_table(
                         cat_n_levels as {cat_n_levels},
                         cat_levels_in_text as {cat_levels_in_text}
                     FROM
-                        {schema_madlib}._gen_cat_levels_set($6, $7, $8, $9)
+                        {cat_features_info_table}
                 ) s3
                 USING ({grp_key})
             """
@@ -1920,7 +2003,7 @@ def _compute_var_importance(schema_madlib, tree,
 
         Args:
             @param schema_madlib: str, MADlib schema name
-            @param tree: Tree data to prune
+            @param tree: dict. tree['tree_state'] is the trained tree (in byte form)
             @param n_cat_features: int, Number of categorical features
             @param n_con_features: int, Number of continuous features
 
@@ -2098,7 +2181,7 @@ def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_nam
                 tree['pruned_depth'] = 0
         importance_vectors = _compute_var_importance(
             schema_madlib, tree,
-            len(cat_features), len(con_features))
+            len(tree['cat_features']), len(tree['con_features']))
         tree.update(**importance_vectors)
 
     plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals()))
@@ -2235,6 +2318,7 @@ def _tree_train_grps_using_bins(
         grouping_cols, grouping_array_str,
         dep_var_str, min_split, min_bucket, max_depth, filter_dep,
         dep_n_levels, is_classification, split_criterion,
+        cat_features_info_table,
         subsample=False, n_random_features=1, tree_terminated=None,
         max_n_surr=0, null_proxy=None,
         **kwargs):
@@ -2279,7 +2363,8 @@ def _tree_train_grps_using_bins(
             tree_states, weights, grouping_cols,
             grouping_array_str, dep_var_str, min_split, min_bucket,
             max_depth, filter_dep, dep_n_levels, subsample,
-            n_random_features, max_n_surr, null_proxy)
+            n_random_features, cat_features_info_table,
+            max_n_surr, null_proxy)
         level += 1
         plpy.notice("Finished training for level " + str(level))
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/0f7834e9/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index a048fa1..c06bed8 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -14,17 +14,18 @@ from utilities.control import MinWarning
 from utilities.control import OptimizerControl
 from utilities.control import HashaggControl
 from utilities.utilities import _assert
-from utilities.utilities import unique_string
 from utilities.utilities import add_postfix
-from utilities.utilities import split_quoted_delimited_str
 from utilities.utilities import extract_keyvalue_params
 from utilities.utilities import py_list_to_sql_string
+from utilities.utilities import split_quoted_delimited_str
+from utilities.utilities import unique_string
+
+from utilities.validate_args import cols_in_tbl_valid
 from utilities.validate_args import get_cols_and_types
-from utilities.validate_args import is_var_valid
+from utilities.validate_args import get_expr_type
 from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import is_var_valid
 from utilities.validate_args import output_tbl_valid
-from utilities.validate_args import cols_in_tbl_valid
-from utilities.validate_args import get_expr_type
 
 from decision_tree import _tree_train_using_bins
 from decision_tree import _tree_train_grps_using_bins
@@ -39,6 +40,7 @@ from decision_tree import _get_filter_str
 from decision_tree import _get_display_header
 from decision_tree import get_feature_str
 from decision_tree import _compute_var_importance
+from decision_tree import _create_cat_features_info_table
 # ------------------------------------------------------------
 
 
@@ -265,8 +267,7 @@ def forest_train(
         @param verbose: str, Verbosity of output messages
         @param sample_ratio: float, subsampling ratio for generating src_view
     """
-    msg_level = "'notice'" if verbose else "'warning'"
-
+    msg_level = "notice" if verbose else "warning"
     with MinWarning(msg_level):
         with OptimizerControl(False):
             # we disable optimizer (ORCA) for platforms that use it
@@ -430,31 +431,9 @@ def forest_train(
                                           is_classification, dep_n_levels, filter_null, null_proxy)
                     cat_features = bins['cat_features']
 
-                # a table for converting cat_features to integers
+                # a table for getting information of cat features for each group
                 cat_features_info_table = unique_string()
-                sql_cat_features_info = """
-                        CREATE TEMP TABLE {cat_features_info_table} AS
-                        SELECT
-                            gid,
-                            cat_n_levels,
-                            cat_levels_in_text
-                        FROM
-                        (
-                            SELECT *
-                            FROM {schema_madlib}._gen_cat_levels_set($1, $2, $3, $4)
-                        ) subq
-                        JOIN
-                            {grp_key_to_grp_cols}
-                        USING (grp_key)
-                        """.format(**locals())
-                plpy.notice("sql_cat_features_info:\n" + sql_cat_features_info)
-                plan_cat_features_info = plpy.prepare(
-                    sql_cat_features_info, ['text[]', 'integer[]', 'integer', 'text[]'])
-                plpy.execute(plan_cat_features_info, [
-                    bins['grp_key_cat'],
-                              bins['cat_n'],
-                              len(cat_features),
-                              bins['cat_origin']])
+                _create_cat_features_info_table(cat_features_info_table, bins)
 
                 con_splits_table = unique_string()
                 _create_con_splits_table(schema_madlib, con_splits_table,
@@ -587,8 +566,11 @@ def forest_train(
                             boolean_cats, num_bins, 'poisson_count', grouping_cols,
                             grouping_array_str, dep, min_split, min_bucket,
                             max_tree_depth, filter_null, dep_n_levels,
-                            is_classification, split_criterion, True,
-                            num_random_features, tree_terminated=tree_terminated,
+                            is_classification, split_criterion,
+                            cat_features_info_table,
+                            subsample=True,
+                            n_random_features=num_random_features,
+                            tree_terminated=tree_terminated,
                             max_n_surr=max_n_surr, null_proxy=null_proxy)
 
                         # If a tree for a group is terminated (not finished properly),
@@ -966,7 +948,7 @@ def _calculate_oob_prediction(
                         1 -- -1 shifted to 0 for null values
                     ),
                     {schema_madlib}.array_scalar_add(
-                        cat_n_levels,
+                        cat_n_levels::integer[],
                         1 -- -1 shifted to 0 for null values
                     )
                 ) AS cat_feature_distributions,
@@ -1024,7 +1006,7 @@ def _calculate_oob_prediction(
                     tree,
                     {cat_features_str}::integer[],
                     {con_features_str}::double precision[],
-                    cat_info.cat_n_levels,
+                    cat_info.cat_n_levels::integer[],
                     {num_permutations},
                     {dep},
                     {is_classification},