You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2019/07/24 22:47:13 UTC

[madlib] 01/02: Association Rules: Improve performance

This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 0f6073cc0a03a0c9c216788de46e781df3e68b99
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Thu Jul 11 14:58:09 2019 -0700

    Association Rules: Improve performance
    
    JIRA: MADLIB-1327
    Assocation rules was slow due to a blow up in number of candidate
    itemsets before checking for their support and graduating them to
    frequent itemsets. This PR changes ensures the candidate itemsets are
    ordered, and only a few consecutive itemsets must be considered to merge
    and create a new candidate itemset for checking support. This also
    reduces the result of a join query significantly (previously, the result
    of the join were all potential candidates, which were redundant and
    which may have resulted in itemsets that were larger than what was
    considered in the current iteration).
    This commit also adds relevant tests in dev-check.
    
    Closes #423
    Co-authored-by: Orhan Kislal <ok...@apache.org>
---
 src/modules/assoc_rules/assoc_rules.cpp            |  40 +++-
 .../postgres/modules/assoc_rules/assoc_rules.py_in | 201 +++++++++++++++------
 .../modules/assoc_rules/assoc_rules.sql_in         | 147 ++++++++++++++-
 .../modules/assoc_rules/test/assoc_rules.sql_in    |  90 +++++++--
 4 files changed, 394 insertions(+), 84 deletions(-)

diff --git a/src/modules/assoc_rules/assoc_rules.cpp b/src/modules/assoc_rules/assoc_rules.cpp
index 85d0046..550d7ce 100644
--- a/src/modules/assoc_rules/assoc_rules.cpp
+++ b/src/modules/assoc_rules/assoc_rules.cpp
@@ -23,6 +23,8 @@ typedef struct perm_fctx
     char*    positions;
     int32    pos_len;
     int32    num_elems;
+    int32    max_LHS_size;
+    int32    max_RHS_size;
     int32    num_calls;
 
     /* type information for the result type*/
@@ -38,6 +40,8 @@ typedef struct perm_fctx
  * @param args      Two-element array.
  *                  args[0] is the text form of a closed frequent pattern.
  *                  args[1] is the number of items in the pattern.
+                    args[2] is the max number of elements in the lhs of the rule
+                    args[3] is the max number of elements in the rhs of the rule
  * @param max_call  The number of  will be generated.
  *
  * @return  The struct including the variables which will be used
@@ -54,13 +58,15 @@ gen_rules_from_cfp::SRF_init(AnyType &args) {
     myfctx->positions = positions;
     myfctx->pos_len   = static_cast<int32_t>(strlen(positions));
     myfctx->num_elems = args[1].getAs<int32>();
+    myfctx->max_LHS_size = args[2].getAs<int32>();
+    myfctx->max_RHS_size = args[3].getAs<int32>();
     myfctx->num_calls = (1 << myfctx->num_elems) - 2;
     myfctx->flags     = new bool[myfctx->num_elems];
+
     memset(myfctx->flags, 0, sizeof(bool) * myfctx->num_elems);
     // return type id is TEXTOID, get the related information
     madlib_get_typlenbyvalalign
         (TEXTOID, &myfctx->typlen, &myfctx->typbyval, &myfctx->typalign);
-
     return myfctx;
 }
 
@@ -99,7 +105,9 @@ gen_rules_from_cfp::SRF_next(void *user_fctx, bool *is_last_call) {
         return Null();
     }
 
+    *is_last_call = false;
     // find the next permutation of the closed frequent pattern
