You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by li...@apache.org on 2020/03/20 01:15:57 UTC

[submarine] branch master updated: SUBMARINE-427. [SDK] Add Neural Factorization Machine model

This is an automated email from the ASF dual-hosted git repository.

liuxun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 8caf5d8  SUBMARINE-427. [SDK] Add Neural Factorization Machine model
8caf5d8 is described below

commit 8caf5d85f2ab33625eee36b0c85dfb8161ca30ac
Author: pingsutw <pi...@gmail.com>
AuthorDate: Tue Mar 17 17:31:49 2020 +0800

    SUBMARINE-427. [SDK] Add Neural Factorization Machine model
    
    ### What is this PR for?
    Add TensorFlow implementation of [Neural Factorization Machine](https://arxiv.org/pdf/1708.05027.pdf) model
    There are some common tf layer in ctr model, put common layer to core.py
    make code concise and development flexible
    
    (will refactor deepfm and fm in next Jira ticket)
    
    ### What type of PR is it?
    [Improvement]
    
    ### Todos
    * [ ] - Task
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-427
    
    ### How should this be tested?
    https://github.com/pingsutw/hadoop-submarine/actions/runs/58067991
    https://travis-ci.org/github/pingsutw/hadoop-submarine/builds/663851273
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the licenses files need update? No
    * Is there breaking changes for older versions? No
    * Does this needs documentation? No
    
    Author: pingsutw <pi...@gmail.com>
    
    Closes #238 from pingsutw/SUBMARINE-427 and squashes the following commits:
    
    5de38c4 [pingsutw] SUBMARINE-427. [SDK] Add Neural Factorization Machine model
---
 submarine-sdk/pysubmarine/pylintrc                 |   1 +
 .../submarine/ml/{model => layers}/__init__.py     |   5 -
 .../pysubmarine/submarine/ml/layers/core.py        | 128 +++++++++++++++++++++
 .../pysubmarine/submarine/ml/model/__init__.py     |   3 +-
 .../submarine/ml/model/base_tf_model.py            |   5 +-
 .../pysubmarine/submarine/ml/model/deepfm.py       |  14 +--
 submarine-sdk/pysubmarine/submarine/ml/model/fm.py |   2 +-
 .../pysubmarine/submarine/ml/model/nfm.py          |  45 ++++++++
 .../pysubmarine/submarine/utils/tf_utils.py        |   9 +-
 .../__init__.py => tests/ml/model/test_nfm.py}     |  13 ++-
 10 files changed, 195 insertions(+), 30 deletions(-)

diff --git a/submarine-sdk/pysubmarine/pylintrc b/submarine-sdk/pysubmarine/pylintrc
index 6afd467..32f1321 100644
--- a/submarine-sdk/pysubmarine/pylintrc
+++ b/submarine-sdk/pysubmarine/pylintrc
@@ -78,6 +78,7 @@ confidence=
 disable=missing-docstring,
         print-statement,
         unnecessary-pass,
