You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@teaclave.apache.org by ms...@apache.org on 2020/04/07 22:32:51 UTC

[incubator-teaclave] branch develop updated: [function] Better function argument handling

This is an automated email from the ASF dual-hosted git repository.

mssun pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git


The following commit(s) were added to refs/heads/develop by this push:
     new 3329697  [function] Better function argument handling
3329697 is described below

commit 332969792de3e1869de29c5dc22f8da1868c8018
Author: Mingshen Sun <bo...@mssun.me>
AuthorDate: Tue Apr 7 15:18:44 2020 -0700

    [function] Better function argument handling
---
 attestation/src/verifier.rs                    |  1 +
 function/src/echo.rs                           | 21 +++++--
 function/src/gbdt_prediction.rs                | 15 +++--
 function/src/gbdt_training.rs                  | 76 ++++++++++++++++++--------
 function/src/logistic_regression_prediction.rs | 18 +++---
 function/src/logistic_regression_training.rs   | 49 ++++++++++++-----
 function/src/mesapy.rs                         | 12 ++--
 types/src/staged_function.rs                   |  4 +-
 8 files changed, 130 insertions(+), 66 deletions(-)

diff --git a/attestation/src/verifier.rs b/attestation/src/verifier.rs
index d688f18..c710669 100644
--- a/attestation/src/verifier.rs
+++ b/attestation/src/verifier.rs
@@ -119,6 +119,7 @@ impl rustls::ServerCertVerifier for AttestationReportVerifier {
 
 impl rustls::ClientCertVerifier for AttestationReportVerifier {
     fn offer_client_auth(&self) -> bool {
+        // If test_mode is on, then disable TLS client authentication.
         !cfg!(test_mode)
     }
 
diff --git a/function/src/echo.rs b/function/src/echo.rs
index 8dbac97..3cc15a5 100644
--- a/function/src/echo.rs
+++ b/function/src/echo.rs
@@ -18,19 +18,32 @@
 #[cfg(feature = "mesalock_sgx")]
 use std::prelude::v1::*;
 
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use std::convert::TryFrom;
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
 
 #[derive(Default)]
 pub struct Echo;
 
+struct EchoArguments {
+    message: String,
+}
+
+impl TryFrom<FunctionArguments> for EchoArguments {
+    type Error = anyhow::Error;
+
+    fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
+        let message = arguments.get("message")?.to_string();
+        Ok(Self { message })
+    }
+}
+
 impl TeaclaveFunction for Echo {
     fn execute(
         &self,
-        _runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
+        _runtime: FunctionRuntime,
         arguments: FunctionArguments,
     ) -> anyhow::Result<String> {
-        let message = arguments.get("message")?.to_string();
+        let message = EchoArguments::try_from(arguments)?.message;
         Ok(message)
     }
 }
diff --git a/function/src/gbdt_prediction.rs b/function/src/gbdt_prediction.rs
index 16233e5..57ec08b 100644
--- a/function/src/gbdt_prediction.rs
+++ b/function/src/gbdt_prediction.rs
@@ -21,24 +21,23 @@ use std::prelude::v1::*;
 use std::format;
 use std::io::{self, BufRead, BufReader, Write};
 
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
 
 use gbdt::decision_tree::Data;
 use gbdt::gradient_boost::GBDT;
 
+const IN_MODEL: &str = "model_file";
+const IN_DATA: &str = "data_file";
+const OUT_RESULT: &str = "result_file";
+
 #[derive(Default)]
 pub struct GbdtPrediction;
 
-static IN_MODEL: &str = "model_file";
-static IN_DATA: &str = "data_file";
-static OUT_RESULT: &str = "result_file";
-
 impl TeaclaveFunction for GbdtPrediction {
     fn execute(
         &self,
-        runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
-        _args: FunctionArguments,
+        runtime: FunctionRuntime,
+        _arguments: FunctionArguments,
     ) -> anyhow::Result<String> {
         let mut json_model = String::new();
         let mut f = runtime.open_input(IN_MODEL)?;
diff --git a/function/src/gbdt_training.rs b/function/src/gbdt_training.rs
index 59cacf3..e911834 100644
--- a/function/src/gbdt_training.rs
+++ b/function/src/gbdt_training.rs
@@ -21,26 +21,35 @@ use std::prelude::v1::*;
 use std::format;
 use std::io::{self, BufRead, BufReader, Write};
 
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use std::convert::TryFrom;
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
 
 use gbdt::config::Config;
 use gbdt::decision_tree::Data;
 use gbdt::gradient_boost::GBDT;
 
+const IN_DATA: &str = "training_data";
+const OUT_MODEL: &str = "trained_model";
+
 #[derive(Default)]
 pub struct GbdtTraining;
 
-static IN_DATA: &str = "training_data";
-static OUT_MODEL: &str = "trained_model";
+struct GbdtTrainingArguments {
+    feature_size: usize,
+    max_depth: u32,
+    iterations: usize,
+    shrinkage: f32,
+    feature_sample_ratio: f64,
+    data_sample_ratio: f64,
+    min_leaf_size: usize,
+    loss: String,
+    training_optimization_level: u8,
+}
 
-impl TeaclaveFunction for GbdtTraining {
-    fn execute(
-        &self,
-        runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
-        arguments: FunctionArguments,
-    ) -> anyhow::Result<String> {
-        log::debug!("start traning...");
+impl TryFrom<FunctionArguments> for GbdtTrainingArguments {
+    type Error = anyhow::Error;
+
+    fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
         let feature_size = arguments.get("feature_size")?.as_usize()?;
         let max_depth = arguments.get("max_depth")?.as_u32()?;
         let iterations = arguments.get("iterations")?.as_usize()?;
@@ -48,27 +57,50 @@ impl TeaclaveFunction for GbdtTraining {
         let feature_sample_ratio = arguments.get("feature_sample_ratio")?.as_f64()?;
         let data_sample_ratio = arguments.get("data_sample_ratio")?.as_f64()?;
         let min_leaf_size = arguments.get("min_leaf_size")?.as_usize()?;
-        let loss = arguments.get("loss")?.as_str();
+        let loss = arguments.get("loss")?.as_str().to_owned();
         let training_optimization_level = arguments.get("training_optimization_level")?.as_u8()?;
 
+        Ok(Self {
+            feature_size,
+            max_depth,
+            iterations,
+            shrinkage,
+            feature_sample_ratio,
+            data_sample_ratio,
+            min_leaf_size,
+            loss,
+            training_optimization_level,
+        })
+    }
+}
+
+impl TeaclaveFunction for GbdtTraining {
+    fn execute(
+        &self,
+        runtime: FunctionRuntime,
+        arguments: FunctionArguments,
+    ) -> anyhow::Result<String> {
+        log::debug!("start traning...");
+        let args = GbdtTrainingArguments::try_from(arguments)?;
+
         log::debug!("open input...");
         // read input
         let training_file = runtime.open_input(IN_DATA)?;
-        let mut train_dv = parse_training_data(training_file, feature_size)?;
+        let mut train_dv = parse_training_data(training_file, args.feature_size)?;
         let data_size = train_dv.len();
 
         // init gbdt config
         let mut cfg = Config::new();
         cfg.set_debug(false);
-        cfg.set_feature_size(feature_size);
-        cfg.set_max_depth(max_depth);
-        cfg.set_iterations(iterations);
-        cfg.set_shrinkage(shrinkage);
-        cfg.set_loss(loss);
-        cfg.set_min_leaf_size(min_leaf_size);
-        cfg.set_data_sample_ratio(data_sample_ratio);
-        cfg.set_feature_sample_ratio(feature_sample_ratio);
-        cfg.set_training_optimization_level(training_optimization_level);
+        cfg.set_feature_size(args.feature_size);
+        cfg.set_max_depth(args.max_depth);
+        cfg.set_iterations(args.iterations);
+        cfg.set_shrinkage(args.shrinkage);
+        cfg.set_loss(&args.loss);
+        cfg.set_min_leaf_size(args.min_leaf_size);
+        cfg.set_data_sample_ratio(args.data_sample_ratio);
+        cfg.set_feature_sample_ratio(args.feature_sample_ratio);
+        cfg.set_training_optimization_level(args.training_optimization_level);
 
         // start training
         let mut gbdt_train_mod = GBDT::new(&cfg);
diff --git a/function/src/logistic_regression_prediction.rs b/function/src/logistic_regression_prediction.rs
index a7b3a98..7aacd79 100644
--- a/function/src/logistic_regression_prediction.rs
+++ b/function/src/logistic_regression_prediction.rs
@@ -18,27 +18,27 @@
 #[cfg(feature = "mesalock_sgx")]
 use std::prelude::v1::*;
 
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
+
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
+
 use rusty_machine::learning::logistic_reg::LogisticRegressor;
 use rusty_machine::learning::optim::grad_desc::GradientDesc;
 use rusty_machine::learning::SupModel;
 use rusty_machine::linalg;
 
-use std::format;
-use std::io::{self, BufRead, BufReader, Write};
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+const MODEL_FILE: &str = "model_file";
+const INPUT_DATA: &str = "data_file";
+const RESULT: &str = "result_file";
 
 #[derive(Default)]
 pub struct LogitRegPrediction;
 
-static MODEL_FILE: &str = "model_file";
-static INPUT_DATA: &str = "data_file";
-static RESULT: &str = "result_file";
-
 impl TeaclaveFunction for LogitRegPrediction {
     fn execute(
         &self,
-        runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
+        runtime: FunctionRuntime,
         _arguments: FunctionArguments,
     ) -> anyhow::Result<String> {
         let mut model_json = String::new();
diff --git a/function/src/logistic_regression_training.rs b/function/src/logistic_regression_training.rs
index fbdc916..d79f4a7 100644
--- a/function/src/logistic_regression_training.rs
+++ b/function/src/logistic_regression_training.rs
@@ -18,43 +18,64 @@
 #[cfg(feature = "mesalock_sgx")]
 use std::prelude::v1::*;
 
+use std::convert::TryFrom;
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
+
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
+
 use rusty_machine::learning::logistic_reg::LogisticRegressor;
 use rusty_machine::learning::optim::grad_desc::GradientDesc;
 use rusty_machine::learning::SupModel;
 use rusty_machine::linalg;
-use std::format;
-use std::io::{self, BufRead, BufReader, Write};
 
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+const TRAINING_DATA: &str = "training_data";
+const OUT_MODEL_FILE: &str = "model_file";
 
 #[derive(Default)]
 pub struct LogitRegTraining;
 
-static TRAINING_DATA: &str = "training_data";
-static OUT_MODEL_FILE: &str = "model_file";
+struct LogitRegTrainingArguments {
+    alg_alpha: f64,
+    alg_iters: usize,
+    feature_size: usize,
+}
+
+impl TryFrom<FunctionArguments> for LogitRegTrainingArguments {
+    type Error = anyhow::Error;
+
+    fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
+        let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
+        let alg_iters = arguments.get("alg_iters")?.as_usize()?;
+        let feature_size = arguments.get("feature_size")?.as_usize()?;
+
+        Ok(Self {
+            alg_alpha,
+            alg_iters,
+            feature_size,
+        })
+    }
+}
 
 impl TeaclaveFunction for LogitRegTraining {
     fn execute(
         &self,
-        runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
+        runtime: FunctionRuntime,
         arguments: FunctionArguments,
     ) -> anyhow::Result<String> {
-        let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
-        let alg_iters = arguments.get("alg_iters")?.as_usize()?;
-        let feature_size = arguments.get("feature_size")?.as_usize()?;
+        let args = LogitRegTrainingArguments::try_from(arguments)?;
 
         let input = runtime.open_input(TRAINING_DATA)?;
-        let (flattend_features, targets) = parse_training_data(input, feature_size)?;
+        let (flattend_features, targets) = parse_training_data(input, args.feature_size)?;
         let data_size = targets.len();
-        let data_matrix = linalg::Matrix::new(data_size, feature_size, flattend_features);
+        let data_matrix = linalg::Matrix::new(data_size, args.feature_size, flattend_features);
         let targets = linalg::Vector::new(targets);
 
-        let gd = GradientDesc::new(alg_alpha, alg_iters);
+        let gd = GradientDesc::new(args.alg_alpha, args.alg_iters);
         let mut lr = LogisticRegressor::new(gd);
         lr.train(&data_matrix, &targets)?;
 
-        let model_json = serde_json::to_string(&lr).unwrap();
+        let model_json = serde_json::to_string(&lr)?;
         let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
         model_file.write_all(model_json.as_bytes())?;
 
diff --git a/function/src/mesapy.rs b/function/src/mesapy.rs
index 7b8c7ec..c77f8c8 100644
--- a/function/src/mesapy.rs
+++ b/function/src/mesapy.rs
@@ -18,13 +18,13 @@
 #[cfg(feature = "mesalock_sgx")]
 use std::prelude::v1::*;
 
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use std::ffi::CString;
 
 use crate::context::reset_thread_context;
 use crate::context::set_thread_context;
 use crate::context::Context;
-use std::ffi::CString;
+
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
 
 const MAXPYBUFLEN: usize = 20480;
 const MESAPY_ERROR_BUFFER_TOO_SHORT: i64 = -1i64;
@@ -44,11 +44,7 @@ extern "C" {
 pub struct Mesapy;
 
 impl TeaclaveFunction for Mesapy {
-    fn execute(
-        &self,
-        runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
-        args: FunctionArguments,
-    ) -> anyhow::Result<String> {
+    fn execute(&self, runtime: FunctionRuntime, args: FunctionArguments) -> anyhow::Result<String> {
         let script = args.get("py_payload")?.as_str();
         let py_args = args.get("py_args")?.as_str();
         let py_args: FunctionArguments = serde_json::from_str(py_args)?;
diff --git a/types/src/staged_function.rs b/types/src/staged_function.rs
index 258fff4..8483b7a 100644
--- a/types/src/staged_function.rs
+++ b/types/src/staged_function.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{Executor, ExecutorType, StagedFiles};
+use crate::{Executor, ExecutorType, StagedFiles, TeaclaveRuntime};
 
 use serde::{Deserialize, Serialize};
 use std::collections::HashMap;
@@ -24,6 +24,8 @@ use std::str::FromStr;
 
 use anyhow::{Context, Result};
 
+pub type FunctionRuntime = Box<dyn TeaclaveRuntime + Send + Sync>;
+
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ArgumentValue {
     inner: String,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@teaclave.apache.org
For additional commands, e-mail: commits-help@teaclave.apache.org