+
     for (i = 0; i < myfctx->num_elems; ++i) {
         if (!myfctx->flags[i]) {
             myfctx->flags[i] = true;
@@ -109,6 +117,33 @@ gen_rules_from_cfp::SRF_next(void *user_fctx, bool *is_last_call) {
         }
     }
 
+    // If the target max size is greater than the current number of elements to
+    // consider, there is no need to actually check for lhs or rhs sizes.
+
+    if (myfctx->max_LHS_size <= myfctx->num_elems ||
+        myfctx->max_RHS_size <= myfctx->num_elems){
+
+        int countLHS = 0;
+        int countRHS = 0;
+
+        // flags[i]=True means that element is on the lhs (and vice versa)
+        for (i = 0; i < myfctx->num_elems; ++i) {
+            if (!myfctx->flags[i]) {
+                countRHS ++;
+            } else {
+                countLHS ++;
+            }
+        }
+
+        // If this rule is not viable (one side is larger than the limit)
+        // Reduce the num_calls to indicate that it is processed and
+        // return Null to skip the operation
+        if (countLHS > myfctx->max_LHS_size || countRHS > myfctx->max_RHS_size){
+            --myfctx->num_calls;
+            return Null();
+        }
+    }
+
     pre_text  = new char[myfctx->pos_len];
     post_text = new char[myfctx->pos_len];
     result    = new Datum[2];
@@ -160,12 +195,11 @@ gen_rules_from_cfp::SRF_next(void *user_fctx, bool *is_last_call) {
 
     ArrayHandle<text*> arr(construct_array(result, 2, TEXTOID,
             myfctx->typlen, myfctx->typbyval, myfctx->typalign));
-    
+
     delete[] pre_text;
     delete[] post_text;
 
     --myfctx->num_calls;
-    *is_last_call = false;
     return arr;
 }
 
diff --git a/src/ports/postgres/modules/assoc_rules/assoc_rules.py_in b/src/ports/postgres/modules/assoc_rules/assoc_rules.py_in
index e67d887..6b7a7c9 100644
--- a/src/ports/postgres/modules/assoc_rules/assoc_rules.py_in
+++ b/src/ports/postgres/modules/assoc_rules/assoc_rules.py_in
@@ -11,6 +11,8 @@ import time
 import plpy
 from utilities.validate_args import columns_exist_in_table
 from utilities.validate_args import table_exists
+from utilities.control import MinWarning
+
 
 """
 @brief if the given condition is false, then raise an error with the message
@@ -56,9 +58,14 @@ def __float_le(val1, val2):
 @param verbose         determining if output contains comments
 @param max_itemset_size determines the maximum size of frequent itemsets allowed
                         to generate association rules from
+@param max_lhs_size    determines the maximum size of the lhs of the rule
+@param max_rhs_size    determines the maximum size of the rhs of the rule
 """
+
+@MinWarning("warning")
 def assoc_rules(madlib_schema, support, confidence, tid_col,
-                item_col, input_table, output_schema, verbose, max_itemset_size):
+                item_col, input_table, output_schema, verbose,
+                max_itemset_size, max_lhs_size, max_rhs_size):
 
     begin_func_exec = time.time();
     begin_step_exec = time.time();
@@ -68,6 +75,11 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
     elif max_itemset_size <= 1:
         plpy.error("ERROR: max_itemset_size has to be greater than 1.")
 
+    #Validate LHS RHS
+    __assert(max_lhs_size is None or max_lhs_size > 0,
+             "max_lhs_size param must be a positive number.")
+    __assert(max_rhs_size is None or max_rhs_size > 0,
+             "max_rhs_size param must be a positive number.")
     #check parameters
     __assert(
             support is not None and
@@ -177,6 +189,10 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
     num_tranx = rv[0]["c1"];
     num_prod = rv[0]["c2"];
     min_supp_tranx = float(num_tranx) * support;
+    # Set default values to the max possible number that an itemset can
+    # have on its LHS or RHS for a given input dataset.
+    max_lhs_size = num_prod if max_lhs_size is None else max_lhs_size
+    max_rhs_size = num_prod if max_rhs_size is None else max_rhs_size
 
     # get the items whose counts are greater than the given
     # support counts. Each item will be given a continuous number.
@@ -213,10 +229,6 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
     if verbose :
         plpy.info("finished encoding input table: {0}".format(
                 time.time() - begin_step_exec));
-
-    begin_step_exec = time.time();
-
-    if verbose:
         plpy.info("Beginning iteration #1");
 
     begin_step_exec = time.time();
@@ -247,6 +259,20 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
         m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)')
         """.format(madlib_schema, madlib_schema));
 
+    # this table adds a new column to order the table by its setlist
+    plpy.execute("DROP TABLE IF EXISTS assoc_rule_sets_loop_ordered");
+    plpy.execute("""
+        CREATE TEMP TABLE assoc_rule_sets_loop_ordered
+            (
+            id          BIGINT,
+            set_list    {0}.svec,
+            support     FLOAT8,
+            tids        {1}.svec,
+            newrownum   BIGINT
+            )
+        m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (newrownum)')
+        """.format(madlib_schema, madlib_schema));
+
     plpy.execute("""
          INSERT INTO assoc_rule_sets_loop (id, set_list, support, tids)
          SELECT
@@ -262,6 +288,13 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
                     num_tranx)
          );
 