+        unused-argument,
         parameter-unpacking,
         unpacking-in-except,
         old-raise-syntax,
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py b/submarine-sdk/pysubmarine/submarine/ml/layers/__init__.py
similarity index 91%
copy from submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
copy to submarine-sdk/pysubmarine/submarine/ml/layers/__init__.py
index cf6064d..a6eb1b5 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/layers/__init__.py
@@ -12,8 +12,3 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from .deepfm import DeepFM
-from .fm import FM
-
-__all__ = ["DeepFM", "FM"]
diff --git a/submarine-sdk/pysubmarine/submarine/ml/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/layers/core.py
new file mode 100644
index 0000000..9a52024
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/ml/layers/core.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+
+
+def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay):
+    bn_train = tf.contrib.layers.batch_norm(x, decay=batch_norm_decay, center=True, scale=True,
+                                            updates_collections=None, is_training=True,
+                                            reuse=None, scope=scope_bn)
+    bn_infer = tf.contrib.layers.batch_norm(x, decay=batch_norm_decay, center=True, scale=True,
+                                            updates_collections=None, is_training=False,
+                                            reuse=True, scope=scope_bn)
+    return tf.cond(tf.cast(train_phase, tf.bool), lambda: bn_train, lambda: bn_infer)
+
+
+def dnn_layer(deep_inputs, estimator_mode, batch_norm, deep_layers, dropout, batch_norm_decay=0.9,
+              l2_reg=0, **kwargs):
+    """
+    The Multi Layer Percetron
+    :param deep_inputs: A tensor of at least rank 2 and static value for the last dimension; i.e.
+           [batch_size, depth], [None, None, None, channels].
+    :param estimator_mode: Standard names for Estimator model modes. `TRAIN`, `EVAL`, `PREDICT`
+    :param batch_norm: Whether use BatchNormalization before activation or not.
+    :param batch_norm_decay: Decay for the moving average.
+           Reasonable values for decay are close to 1.0, typically in the
+           multiple-nines range: 0.999, 0.99, 0.9, etc.
+    :param deep_layers: list of positive integer, the layer number and units in each layer.
+    :param dropout: float in [0,1). Fraction of the units to dropout.
+    :param l2_reg: float between 0 and 1.
+           L2 regularizer strength applied to the kernel weights matrix.
+    """
+    with tf.variable_scope("DNN_Layer"):
+        if batch_norm:
+            if estimator_mode == tf.estimator.ModeKeys.TRAIN:
+                train_phase = True
+            else:
+                train_phase = False
+
+        for i in range(len(deep_layers)):
+            deep_inputs = tf.contrib.layers.fully_connected(
+                inputs=deep_inputs, num_outputs=deep_layers[i],
+                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg),
+                scope='mlp%d' % i)
+            if batch_norm:
+                deep_inputs = batch_norm_layer(
+                    deep_inputs, train_phase=train_phase,
+                    scope_bn='bn_%d' % i, batch_norm_decay=batch_norm_decay)
+            if estimator_mode == tf.estimator.ModeKeys.TRAIN:
+                deep_inputs = tf.nn.dropout(deep_inputs, keep_prob=dropout[i])
+
+        deep_out = tf.contrib.layers.fully_connected(
+            inputs=deep_inputs, num_outputs=1, activation_fn=tf.identity,
+            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg),
+            scope='deep_out')
+        deep_out = tf.reshape(deep_out, shape=[-1])
+        return deep_out
+
+
+def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs):
+    """
+    Layer which represents linear function.
+    :param features: input features
+    :param feature_size: size of features
+    :param field_size: number of fields in the features
+    :param l2_reg: float between 0 and 1.
+           L2 regularizer strength applied to the kernel weights matrix.
+    """
+    feat_ids = features['feat_ids']
+    feat_ids = tf.reshape(feat_ids, shape=[-1, field_size])
+    feat_vals = features['feat_vals']
+    feat_vals = tf.reshape(feat_vals, shape=[-1, field_size])
+
+    regularizer = tf.contrib.layers.l2_regularizer(l2_reg)
+    with tf.variable_scope("LinearLayer_Layer"):
+        linear_bias = tf.get_variable(name='linear_bias', shape=[1],
+                                      initializer=tf.constant_initializer(0.0))
+        linear_weight = tf.get_variable(name='linear_weight', shape=[feature_size],
+                                        initializer=tf.glorot_normal_initializer(),
+                                        regularizer=regularizer)
+
+    feat_weights = tf.nn.embedding_lookup(linear_weight, feat_ids)
+    linear_out = tf.reduce_sum(tf.multiply(feat_weights, feat_vals), 1) + linear_bias
+    return linear_out
+
+
+def bilinear_layer(features, feature_size, field_size, embedding_size, l2_reg=0, **kwargs):
+    """
+    Bi-Interaction Layer used in Neural FM,compress the pairwise element-wise product of features
+    into one single vector.
+    :param features: input features
+    :param feature_size: size of features
+    :param field_size: number of fields in the features
+    :param embedding_size: sparse feature embedding_size
+    :param l2_reg: float between 0 and 1.
+           L2 regularizer strength applied to the kernel weights matrix.
+    """
+    feat_ids = features['feat_ids']
+    feat_ids = tf.reshape(feat_ids, shape=[-1, field_size])
+    feat_vals = features['feat_vals']
+    feat_vals = tf.reshape(feat_vals, shape=[-1, field_size])
+
+    with tf.variable_scope("BilinearLayer_Layer"):
+        regularizer = tf.contrib.layers.l2_regularizer(l2_reg)
+        embedding_dict = tf.get_variable(name='embedding_dict',
+                                         shape=[feature_size, embedding_size],
+                                         initializer=tf.glorot_normal_initializer(),
+                                         regularizer=regularizer)
+
+    embeddings = tf.nn.embedding_lookup(embedding_dict, feat_ids)
+    feat_vals = tf.reshape(feat_vals, shape=[-1, field_size, 1])
+    embeddings = tf.multiply(embeddings, feat_vals)
+    sum_square = tf.square(tf.reduce_sum(embeddings, 1))
+    square_sum = tf.reduce_sum(tf.square(embeddings), 1)
+    bilinear_out = 0.5 * tf.subtract(sum_square, square_sum)
+    return bilinear_out
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py b/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
index cf6064d..febeb99 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
@@ -15,5 +15,6 @@
 
 from .deepfm import DeepFM
 from .fm import FM
