You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@teaclave.apache.org by ms...@apache.org on 2020/10/16 18:56:57 UTC

[incubator-teaclave] branch master updated: [function] Add pca function (#424)

This is an automated email from the ASF dual-hosted git repository.

mssun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git


The following commit(s) were added to refs/heads/master by this push:
     new 752426e  [function] Add pca function (#424)
752426e is described below

commit 752426eb6553674cd65b1ccce117c9d3418304e4
Author: zEqueue <53...@users.noreply.github.com>
AuthorDate: Sat Oct 17 02:56:50 2020 +0800

    [function] Add pca function (#424)
---
 binder/src/binder.rs                               |   2 +
 executor/Cargo.toml                                |   2 +
 executor/src/builtin.rs                            |   7 +-
 function/README.md                                 |   3 +-
 function/src/lib.rs                                |   3 +
 function/src/principal_components_analysis.rs      | 179 +++++++++++++++++++++
 services/utils/service_app_utils/src/lib.rs        |   2 +
 .../expected_result.txt                            |  90 +++++++++++
 .../princopal_components_analysis/input.txt        |  90 +++++++++++
 9 files changed, 376 insertions(+), 2 deletions(-)

diff --git a/binder/src/binder.rs b/binder/src/binder.rs
index 44838b2..f60efa1 100644
--- a/binder/src/binder.rs
+++ b/binder/src/binder.rs
@@ -77,6 +77,8 @@ impl TeeBinder {
         }
     }
 
+    /// # Safety
+    /// Force to destroy current enclave.
     pub unsafe fn destroy(&self) {
         let _ = sgx_destroy_enclave(self.enclave.geteid());
     }
diff --git a/executor/Cargo.toml b/executor/Cargo.toml
index df458ed..76ea1a9 100644
--- a/executor/Cargo.toml
+++ b/executor/Cargo.toml
@@ -38,6 +38,7 @@ full_builtin_function = [
   "builtin_ordered_set_intersect",
   "builtin_rsa_sign",
   "builtin_face_detection",
+  "builtin_principal_components_analysis",
 ]
 
 builtin_echo = []
@@ -50,6 +51,7 @@ builtin_private_join_and_compute = []
 builtin_ordered_set_intersect = []
 builtin_rsa_sign = []
 builtin_face_detection = []
+builtin_principal_components_analysis = []
 
 [dependencies]
 log           = { version = "0.4.6", features = ["release_max_level_info"] }
diff --git a/executor/src/builtin.rs b/executor/src/builtin.rs
index 510db29..afc9862 100644
--- a/executor/src/builtin.rs
+++ b/executor/src/builtin.rs
@@ -20,7 +20,8 @@ use std::prelude::v1::*;
 
 use teaclave_function::{
     Echo, FaceDetection, GbdtPredict, GbdtTrain, LogisticRegressionPredict,
-    LogisticRegressionTrain, OnlineDecrypt, OrderedSetIntersect, PrivateJoinAndCompute, RsaSign,
+    LogisticRegressionTrain, OnlineDecrypt, OrderedSetIntersect, PrincipalComponentsAnalysis,
+    PrivateJoinAndCompute, RsaSign,
 };
 use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveExecutor};
 
@@ -58,6 +59,10 @@ impl TeaclaveExecutor for BuiltinFunctionExecutor {
             OrderedSetIntersect::NAME => OrderedSetIntersect::new().run(arguments, runtime),
             #[cfg(feature = "builtin_rsa_sign")]
             RsaSign::NAME => RsaSign::new().run(arguments, runtime),
+            #[cfg(feature = "builtin_principal_components_analysis")]
+            PrincipalComponentsAnalysis::NAME => {
+                PrincipalComponentsAnalysis::new().run(arguments, runtime)
+            }
             #[cfg(feature = "builtin_face_detection")]
             FaceDetection::NAME => FaceDetection::new().run(arguments, runtime),
             _ => bail!("Function not found."),
diff --git a/function/README.md b/function/README.md
index af104d1..d1bbe01 100644
--- a/function/README.md
+++ b/function/README.md
@@ -25,8 +25,9 @@ Currently, we have these built-in functions:
     elements in the intersection. Users should calculate hash values of each item
     and upload them as a sorted list.
   - `builtin-rsa-sign`: Signing data with RSA key.
-  - `builtin-face-detection`: an implementation of Funnel-Structured cascade,
+  - `builtin-face-detection`: An implementation of Funnel-Structured cascade,
     which is designed for real-time multi-view face detection.
+  - `builtin-principal-components-analysis`: Example to calculate PCA.
   
 The function arguments are in JSON format and can be serialized to a Rust struct
 very easily. You can learn more about supported arguments in the implementation
diff --git a/function/src/lib.rs b/function/src/lib.rs
index a733017..6bb6ce1 100644
--- a/function/src/lib.rs
+++ b/function/src/lib.rs
@@ -30,6 +30,7 @@ mod logistic_regression_predict;
 mod logistic_regression_train;
 mod online_decrypt;
 mod ordered_set_intersect;
+mod principal_components_analysis;
 mod private_join_and_compute;
 mod rsa_sign;
 
@@ -41,6 +42,7 @@ pub use logistic_regression_predict::LogisticRegressionPredict;
 pub use logistic_regression_train::LogisticRegressionTrain;
 pub use online_decrypt::OnlineDecrypt;
 pub use ordered_set_intersect::OrderedSetIntersect;
+pub use principal_components_analysis::PrincipalComponentsAnalysis;
 pub use private_join_and_compute::PrivateJoinAndCompute;
 pub use rsa_sign::RsaSign;
 
@@ -61,6 +63,7 @@ pub mod tests {
             ordered_set_intersect::tests::run_tests(),
             rsa_sign::tests::run_tests(),
             face_detection::tests::run_tests(),
+            principal_components_analysis::tests::run_tests(),
         )
     }
 }