+    plpy.execute("""
+            INSERT INTO assoc_rule_sets_loop_ordered(id, set_list, support,
+                                                     tids, newrownum)
+            SELECT *, row_number() over (order by set_list) as newrownum
+            FROM assoc_rule_sets_loop
+            """
+            );
     rv = plpy.execute("""
         SELECT count(id) as c1, max(id) as c2
         FROM assoc_rule_sets_loop
@@ -275,19 +308,6 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
             "internal error: num_item_loop must be equal to max_item_loop"
         );
 
-    # to avoid the self cross join of the table assoc_rule_sets_loop,
-    # we use the following table to generate the join relationship
-    # so that we can use inner join (hash join)
-    plpy.execute("DROP TABLE IF EXISTS rule_set_rel");
-    plpy.execute("""
-         CREATE TEMP TABLE rule_set_rel
-            (
-            sid     BIGINT,
-            did     BIGINT
-            )
-         m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (sid)')
-         """);
-
     # As two different svecs may have the same hash key,
     # we use this table to assign a unique ID for each svec.
     plpy.execute("DROP TABLE IF EXISTS assoc_loop_aux");
@@ -309,7 +329,7 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
             time.time() - begin_step_exec));
 
     iter = 0;
-
+    num_products_threshold = num_supp_prod
     while num_item_loop > 0 and iter < max_itemset_size:
         begin_step_exec = time.time();
         iter = iter + 1;
@@ -317,13 +337,6 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
         if verbose  :
             plpy.info("Beginning iteration # {0}".format(iter + 1));
 
-        plpy.execute("TRUNCATE TABLE rule_set_rel");
-        plpy.execute("""
-             INSERT INTO rule_set_rel(sid, did)
-             SELECT t1.id, generate_series(t1.id + 1, {0})
-             FROM assoc_rule_sets_loop t1
-             """.format(num_item_loop));
-
         plpy.execute("""
              INSERT INTO assoc_rule_sets
                 (text_svec, set_list, support, iteration)
@@ -339,30 +352,56 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
                  time.time() - begin_step_exec));
 
         # generate the patterns for the next iteration
+
+        # The vector operations are used to calculate the boolean OR operation
+        # efficiently.
+        # Example:
+        # a1 = [1,1,0,0,0,1]            a2 = [1,1,0,0,1,1]
+        # svec_a1 = {2,3,1}:{1,0,1}     svec_a2 = {2,2,2}:{1,0,1}
+        # x = svec_a1 + svec_a2 = [2,2,0,0,1,2] -> {2,2,1,1}:{2,0,1,2}
+        # y = svec_a1 * svec_a2 = [1,1,0,0,0,1] -> {2,3,1}:{1,0,1}
+        # x - y = [1,1,0,0,1,1] -> {2,2,2}:{1,0,1}
+
+        # The boolean OR operation is used to combine two itemsets to create a larger one.
+        # (a,b,c) combined with (a,b,d) will give (a,b,c,d)
+        # (a,b,c) combined with (c,d,e) will give (a,b,c,d,e)
+
+        # The assoc_rule_sets_loop_ordered table is ordered by the setlist.
+        # This ensures that the sets (a,b,c) and (a,b,d) are close to each other.
+        # t3.newrownum-t1.newrownum <= {num_products_threshold} check ensures that we are
+        # looking at a subset of pairs of itemsets instead of all pairs.
+
+        # Let's assume we are trying to get to (a,b,c,d)
+        # If the support for (a,b,c,d) is above the threshold; (a,b,c), (a,b,d),
+        # (b,c,d) should all be already in our list.
+        # In this case it doesn't matter if we skip trying (a,b,c) and (b,c,d)
+        # combination because the result is already covered by (a,b,c) + (a,c,d).
+
+        # The iterp1 check ensures that the new itemset is of a certain size.
+        # At every iteration, we increase the target itemset size by one.
+
         plpy.execute("ALTER SEQUENCE assoc_loop_aux_id_seq RESTART WITH 1");
         plpy.execute("TRUNCATE TABLE assoc_loop_aux");
         plpy.execute("""
            INSERT INTO assoc_loop_aux(set_list, support, tids)
-           SELECT DISTINCT ON({0}.svec_to_string(set_list)) set_list,
-                   {1}.svec_l1norm(tids)::FLOAT8 / {2},
+           SELECT DISTINCT ON({madlib_schema}.svec_to_string(set_list)) set_list,
+                   {madlib_schema}.svec_l1norm(tids)::FLOAT8 / {num_tranx},
                    tids
            FROM (
              SELECT
-                {3}.svec_minus(
-                    {4}.svec_plus(t1.set_list, t3.set_list),
-                    {5}.svec_mult(t1.set_list, t3.set_list)
+                {madlib_schema}.svec_minus(
+                    {madlib_schema}.svec_plus(t1.set_list, t3.set_list),
+                    {madlib_schema}.svec_mult(t1.set_list, t3.set_list)
                 ) as set_list,
-                {6}.svec_mult(t1.tids, t3.tids) as tids
-             FROM assoc_rule_sets_loop t1,
-                  rule_set_rel t2,
-                  assoc_rule_sets_loop t3
-             WHERE t1.id = t2.sid and t2.did = t3.id
+                {madlib_schema}.svec_mult(t1.tids, t3.tids) as tids
+             FROM assoc_rule_sets_loop_ordered t1,
+                  assoc_rule_sets_loop_ordered t3
+             WHERE t1.newrownum < t3.newrownum AND
+                   t3.newrownum-t1.newrownum <= {num_products_threshold}
            ) t
-           WHERE {7}.svec_l1norm(set_list)::INT = {8} AND
-                 {9}.svec_l1norm(tids)::FLOAT8 >= {10}
-           """.format(madlib_schema, madlib_schema, num_tranx, madlib_schema,
-                      madlib_schema, madlib_schema, madlib_schema, madlib_schema,
-                      iter + 1, madlib_schema, min_supp_tranx)
+           WHERE {madlib_schema}.svec_l1norm(set_list)::INT = {iterp1} AND
+                 {madlib_schema}.svec_l1norm(tids)::FLOAT8 >= {min_supp_tranx}
+           """.format(iterp1=iter + 1, **locals())
            );
 
         plpy.execute("TRUNCATE TABLE assoc_rule_sets_loop");
@@ -371,6 +410,13 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
             SELECT * FROM assoc_loop_aux
             """
             );
