You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2021/01/18 06:39:41 UTC

[submarine] branch master updated: SUBMARINE-426. [SDK] Add Convolutional Click Prediction Model

This is an automated email from the ASF dual-hosted git repository.

pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e0f41c  SUBMARINE-426. [SDK] Add Convolutional Click Prediction Model
3e0f41c is described below

commit 3e0f41c53b15e412dd5167b910f36a075c3744c6
Author: Lisa <ae...@gmail.com>
AuthorDate: Tue Jan 12 18:46:51 2021 +0800

    SUBMARINE-426. [SDK] Add Convolutional Click Prediction Model
    
    ### What is this PR for?
    Add TensorFlow implementation of Convolutional Click Prediction Model
    
    [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf)
    
    ### What type of PR is it?
    [Improvement]
    
    ### Todos
    * [ ] - Task
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-426
    
    ### How should this be tested?
    https://travis-ci.org/github/aeioulisa/submarine/builds/753005415
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the licenses files need update? No
    * Is there breaking changes for older versions? No
    * Does this needs documentation? No
    
    Author: Lisa <ae...@gmail.com>
    
    Closes #487 from aeioulisa/SUBMARINE-426 and squashes the following commits:
    
    793db57 [Lisa] Repair the import error
    a45ecfe [Lisa] remove redundant logs
    9e6e39c [Lisa] Fix code style
    5e39d7e [Lisa] Add parameters
    1b5e42f [Lisa] remove README.md
    633b1e6 [Lisa] Add Convolutional Click Prediction Model
---
 .../pysubmarine/example/tensorflow/ccpm/ccpm.json  | 36 ++++++++++++
 .../example/tensorflow/ccpm/ccpm_distributed.json  | 36 ++++++++++++
 .../tensorflow/ccpm/run_ccpm.py}                   | 22 ++++++--
 .../submarine/ml/tensorflow/layers/core.py         | 53 +++++++++++++++++
 .../submarine/ml/tensorflow/model/__init__.py      |  3 +-
 .../submarine/ml/tensorflow/model/ccpm.py          | 66 ++++++++++++++++++++++
 .../submarine/ml/tensorflow/parameters.py          |  2 +
 .../ml/tensorflow/model/test_ccpm.py}              | 13 +++--
 8 files changed, 222 insertions(+), 9 deletions(-)

diff --git a/submarine-sdk/pysubmarine/example/tensorflow/ccpm/ccpm.json b/submarine-sdk/pysubmarine/example/tensorflow/ccpm/ccpm.json
new file mode 100644
index 0000000..ad54a38
--- /dev/null
+++ b/submarine-sdk/pysubmarine/example/tensorflow/ccpm/ccpm.json
@@ -0,0 +1,36 @@
+{
+  "input": {
+    "train_data": ["../../data/tr.libsvm"],
+    "valid_data": ["../../data/va.libsvm"],
+    "test_data": ["../../data/te.libsvm"],
+    "type": "libsvm"
+  },
+  "output": {
+    "save_model_dir": "./experiment",
+    "metric": "auc"
+  },
+  "training": {
+    "batch_size" : 512,
+    "field_size": 39,
+    "num_epochs": 3,
+    "feature_size": 117581,
+    "embedding_size": 256,
+    "learning_rate": 0.0005,
+    "batch_norm_decay": 0.9,
+    "l2_reg": 0.0001,
+    "deep_layers": [400, 400, 400],
+    "conv_kernel_width": [6,5],
+    "conv_filters": [4,4],
+    "dropout": [0.3, 0.3, 0.3],
+    "batch_norm": false,
+    "optimizer": "adam",
+    "log_steps": 10,
+    "seed": 77,
+    "mode": "local"
+  },
+  "resource": {
+    "num_cpu": 4,
+    "num_gpu": 0,
+    "num_thread": 0
+  }
+}
diff --git a/submarine-sdk/pysubmarine/example/tensorflow/ccpm/ccpm_distributed.json b/submarine-sdk/pysubmarine/example/tensorflow/ccpm/ccpm_distributed.json
new file mode 100644
index 0000000..d4f93b0
--- /dev/null
+++ b/submarine-sdk/pysubmarine/example/tensorflow/ccpm/ccpm_distributed.json
@@ -0,0 +1,36 @@
+{
+  "input": {
+    "train_data": ["hdfs:///user/submarine/data/tr.libsvm"],
+    "valid_data": ["hdfs:///user/submarine/data/va.libsvm"],
+    "test_data": ["hdfs:///user/submarine/data/te.libsvm"],
+    "type": "libsvm"
+  },
+  "output": {
+    "save_model_dir": "hdfs:///user/submarine/deepfm",
+    "metric": "auc"
+  },
+  "training": {
+    "batch_size" : 512,
+    "field_size": 39,
+    "num_epochs": 3,
+    "feature_size": 117581,
+    "embedding_size": 256,
+    "learning_rate": 0.0005,
+    "batch_norm_decay": 0.9,
+    "l2_reg": 0.0001,
+    "deep_layers": [400, 400, 400],
+    "conv_kernel_width": [6,5],
+    "conv_filters": [4,4],
+    "dropout": [0.3, 0.3, 0.3],
+    "batch_norm": false,
+    "optimizer": "adam",
+    "log_steps": 10,
+    "seed": 77,
+    "mode": "distributed"
+  },
+  "resource": {
+    "num_cpu": 4,
+    "num_gpu": 0,
+    "num_thread": 0
+  }
+}
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py b/submarine-sdk/pysubmarine/example/tensorflow/ccpm/run_ccpm.py
similarity index 55%
copy from submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
copy to submarine-sdk/pysubmarine/example/tensorflow/ccpm/run_ccpm.py
index febeb99..88acfca 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
+++ b/submarine-sdk/pysubmarine/example/tensorflow/ccpm/run_ccpm.py
@@ -13,8 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .deepfm import DeepFM
-from .fm import FM
-from .nfm import NFM
+from submarine.ml.tensorflow.model import CCPM
+import argparse
 