diff --git a/function/src/principal_components_analysis.rs b/function/src/principal_components_analysis.rs
new file mode 100644
index 0000000..398c94c
--- /dev/null
+++ b/function/src/principal_components_analysis.rs
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[cfg(feature = "mesalock_sgx")]
+use std::prelude::v1::*;
+
+use std::convert::TryFrom;
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
+
+use teaclave_types::{FunctionArguments, FunctionRuntime};
+
+use rusty_machine::learning::pca::PCA;
+use rusty_machine::learning::UnSupModel;
+use rusty_machine::linalg;
+use rusty_machine::linalg::BaseMatrix;
+const IN_DATA: &str = "input_data";
+const OUT_RESULT: &str = "output_data";
+
+#[derive(Default)]
+pub struct PrincipalComponentsAnalysis;
+
+#[derive(serde::Deserialize)]
+struct PrincipalComponentsAnalysisArguments {
+    n: usize,
+    center: bool,
+    feature_size: usize,
+}
+
+impl TryFrom<FunctionArguments> for PrincipalComponentsAnalysisArguments {
+    type Error = anyhow::Error;
+
+    fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
+        use anyhow::Context;
+        serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments")
+    }
+}
+
+impl PrincipalComponentsAnalysis {
+    pub const NAME: &'static str = "builtin_principal_components_analysis";
+
+    pub fn new() -> Self {
+        Default::default()
+    }
+
+    pub fn run(
+        &self,
+        arguments: FunctionArguments,
+        runtime: FunctionRuntime,
+    ) -> anyhow::Result<String> {
+        let args = PrincipalComponentsAnalysisArguments::try_from(arguments)?;
+        let input = runtime.open_input(IN_DATA)?;
+        let (flattend_features, targets) = parse_input_data(input, args.feature_size)?;
+
+        let data_size = targets.len();
+        let input_features = linalg::Matrix::new(data_size, args.feature_size, flattend_features);
+
+        let mut model = PCA::new(args.n, args.center);
+        model.train(&input_features)?;
+
+        let predict_result = model.predict(&input_features)?;
+
+        let mut output = runtime.create_output(OUT_RESULT)?;
+        for i in 0..predict_result.rows() {
+            for j in 0..predict_result.cols() {
+                if j == predict_result.cols() - 1 {
+                    write!(&mut output, "{:?}", predict_result[[i, j]])?;
+                } else {
+                    write!(&mut output, "{:?},", predict_result[[i, j]])?;
+                }
+            }
+            writeln!(&mut output)?;
+        }
+
+        Ok(format!(
+            "transform {} rows * {} cols lines of data.",
+            predict_result.rows(),
+            predict_result.cols()
+        ))
+    }
+}
+
+fn parse_input_data(
+    input: impl io::Read,
+    feature_size: usize,
+) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
+    let reader = BufReader::new(input);
+    let mut targets = Vec::<f64>::new();
+    let mut features = Vec::new();
+
+    for line_result in reader.lines() {
+        let line = line_result?;
+        let trimed_line = line.trim();
+        anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
+
+        let mut v: Vec<f64> = trimed_line
+            .split(',')
+            .map(|x| x.parse::<f64>())
+            .collect::<std::result::Result<_, _>>()?;
+
+        anyhow::ensure!(
+            v.len() == feature_size + 1,
+            "Data format error: column len = {}, expected = {}",
+            v.len(),
+            feature_size + 1
+        );
+
+        let label = v.swap_remove(feature_size);
+        targets.push(label);
+        features.extend(v);
+    }
+
+    Ok((features, targets))
+}
+
+#[cfg(feature = "enclave_unit_test")]
+pub mod tests {
+    use super::*;
+    use serde_json::json;
+    use std::path::Path;
+    use std::untrusted::fs;
+    use teaclave_crypto::*;
+    use teaclave_runtime::*;
+    use teaclave_test_utils::*;
+    use teaclave_types::*;
+
+    pub fn run_tests() -> bool {
+        run_tests!(test_pca_predict)
+    }
+
+    fn test_pca_predict() {
+        let args = FunctionArguments::from_json(json!({
+            "n": 2,
+            "feature_size": 4,
+            "center":true
+        }))
+        .unwrap();
+
+        let base = Path::new("fixtures/functions/princopal_components_analysis");
+
+        let input_data_file = base.join("input.txt");
+        let output_data_file = base.join("result.txt");
+        let expected_output = base.join("expected_result.txt");
+
+        let input_files = StagedFiles::new(hashmap!(
+            IN_DATA =>
+            StagedFileInfo::new(&input_data_file, TeaclaveFile128Key::random(), FileAuthTag::mock()),
+        ));
+
+        let output_files = StagedFiles::new(hashmap!(
+            OUT_RESULT =>
+            StagedFileInfo::new(&output_data_file, TeaclaveFile128Key::random(), FileAuthTag::mock()),
+        ));
+
+        let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
+        let summary = PrincipalComponentsAnalysis::new()
+            .run(args, runtime)
+            .unwrap();
+        assert_eq!(summary, "transform 90 rows * 2 cols lines of data.");
+
+        let result = fs::read_to_string(&output_data_file).unwrap();
+        let expected = fs::read_to_string(&expected_output).unwrap();
+        assert_eq!(&result[..], &expected[..]);
+    }
+}
diff --git a/services/utils/service_app_utils/src/lib.rs b/services/utils/service_app_utils/src/lib.rs
index 382dbdc..87be6a7 100644
--- a/services/utils/service_app_utils/src/lib.rs
+++ b/services/utils/service_app_utils/src/lib.rs
@@ -54,6 +54,8 @@ impl TeaclaveServiceLauncher {
         self.tee.finalize();
     }
 