+        plpy.execute("TRUNCATE TABLE assoc_rule_sets_loop_ordered");
+        plpy.execute("""
+            INSERT INTO assoc_rule_sets_loop_ordered(id, set_list, support, tids, newrownum)
+            SELECT *, row_number() over (order by set_list) as newrownum
+            FROM assoc_rule_sets_loop
+            """
+            );
 
         rv = plpy.execute("""
             SELECT count(id) as c1, max(id) as c2
@@ -383,6 +429,27 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
             (num_item_loop == 0 and max_item_loop is None),
             "internal error: num_item_loop must be equal to max_item_loop"
             );
+        # num_products_threshold should be equal to the number of distinct
+        # products that are present in the previous iteration's frequent
+        # itemsets.
+        # We can ideally do away with the if condition, but it's a trade-off:
+        # The query in the if condition might be a considerable overhead when
+        # the original number of products is very low compared to the number
+        # of rows in assoc_rule_sets_loop in a given iteration. On the other
+        # hand, it can be a considerable improvement if we have a lot of
+        # distinct products and only a small number of them are present in
+        # frequent itemsets. But this is specific to datasets and parameters
+        # (such as support), so the following if statment is a compromise.
+        if num_item_loop < num_supp_prod:
+            # Get number of 1's from all set_lists in assoc_rule_sets_loop
+            num_products_threshold = plpy.execute("""
+                SELECT {madlib_schema}.svec_l1norm(
+                    {madlib_schema}.svec_count_nonzero(a)) AS cnt
+                FROM (
+                    SELECT {madlib_schema}.svec_sum(set_list) AS a
+                    FROM assoc_rule_sets_loop
+                    ) t
+                """.format(**locals()))[0]['cnt']
 
         if verbose :
             plpy.info("{0} Frequent itemsets found in this iteration".format(
@@ -419,15 +486,18 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
                 END
             FROM (
                 SELECT
-                    {0}.gen_rules_from_cfp(text_svec, iteration) as item,
+                    {madlib_schema}.gen_rules_from_cfp(text_svec,
+                                                       iteration,
+                                                       {max_lhs_size},
+                                                       {max_rhs_size}) as item,
                     support as support_xy
                 FROM assoc_rule_sets
                 WHERE iteration > 1
             ) t, assoc_rule_sets x, assoc_rule_sets y
             WHERE t.item[1] = x.text_svec AND
                   t.item[2] = y.text_svec AND
-                  (t.support_xy / x.support) >= {1}
-            """.format(madlib_schema, confidence)
+                  (t.support_xy / x.support) >= {confidence}
+            """.format(**locals())
             );
 
         # generate the readable rules
@@ -475,6 +545,7 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
              """.format(output_schema, num_tranx)
              );
 
+
         # if in verbose mode, we will keep all the intermediate tables
         if not verbose :
             plpy.execute("""
@@ -487,7 +558,7 @@ def assoc_rules(madlib_schema, support, confidence, tid_col,
                 DROP TABLE IF EXISTS assoc_enc_input;
                 DROP TABLE IF EXISTS assoc_rule_sets;
                 DROP TABLE IF EXISTS assoc_rule_sets_loop;
-                DROP TABLE IF EXISTS rule_set_rel;
+                DROP TABLE IF EXISTS assoc_rule_sets_loop_ordered;
                 DROP TABLE IF EXISTS assoc_loop_aux;
                 """);
 