-__all__ = ["DeepFM", "FM", "NFM"]
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-conf", help="a JSON configuration file for CCPM", type=str)
+    parser.add_argument("-task_type", default='train',
+                        help="train or evaluate, by default is train")
+    args = parser.parse_args()
+    json_path = args.conf
+    task_type = args.task_type
+
+    model = CCPM(json_path=json_path)
+
+    if task_type == 'train':
+        model.train()
+    if task_type == 'evaluate':
+        result = model.evaluate()
+        print("Model metrics : ", result)
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
index 8afb9e2..ec0f18f 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import tensorflow as tf
+from tensorflow.keras.layers import Layer
 
 
 def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay):
@@ -181,3 +182,55 @@ def fm_layer(inputs, **kwargs):
         square_sum = tf.reduce_sum(tf.square(inputs), 1)
         fm_out = 0.5 * tf.reduce_sum(tf.subtract(sum_square, square_sum), 1)
     return fm_out
+
+
+class KMaxPooling(Layer):
+    """K Max pooling that selects the k biggest value along the specific axis.
+      Input shape
+        -  nD tensor with shape: ``(batch_size, ..., input_dim)``.
+      Output shape
+        - nD tensor with shape: ``(batch_size, ..., output_dim)``.
+      Arguments
+        - **k**: positive integer, number of top elements to look for along the ``axis`` dimension.
+        - **axis**: positive integer, the dimension to look for elements.
+     """
+
+    def __init__(self, k=1, axis=-1, **kwargs):
+
+        self.dims = 1
+        self.k = k
+        self.axis = axis
+        super(KMaxPooling, self).__init__(**kwargs)
+
+    def build(self, input_shape):
+
+        if self.axis < 1 or self.axis > len(input_shape):
+            raise ValueError("axis must be 1~%d,now is %d" %
+                             (len(input_shape), self.axis))
+
+        if self.k < 1 or self.k > input_shape[self.axis]:
+            raise ValueError("k must be in 1 ~ %d,now k is %d" %
+                             (input_shape[self.axis], self.k))
+        self.dims = len(input_shape)
+        super(KMaxPooling, self).build(input_shape)
+
+    def call(self, inputs):
+
+        perm = list(range(self.dims))
+        perm[-1], perm[self.axis] = perm[self.axis], perm[-1]
+        shifted_input = tf.transpose(inputs, perm)
+
+        top_k = tf.nn.top_k(shifted_input, k=self.k, sorted=True, name=None)[0]
+        output = tf.transpose(top_k, perm)
+
+        return output
+
+    def compute_output_shape(self, input_shape):
+        output_shape = list(input_shape)
+        output_shape[self.axis] = self.k
+        return tuple(output_shape)
+
+    def get_config(self, ):
+        config = {'k': self.k, 'axis': self.axis}
+        base_config = super(KMaxPooling, self).get_config()
+        return dict(list(base_config.items()) + list(config.items()))
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
index febeb99..7f561f9 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
@@ -16,5 +16,6 @@
 from .deepfm import DeepFM
 from .fm import FM
 from .nfm import NFM