+    /// # Safety
+    /// Force to destroy current enclave.
     pub unsafe fn destroy(&self) {
         self.tee.destroy();
     }
diff --git a/tests/fixtures/functions/princopal_components_analysis/expected_result.txt b/tests/fixtures/functions/princopal_components_analysis/expected_result.txt
new file mode 100644
index 0000000..6aa5cf8
--- /dev/null
+++ b/tests/fixtures/functions/princopal_components_analysis/expected_result.txt
@@ -0,0 +1,90 @@
+-2.4154563605850874,-0.20881505750478446
+-2.449295412107796,0.27299651162552563
+-2.6227634353481397,0.2569251860576255
+-2.4803771707374023,0.4289823305744881
+-2.459461856392966,-0.208306079588436
+-2.0085238353978125,-0.6114019714802297
+-2.5526923066246896,0.2200715379096242
+-2.3585271275989363,-0.05384719438852015
+-2.6228897996388616,0.6876502126052235
+-2.4079356705593935,0.2071680966917535
+-2.236677084460471,-0.5366767193515287
+-2.345603390420664,0.10162964664409319
+-2.5219205008224166,0.327957397006823
+-2.959624831811964,0.6226584518850686
+-2.3711981361957664,-1.0707740885163295
+-2.1093321999011203,-1.1965858832277436
+-2.35035196189616,-0.6797590970826204
+-2.379097919093624,-0.1945821630700737
+-1.9283434866350828,-0.7645314034633046
+-2.3168574332232006,-0.3835472873684529
+-2.0425464754571334,-0.29644375085135616
+-2.2727601431470164,-0.3006295910340834
+-2.9463565717839426,0.00011191407320459978
+-2.0345322440674942,0.02252107380045902
+-2.0892322955469034,0.15289749084588627
+-2.242114701635465,0.23798129461071355
+-2.200353212991423,-0.008292124118500943
+-2.293732681737343,-0.2609195559201949
+-2.371450864777209,-0.2093240354211336
+-2.3663923404743787,0.3081930302594185
+-2.3223868446665006,0.3076840523430701
+-2.1407436557233805,-0.30215652478313
+-2.3765242147371333,-0.6872612617528595
+-2.3245617117855066,-0.966383790066429
+-2.4079356705593935,0.2071680966917535
+-2.599420525303255,0.03225456520900455
+-2.355846803317045,-0.5026794581694134
+-2.4079356705593935,0.2071680966917535
+-2.7160856798481694,0.6018761293049668
+-2.322260480375779,-0.1230409742045277
+-2.500821597941368,-0.14247766465466383
+-2.5892886510405004,1.0277088572212831
+-2.7315633770176113,0.46450652550564864
+-2.1353751785932173,-0.04851113714873852
+-1.9386708652333895,-0.30095726733135136
+3.6644800957337753,-0.43616791631609914
+0.7839434335499555,1.3596816610630107
+3.197972583300874,-0.25386830247434217
+2.5840429361778803,0.35058717701735165
+3.191129023175738,-0.6000164377141677
+1.9318553632493356,-0.08297805758300825
+2.068838580707495,0.34958540009777905
+2.434319570078061,-0.07059977319461158
+1.6104370939325388,0.9342765128420694
+1.8543779929919721,0.7330491864665256
+2.1755781037497424,0.06309296833832703
+2.216444303934199,0.09428288294927929
+3.7606514273952145,-1.0235179354940704
+4.060790115152506,-0.15088885546225636
+1.5606613738988573,0.8615851069194693
+2.698739466363877,-0.214518805139321
+1.465496869330631,0.780326605558503
+3.7653230302594998,-0.36513570536690204
+1.6543853654945266,0.33618915987668574
+2.5457504403499147,-0.1732818362763856
+2.882118295449793,-0.4419967725591048
+1.524922838062061,0.3196088563924367
+1.5586355252940485,0.2685223138097245
+2.390187709979461,0.36063423126933486
+2.6539650463871345,-0.3672715204304037
+3.107422457351042,-0.274322948757176
+3.5040967439848414,-1.241639128197301
+2.4265461514709243,0.3748671257040457
+1.7084852556845886,0.25898423747408994
+2.0423563750391875,0.6059549134737012
+3.346178467816925,-0.5623423142656742
+2.416563295722366,0.06041788299152115
+2.1724388081263206,0.09479186086562776
+1.4369118464463038,0.32062681222513423
+2.377390337091911,-0.22556763631087645
+2.584846430369159,-0.01030283057353329
+2.193736125201076,-0.2483696916432482
+1.6803246341193767,0.7305695161926307
+2.833386882389893,-0.11114646252211743
+2.691184206315768,-0.11635025853754255
+2.21439871096407,-0.024208048710975677
+1.791678535780019,0.5048809395113123
+2.0327900920433652,0.07148082761690772
+2.1730241437585724,0.08120020557162264
+1.6570162940969073,0.44108843624293476
diff --git a/tests/fixtures/functions/princopal_components_analysis/input.txt b/tests/fixtures/functions/princopal_components_analysis/input.txt
new file mode 100644
index 0000000..3568733
--- /dev/null
+++ b/tests/fixtures/functions/princopal_components_analysis/input.txt
@@ -0,0 +1,90 @@
+5.1,3.5,1.4,0.2,0
+4.9,3.0,1.4,0.2,0
+4.7,3.2,1.3,0.2,0
+4.6,3.1,1.5,0.2,0
+5.0,3.6,1.4,0.2,0
+5.4,3.9,1.7,0.4,0
+4.6,3.4,1.4,0.3,0
+5.0,3.4,1.5,0.2,0
+4.4,2.9,1.4,0.2,0
+4.9,3.1,1.5,0.1,0
+5.4,3.7,1.5,0.2,0
+4.8,3.4,1.6,0.2,0
+4.8,3.0,1.4,0.1,0
+4.3,3.0,1.1,0.1,0
+5.8,4.0,1.2,0.2,0
+5.7,4.4,1.5,0.4,0
+5.4,3.9,1.3,0.4,0
+5.1,3.5,1.4,0.3,0
+5.7,3.8,1.7,0.3,0
+5.1,3.8,1.5,0.3,0
+5.4,3.4,1.7,0.2,0
+5.1,3.7,1.5,0.4,0
+4.6,3.6,1.0,0.2,0
+5.1,3.3,1.7,0.5,0
+4.8,3.4,1.9,0.2,0
+5.0,3.0,1.6,0.2,0
+5.0,3.4,1.6,0.4,0
+5.2,3.5,1.5,0.2,0
+5.2,3.4,1.4,0.2,0
+4.7,3.2,1.6,0.2,0
+4.8,3.1,1.6,0.2,0
+5.4,3.4,1.5,0.4,0
+5.2,4.1,1.5,0.1,0
+5.5,4.2,1.4,0.2,0
+4.9,3.1,1.5,0.1,0
+5.0,3.2,1.2,0.2,0
+5.5,3.5,1.3,0.2,0
+4.9,3.1,1.5,0.1,0
+4.4,3.0,1.3,0.2,0
+5.1,3.4,1.5,0.2,0
+5.0,3.5,1.3,0.3,0
+4.5,2.3,1.3,0.3,0
+4.4,3.2,1.3,0.2,0
+5.0,3.5,1.6,0.6,0
+5.1,3.8,1.9,0.4,0
+7.6,3.0,6.6,2.1,1
+4.9,2.5,4.5,1.7,1
+7.3,2.9,6.3,1.8,1
+6.7,2.5,5.8,1.8,1
+7.2,3.6,6.1,2.5,1
+6.5,3.2,5.1,2.0,1
+6.4,2.7,5.3,1.9,1
+6.8,3.0,5.5,2.1,1
+5.7,2.5,5.0,2.0,1
+5.8,2.8,5.1,2.4,1
+6.4,3.2,5.3,2.3,1
+6.5,3.0,5.5,1.8,1
+7.7,3.8,6.7,2.2,1
+7.7,2.6,6.9,2.3,1
+6.0,2.2,5.0,1.5,1
+6.9,3.2,5.7,2.3,1
+5.6,2.8,4.9,2.0,1
+7.7,2.8,6.7,2.0,1
+6.3,2.7,4.9,1.8,1
+6.7,3.3,5.7,2.1,1
+7.2,3.2,6.0,1.8,1
+6.2,2.8,4.8,1.8,1
+6.1,3.0,4.9,1.8,1
+6.4,2.8,5.6,2.1,1
+7.2,3.0,5.8,1.6,1
+7.4,2.8,6.1,1.9,1
+7.9,3.8,6.4,2.0,1
+6.4,2.8,5.6,2.2,1
+6.3,2.8,5.1,1.5,1
+6.1,2.6,5.6,1.4,1
+7.7,3.0,6.1,2.3,1
+6.3,3.4,5.6,2.4,1
+6.4,3.1,5.5,1.8,1
+6.0,3.0,4.8,1.8,1
+6.9,3.1,5.4,2.1,1
+6.7,3.1,5.6,2.4,1
+6.9,3.1,5.1,2.3,1
+5.8,2.7,5.1,1.9,1
+6.8,3.2,5.9,2.3,1
+6.7,3.3,5.7,2.5,1
+6.7,3.0,5.2,2.3,1
+6.3,2.5,5.0,1.9,1
+6.5,3.0,5.2,2.0,1
+6.2,3.4,5.4,2.3,1
+5.9,3.0,5.1,1.8,1
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@teaclave.apache.org
For additional commands, e-mail: commits-help@teaclave.apache.org