+from .nfm import NFM
 
-__all__ = ["DeepFM", "FM"]
+__all__ = ["DeepFM", "FM", "NFM"]
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/base_tf_model.py b/submarine-sdk/pysubmarine/submarine/ml/model/base_tf_model.py
index faeb50d..9171d99 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/model/base_tf_model.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/model/base_tf_model.py
@@ -16,6 +16,7 @@
 from abc import ABC
 import logging
 import tensorflow as tf
+import numpy as np
 from submarine.ml.model.abstract_model import AbstractModel
 from submarine.ml.registries import input_fn_registry
 from submarine.ml.parameters import default_parameters
@@ -107,4 +108,6 @@ class BaseTFModel(AbstractModel, ABC):
         )
 
     def model_fn(self, features, labels, mode, params):
-        pass
+        seed = params["training"]["seed"]
+        np.random.seed(seed)
+        tf.set_random_seed(seed)
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/deepfm.py b/submarine-sdk/pysubmarine/submarine/ml/model/deepfm.py
index 1cbc1b7..9c6d506 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/model/deepfm.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/model/deepfm.py
@@ -29,22 +29,12 @@ import logging
 import tensorflow as tf
 import numpy as np
 from submarine.ml.model.base_tf_model import BaseTFModel
+from submarine.ml.layers.core import batch_norm_layer
 from submarine.utils.tf_utils import get_estimator_spec
 
 logger = logging.getLogger(__name__)
 
 
-def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay):
-    bn_train = tf.contrib.layers.batch_norm(x, decay=batch_norm_decay, center=True, scale=True,
-                                            updates_collections=None, is_training=True,
-                                            reuse=None, scope=scope_bn)
-    bn_infer = tf.contrib.layers.batch_norm(x, decay=batch_norm_decay, center=True, scale=True,
-                                            updates_collections=None, is_training=False,
-                                            reuse=True, scope=scope_bn)
-    z = tf.cond(tf.cast(train_phase, tf.bool), lambda: bn_train, lambda: bn_infer)
-    return z
-
-
 class DeepFM(BaseTFModel):
     def model_fn(self, features, labels, mode, params):
         field_size = params["training"]["field_size"]
@@ -115,4 +105,4 @@ class DeepFM(BaseTFModel):
             y_bias = fm_bias * tf.ones_like(y_d, dtype=tf.float32)
             logit = y_bias + y_w + y_v + y_d
 