+from .ccpm import CCPM
 
-__all__ = ["DeepFM", "FM", "NFM"]
+__all__ = ["DeepFM", "FM", "NFM", "CCPM"]
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/ccpm.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/ccpm.py
new file mode 100644
index 0000000..de41adf
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/ccpm.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import tensorflow as tf
+
+from submarine.ml.tensorflow.layers.core import (dnn_layer, embedding_layer, linear_layer,
+                                                 KMaxPooling)
+from submarine.ml.tensorflow.model.base_tf_model import BaseTFModel
+from submarine.utils.tf_utils import get_estimator_spec
+
+logger = logging.getLogger(__name__)
+
+
+class CCPM(BaseTFModel):
+    def model_fn(self, features, labels, mode, params):
+        super().model_fn(features, labels, mode, params)
+
+        if len(params['training']['conv_kernel_width']) != len(params['training']['conv_filters']):
+            raise ValueError(
+                "conv_kernel_width must have same element with conv_filters")
+
+        linear_logit = linear_layer(features, **params['training'])
+        embedding_outputs = embedding_layer(features, **params['training'])
+        conv_filters = params['training']['conv_filters']
+        conv_kernel_width = params['training']['conv_kernel_width']
+
+        n = params['training']['embedding_size']
+        conv_filters_len = len(conv_filters)
+        conv_input = tf.concat(embedding_outputs, axis=1)
+
+        pooling_result = tf.keras.layers.Lambda(
+            lambda x: tf.expand_dims(x, axis=3))(conv_input)
+
+        for i in range(1, conv_filters_len + 1):
+            filters = conv_filters[i - 1]
+            width = conv_kernel_width[i - 1]
+            p = pow(i / conv_filters_len, conv_filters_len - i)
+            k = max(1, int((1 - p) * n)) if i < conv_filters_len else 3
+
+            conv_result = tf.keras.layers.Conv2D(filters=filters, kernel_size=(width, 1),
+                                                 strides=(1, 1), padding='same',
+                                                 activation='tanh', use_bias=True, )(pooling_result)
+
+            pooling_result = KMaxPooling(
+                k=min(k, int(conv_result.shape[1])), axis=1)(conv_result)
+
+        flatten_result = tf.keras.layers.Flatten()(pooling_result)
+        deep_logit = dnn_layer(flatten_result, mode, **params['training'])
+
+        with tf.variable_scope("CCPM_out"):
+            logit = linear_logit + deep_logit
+
+        return get_estimator_spec(logit, labels, mode, params)
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
index a35312d..0815a2d 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
@@ -28,6 +28,8 @@ default_parameters = {
         "batch_norm_decay": 0.9,
         "l2_reg": 0.0001,
         "deep_layers": [400, 400, 400],
+        "conv_kernel_width": [6, 5],
+        "conv_filters": [4, 4],
         "dropout": [0.3, 0.3, 0.3],
         "batch_norm": "false",
         "optimizer": "adam",
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_ccpm.py
similarity index 78%
copy from submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
copy to submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_ccpm.py
index febeb99..536c049 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/__init__.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_ccpm.py
@@ -13,8 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .deepfm import DeepFM
-from .fm import FM
-from .nfm import NFM
+from submarine.ml.tensorflow.model import CCPM
 
-__all__ = ["DeepFM", "FM", "NFM"]
+
+def test_run_ccpm(get_model_param):
+    params = get_model_param
+
+    model = CCPM(model_params=params)
+    model.train()
+    model.evaluate()
+    model.predict()


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org