@@ -529,19 +600,33 @@ def assoc_rules_help_message(schema_madlib, message=None, **kwargs):
                                 USAGE
 -----------------------------------------------------------------------
 SELECT {schema_madlib}.assoc_rules(
-    support,            -- FLOAT8, minimum level of support needed for each itemset to be included in result
-    confidence,         -- FLOAT8, minimum level of confidence needed for each rule to be included in result
-    tid_col,            -- TEXT, name of the column storing the transaction ids
-    item_col,           -- TEXT, name of the column storing the products
-    input_table,        -- TEXT, name of the table containing the input data
-    output_schema,      -- TEXT, name of the schema where the final results will be stored.
-                                The schema must be created before calling the function.  Alternatively, use
-                                <tt>NULL</tt> to output to the current schema.
-    verbose,            -- BOOLEAN, (optional, default: False) determines if details are printed for each
-                                iteration as the algorithm progresses
-    max_itemset_size    -- INTEGER, (optional, default: itemsets of all sizes) determines the maximum size of frequent
-                                itemsets allowed that are used for generating association rules. Value less
-                                than 2 throws an error.
+    support,            -- FLOAT8,  minimum level of support needed for each
+                                    itemset to be included in result
+    confidence,         -- FLOAT8,  minimum level of confidence needed for each
+                                    rule to be included in result
+    tid_col,            -- TEXT,    name of the column storing the transaction ids
+    item_col,           -- TEXT,    name of the column storing the products
+    input_table,        -- TEXT,    name of the table containing the input data
+    output_schema,      -- TEXT,    name of the schema where the final results
+                                    will be stored. The schema must be created
+                                    before calling the function.
+                                    Alternatively, use <tt>NULL</tt> to output
+                                    to the current schema.
+    verbose,            -- BOOLEAN, (optional, default: False) determines if
+                                    details are printed for each iteration as
+                                    the algorithm progresses
+    max_itemset_size,   -- INTEGER, (optional, default: itemsets of all sizes)
+                                    determines the maximum size of frequent
+                                    itemsets allowed that are used for generating
+                                    association rules. Value less than 2 throws an error.
+    max_lhs_size,       -- INTEGER, (optional, default: NULL) determines the
+                                    maximum size of the lhs of the rule.
+                                    NULL means there is no restriction on the
+                                    size of the left hand side.
+    max_rhs_size        -- INTEGER  (optional, default: NULL) determines the
+                                    maximum size of the rhs of the rule.
+                                    NULL means there is no restriction on the
+                                    size of the right hand side.
 );
 -------------------------------------------------------------------------
                                 OUTPUT TABLES
diff --git a/src/ports/postgres/modules/assoc_rules/assoc_rules.sql_in b/src/ports/postgres/modules/assoc_rules/assoc_rules.sql_in
index dafe117..321b4fa 100644
--- a/src/ports/postgres/modules/assoc_rules/assoc_rules.sql_in
+++ b/src/ports/postgres/modules/assoc_rules/assoc_rules.sql_in
@@ -179,7 +179,9 @@ assoc_rules( support,
              input_table,
              output_schema,
              verbose,
-             max_itemset_size
+             max_itemset_size,
+             max_LHS_size,
+             max_RHS_size
            );</pre>
 This generates all association rules that satisfy the specified minimum
 <em>support</em> and <em>confidence</em>.
@@ -273,6 +275,22 @@ This generates all association rules that satisfy the specified minimum
   This parameter can be used to reduce run time for data sets where itemset size is large,
   which is a common situation. If your query is not returning or is running too long,
   try using a lower value for this parameter.</dd>
+
+
+  <dt>max_LHS_size (optional)</dt>
+  <dd>INTEGER, default: NULL. Determines the maximum size of the left hand side
+  of the rule. Must be 1 or more.
+  This parameter can be used to reduce run time for data sets where itemset size is large,
+  which is a common situation. If your query is not returning or is running too long,
+  try using a lower value for this parameter.</dd>
+
+
+  <dt>max_RHS_size (optional)</dt>
+  <dd>INTEGER, default: NULL. Determines the maximum size of the right hand side
+  of the rule. Must be 1 or more.
+  This parameter can be used to reduce run time for data sets where itemset size is large,
+  which is a common situation. If your query is not returning or is running too long,
+  try using a lower value for this parameter.</dd>
 </dl>
 
 
@@ -397,6 +415,48 @@ Result:
 (2 rows)
 </pre>
 