-        return get_estimator_spec(logit, labels, mode, params, [fm_vector, fm_weight])
+        return get_estimator_spec(logit, labels, mode, params)
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/fm.py b/submarine-sdk/pysubmarine/submarine/ml/model/fm.py
index fed9524..19e4c65 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/model/fm.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/model/fm.py
@@ -67,4 +67,4 @@ class FM(BaseTFModel):
 
         y = fm_bias + y_w + y_v
 
-        return get_estimator_spec(y, labels, mode, params, [fm_vector, fm_weight])
+        return get_estimator_spec(y, labels, mode, params)
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/nfm.py b/submarine-sdk/pysubmarine/submarine/ml/model/nfm.py
new file mode 100644
index 0000000..957f32a
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/ml/model/nfm.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+TensorFlow implementation of NFM
+
+Reference:
+    [1] He X, Chua T S. Neural factorization machines for sparse predictive
+    analytics[C]//Proceedings of the 40th International ACM SIGIR conference on Research and
+    Development in Information Retrieval. ACM, 2017: 355-364. (https://arxiv.org/abs/1708.05027)
+"""
+
+import logging
+import tensorflow as tf
+from submarine.ml.model.base_tf_model import BaseTFModel
+from submarine.ml.layers.core import dnn_layer, bilinear_layer, linear_layer
+from submarine.utils.tf_utils import get_estimator_spec
+
+logger = logging.getLogger(__name__)
+
+
+class NFM(BaseTFModel):
+    def model_fn(self, features, labels, mode, params):
+        super().model_fn(features, labels, mode, params)
+
+        linear_logit = linear_layer(features, **params['training'])
+        deep_inputs = bilinear_layer(features, **params['training'])
+        deep_logit = dnn_layer(deep_inputs, mode,  **params['training'])
+
+        with tf.variable_scope("NFM_out"):
+            logit = linear_logit + deep_logit
+
+        return get_estimator_spec(logit, labels, mode, params)
diff --git a/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py b/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
index 3d42628..5273512 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
@@ -72,7 +72,7 @@ def get_tf_config(params):
     return tf_config
 
 
-def get_estimator_spec(logit, labels, mode, params, weights):
+def get_estimator_spec(logit, labels, mode, params):
     """
     Returns `EstimatorSpec` that a model_fn can return.
     :param logit: logits `Tensor` to be used.
@@ -80,10 +80,8 @@ def get_estimator_spec(logit, labels, mode, params, weights):
     :param mode: Estimator's `ModeKeys`.
     :param params: Optional dict of hyperparameters. Will receive what is passed to Estimator
      in params parameter.
-    :param weights: a list of weights that need L2 regularization
     :return:
     """
-    l2_reg = params["training"]["l2_reg"]
     learning_rate = params["training"]["learning_rate"]
     optimizer = params["training"]["optimizer"]
     metric = params['output']['metric']
@@ -101,11 +99,8 @@ def get_estimator_spec(logit, labels, mode, params, weights):
             export_outputs=export_outputs)
 
     with tf.name_scope("Loss"):
-        l2 = 0
-        for weight in weights:
-            l2 = l2_reg * tf.nn.l2_loss(weight)
         loss = tf.reduce_mean(
-            tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=labels)) + l2
+            tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=labels))
 
     # Provide an estimator spec for `ModeKeys.EVAL`
     eval_metric_ops = {}
diff --git a/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py b/submarine-sdk/pysubmarine/tests/ml/model/test_nfm.py
similarity index 79%
copy from submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
copy to submarine-sdk/pysubmarine/tests/ml/model/test_nfm.py
index cf6064d..03866e9 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/model/__init__.py
+++ b/submarine-sdk/pysubmarine/tests/ml/model/test_nfm.py
@@ -13,7 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .deepfm import DeepFM
-from .fm import FM
 
-__all__ = ["DeepFM", "FM"]
+from submarine.ml.model import NFM
+
+
+def test_run_nfm(get_model_param):
+    params = get_model_param
+
+    model = NFM(model_params=params)
+    model.train()
+    model.evaluate()
+    model.predict()


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org