+
+-# Limit the size of right hand side to 1.  This parameter is a good way to
+reduce long run times.
+<pre class="example">
+SELECT * FROM madlib.assoc_rules( .25,            -- Support
+                                  .5,             -- Confidence
+                                  'trans_id',     -- Transaction id col
+                                  'product',      -- Product col
+                                  'test_data',    -- Input data
+                                  NULL,           -- Output schema
+                                  TRUE,           -- Verbose output
+                                  NULL,           -- Max itemset size
+                                  NULL,           -- Max LHS size
+                                  1               -- Max RHS size
+                                );
+</pre>
+Result (iteration details not shown):
+<pre class="result">
+ output_schema | output_table | total_rules |   total_time
+---------------+--------------+-------------+-----------------
+ public        | assoc_rules  |           6 | 00:00:00.031362
+(1 row)
+</pre>
+The association rules are again stored in the assoc_rules table:
+<pre class="example">
+SELECT * FROM assoc_rules
+ORDER BY support DESC, confidence DESC;
+</pre>
+Result:
+<pre class="result">
+ ruleid |       pre       |   post    | count |      support      |    confidence     |       lift        |    conviction
+--------+-----------------+-----------+-------+-------------------+-------------------+-------------------+-------------------
+      4 | {diapers}       | {beer}    |     5 | 0.714285714285714 |                 1 |                 1 |                 0
+      3 | {beer}          | {diapers} |     5 | 0.714285714285714 | 0.714285714285714 |                 1 |                 1
+      1 | {chips}         | {beer}    |     3 | 0.428571428571429 |                 1 |                 1 |                 0
+      6 | {diapers,chips} | {beer}    |     2 | 0.285714285714286 |                 1 |                 1 |                 0
+      2 | {chips}         | {diapers} |     2 | 0.285714285714286 | 0.666666666666667 | 0.933333333333333 | 0.857142857142857
+      5 | {beer,chips}    | {diapers} |     2 | 0.285714285714286 | 0.666666666666667 | 0.933333333333333 | 0.857142857142857
+(6 rows)
+</pre>
+
+
 @anchor notes
 @par Notes
 
@@ -459,6 +519,8 @@ CREATE TYPE MADLIB_SCHEMA.assoc_rules_results AS
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.gen_rules_from_cfp
     (
     TEXT,
+    INT,
+    INT,
     INT
     )
 RETURNS SETOF TEXT[] AS 'MODULE_PATHNAME'
@@ -498,6 +560,76 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.assoc_rules
     input_table TEXT,
     output_schema TEXT,
     verbose BOOLEAN,
+    max_itemset_size INTEGER,
+    max_lhs_size INTEGER,
+    max_rhs_size INTEGER
+   )
+RETURNS MADLIB_SCHEMA.assoc_rules_results
+AS $$
+    PythonFunctionBodyOnly(`assoc_rules', `assoc_rules')
+    with AOControl(False):
+        plpy.execute("SET client_min_messages = error;")
+        # schema_madlib comes from PythonFunctionBodyOnly
+        return assoc_rules.assoc_rules(schema_madlib,
+                                       support,
+                                       confidence,
+                                       tid_col,
+                                       item_col,
+                                       input_table,
+                                       output_schema,
+                                       verbose,
+                                       max_itemset_size,
+                                       max_lhs_size,
+                                       max_rhs_size
+                                       );
+
+$$ LANGUAGE plpythonu
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.assoc_rules
+    (
+    support FLOAT8,
+    confidence FLOAT8,
+    tid_col TEXT,
+    item_col TEXT,
+    input_table TEXT,
+    output_schema TEXT,
+    verbose BOOLEAN,
+    max_itemset_size INTEGER,
+    max_LHS_size INTEGER
+   )
+RETURNS MADLIB_SCHEMA.assoc_rules_results
+AS $$
+    PythonFunctionBodyOnly(`assoc_rules', `assoc_rules')
+    with AOControl(False):
+        plpy.execute("SET client_min_messages = error;")
+        # schema_madlib comes from PythonFunctionBodyOnly
+        return assoc_rules.assoc_rules(schema_madlib,
+                                       support,
+                                       confidence,
+                                       tid_col,
+                                       item_col,
+                                       input_table,
+                                       output_schema,
+                                       verbose,
+                                       max_itemset_size,
+                                       max_lhs_size,
+                                       None
+                                       );
+
+$$ LANGUAGE plpythonu
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.assoc_rules
+    (
+    support FLOAT8,
+    confidence FLOAT8,
+    tid_col TEXT,
+    item_col TEXT,
+    input_table TEXT,
+    output_schema TEXT,
+    verbose BOOLEAN,
     max_itemset_size INTEGER
    )
 RETURNS MADLIB_SCHEMA.assoc_rules_results
@@ -514,7 +646,10 @@ AS $$
                                        input_table,
                                        output_schema,
                                        verbose,
-                                       max_itemset_size);
+                                       max_itemset_size,
+                                       None,
+                                       None
+                                       );
 
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -549,7 +684,9 @@ AS $$
                                        input_table,
                                        output_schema,
                                        False,
-                                       10);
+                                       10,
+                                       None,
+                                       None);
 
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -579,7 +716,9 @@ AS $$
                                        input_table,
                                        output_schema,
                                        verbose,
-                                       10);
+                                       10,
+                                       None,
+                                       None);
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git a/src/ports/postgres/modules/assoc_rules/test/assoc_rules.sql_in b/src/ports/postgres/modules/assoc_rules/test/assoc_rules.sql_in
index 5a6275e..e4e4af7 100644
--- a/src/ports/postgres/modules/assoc_rules/test/assoc_rules.sql_in
+++ b/src/ports/postgres/modules/assoc_rules/test/assoc_rules.sql_in
@@ -28,6 +28,8 @@ declare
     result1        TEXT;
     result2        TEXT;
     result3        TEXT;
+    result4        TEXT;
+    result5        TEXT;
     result_maxiter TEXT;
     res            MADLIB_SCHEMA.assoc_rules_results;
     output_schema  TEXT;
@@ -92,7 +94,9 @@ begin
         support double precision,
         confidence double precision,
         lift double precision,
-        conviction double precision
+        conviction double precision,
+        lhs_1d BOOL,
+        rhs_1d BOOL
     ) ;
 
 
@@ -107,13 +111,13 @@ begin
     INSERT INTO test1_exp_result VALUES (10, '{3,1}', '{2}', 0.20000000000000001, 1, 1.6666666666666667, 0);
     INSERT INTO test1_exp_result VALUES (3, '{1}', '{3}', 0.20000000000000001, 0.5, 1.2499999999999998, 1.2);
 
-    INSERT INTO test2_exp_result VALUES (7, '{chips,diapers}', '{beer}', 0.2857142857142857, 1, 1, 0);
-    INSERT INTO test2_exp_result VALUES (2, '{chips}', '{diapers}', 0.2857142857142857, 0.66666666666666663, 0.93333333333333324, 0.85714285714285698);
-    INSERT INTO test2_exp_result VALUES (1, '{chips}', '{diapers,beer}', 0.2857142857142857, 0.66666666666666663, 0.93333333333333324, 0.85714285714285698);
-    INSERT INTO test2_exp_result VALUES (6, '{diapers}', '{beer}', 0.7142857142857143, 1, 1, 0);
-    INSERT INTO test2_exp_result VALUES (4, '{beer}', '{diapers}', 0.7142857142857143, 0.7142857142857143, 1, 1);
-    INSERT INTO test2_exp_result VALUES (3, '{chips,beer}', '{diapers}', 0.2857142857142857, 0.66666666666666663, 0.93333333333333324, 0.85714285714285698);
-    INSERT INTO test2_exp_result VALUES (5, '{chips}', '{beer}', 0.42857142857142855, 1, 1, 0);
+    INSERT INTO test2_exp_result VALUES (7, '{chips,diapers}', '{beer}', 0.2857142857142857, 1, 1, 0, false, true);
+    INSERT INTO test2_exp_result VALUES (2, '{chips}', '{diapers}', 0.2857142857142857, 0.66666666666666663, 0.93333333333333324, 0.85714285714285698, true, true);
+    INSERT INTO test2_exp_result VALUES (1, '{chips}', '{diapers,beer}', 0.2857142857142857, 0.66666666666666663, 0.93333333333333324, 0.85714285714285698, true, false);
+    INSERT INTO test2_exp_result VALUES (6, '{diapers}', '{beer}', 0.7142857142857143, 1, 1, 0, true, true);
+    INSERT INTO test2_exp_result VALUES (4, '{beer}', '{diapers}', 0.7142857142857143, 0.7142857142857143, 1, 1, true, true);
+    INSERT INTO test2_exp_result VALUES (3, '{chips,beer}', '{diapers}', 0.2857142857142857, 0.66666666666666663, 0.93333333333333324, 0.85714285714285698, false, true);
+    INSERT INTO test2_exp_result VALUES (5, '{chips}', '{beer}', 0.42857142857142855, 1, 1, 0, true, true);
 
     res = MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data1','madlib_installcheck_assoc_rules', false);
 
@@ -129,6 +133,9 @@ begin
           assoc_array_eq(t1.post, t2.post) AND
           abs(t1.support - t2.support) < 1E-10 AND
           abs(t1.confidence - t2.confidence) < 1E-10;
+    IF result1 = 'FAIL' THEN
+        RAISE EXCEPTION 'Association rules mining failed. No results were returned for result 1.';
+    END IF;
 
     PERFORM MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false);
     SELECT INTO result2 CASE WHEN count(*) = 7 then 'PASS' ELSE 'FAIL' END
@@ -137,29 +144,63 @@ begin
           assoc_array_eq(t1.post, t2.post) AND
           abs(t1.support - t2.support) < 1E-10 AND
           abs(t1.confidence - t2.confidence) < 1E-10;
+    IF (result2 = 'FAIL') THEN
+        RAISE EXCEPTION 'Association rules mining failed. No results were returned for result 2.';
+    END IF;
+
+    -- Test for max_RHS_size=2. No rules with RHS greater than 1 item must exist.
+    PERFORM MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false, NULL, NULL, 1);
+    SELECT INTO result3 CASE WHEN count(*) = 6 then 'PASS' ELSE 'FAIL' END
+    FROM assoc_rules t1, test2_exp_result t2
+    WHERE assoc_array_eq(t1.pre, t2.pre) AND
+          assoc_array_eq(t1.post, t2.post) AND
+          abs(t1.support - t2.support) < 1E-10 AND
+          abs(t1.confidence - t2.confidence) < 1E-10 AND
+          rhs_1d=true;
+    IF result3 = 'FAIL' THEN
+        RAISE EXCEPTION 'Association rules mining failed. Assertion failed when max_RHS_size=1';
+    END IF;
+
+    -- Test for max_LHS_size=2. No rules with LHS greater than 1 item must exist.
+    PERFORM MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false, NULL, 1, NULL);
+    SELECT INTO result4 CASE WHEN count(*) = 5 then 'PASS' ELSE 'FAIL' END
+    FROM assoc_rules t1, test2_exp_result t2
+    WHERE assoc_array_eq(t1.pre, t2.pre) AND
+          assoc_array_eq(t1.post, t2.post) AND
+          abs(t1.support - t2.support) < 1E-10 AND
+          abs(t1.confidence - t2.confidence) < 1E-10 AND
+          lhs_1d=true;
+    IF result4 = 'FAIL' THEN
+        RAISE EXCEPTION 'Association rules mining failed. Assertion failed when max_LHS_size=1';
+    END IF;
+
+    -- Test for max_itemset_size=2. No rules with either LHS or RHS should contain greater than 1 item.
+    PERFORM MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false, 2, NULL, NULL);
+    SELECT INTO result5 CASE WHEN count(*) = 4 then 'PASS' ELSE 'FAIL' END
+    FROM assoc_rules t1, test2_exp_result t2
+    WHERE assoc_array_eq(t1.pre, t2.pre) AND
+          assoc_array_eq(t1.post, t2.post) AND
+          abs(t1.support - t2.support) < 1E-10 AND
+          abs(t1.confidence - t2.confidence) < 1E-10 AND
+          rhs_1d=true AND
+          lhs_1d=true;
+    IF result5 = 'FAIL' THEN
+        RAISE EXCEPTION 'Association rules mining failed. Assertion failed when max_itemset_size=2';
+    END IF;
 
     PERFORM MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false, 2);
     SELECT INTO result_maxiter CASE WHEN count(*) = 4 then 'PASS' ELSE 'FAIL' END
     FROM assoc_rules;
 
-    DROP TABLE IF EXISTS test_data1;
-    DROP TABLE IF EXISTS test_data2;
+
     DROP TABLE IF EXISTS test2_exp_result;
     DROP TABLE IF EXISTS test1_exp_result;
 
-    IF result3 = 'FAIL' THEN
-        RAISE EXCEPTION 'Input data transformation failed';
-    END IF;
-
-    IF (result1 = 'FAIL') OR (result2 = 'FAIL') THEN
-        RAISE EXCEPTION 'Association rules mining failed. No results were returned.';
-    END IF;
-
     IF result_maxiter = 'FAIL' THEN
         RAISE EXCEPTION 'Association rules mining error when max_iter parameter specified.';
     END IF;
 
-    RAISE INFO 'Association rules install check passed.';
+    RAISE INFO 'Association rules dev check output test cases passed.';
     RETURN;
 
 end $$ language plpgsql;
@@ -168,3 +209,14 @@ end $$ language plpgsql;
 -- Test
 ---------------------------------------------------------------------------
 SELECT install_test();
+
+-- Input test cases.
+SELECT MADLIB_SCHEMA.assert(MADLIB_SCHEMA.trap_error($TRAP$
+SELECT MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false, 2, 0, NULL);
+$TRAP$) = 1, 'Should error out if max_LHS_size is < 1');
+
+SELECT MADLIB_SCHEMA.assert(MADLIB_SCHEMA.trap_error($TRAP$
+SELECT MADLIB_SCHEMA.assoc_rules (.1, .5, 'trans_id', 'product', 'test_data2','madlib_installcheck_assoc_rules', false, NULL, 5, -1);
+$TRAP$) = 1, 'Should error out if max_RHS_size is < 1');
+DROP TABLE IF EXISTS test_data1;
+DROP TABLE IF EXISTS test_data2;