You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@teaclave.apache.org by ms...@apache.org on 2020/04/07 22:32:51 UTC
[incubator-teaclave] branch develop updated: [function] Better
function argument handling
This is an automated email from the ASF dual-hosted git repository.
mssun pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git
The following commit(s) were added to refs/heads/develop by this push:
new 3329697 [function] Better function argument handling
3329697 is described below
commit 332969792de3e1869de29c5dc22f8da1868c8018
Author: Mingshen Sun <bo...@mssun.me>
AuthorDate: Tue Apr 7 15:18:44 2020 -0700
[function] Better function argument handling
---
attestation/src/verifier.rs | 1 +
function/src/echo.rs | 21 +++++--
function/src/gbdt_prediction.rs | 15 +++--
function/src/gbdt_training.rs | 76 ++++++++++++++++++--------
function/src/logistic_regression_prediction.rs | 18 +++---
function/src/logistic_regression_training.rs | 49 ++++++++++++-----
function/src/mesapy.rs | 12 ++--
types/src/staged_function.rs | 4 +-
8 files changed, 130 insertions(+), 66 deletions(-)
diff --git a/attestation/src/verifier.rs b/attestation/src/verifier.rs
index d688f18..c710669 100644
--- a/attestation/src/verifier.rs
+++ b/attestation/src/verifier.rs
@@ -119,6 +119,7 @@ impl rustls::ServerCertVerifier for AttestationReportVerifier {
impl rustls::ClientCertVerifier for AttestationReportVerifier {
fn offer_client_auth(&self) -> bool {
+ // If test_mode is on, then disable TLS client authentication.
!cfg!(test_mode)
}
diff --git a/function/src/echo.rs b/function/src/echo.rs
index 8dbac97..3cc15a5 100644
--- a/function/src/echo.rs
+++ b/function/src/echo.rs
@@ -18,19 +18,32 @@
#[cfg(feature = "mesalock_sgx")]
use std::prelude::v1::*;
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use std::convert::TryFrom;
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
#[derive(Default)]
pub struct Echo;
+struct EchoArguments {
+ message: String,
+}
+
+impl TryFrom<FunctionArguments> for EchoArguments {
+ type Error = anyhow::Error;
+
+ fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
+ let message = arguments.get("message")?.to_string();
+ Ok(Self { message })
+ }
+}
+
impl TeaclaveFunction for Echo {
fn execute(
&self,
- _runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
+ _runtime: FunctionRuntime,
arguments: FunctionArguments,
) -> anyhow::Result<String> {
- let message = arguments.get("message")?.to_string();
+ let message = EchoArguments::try_from(arguments)?.message;
Ok(message)
}
}
diff --git a/function/src/gbdt_prediction.rs b/function/src/gbdt_prediction.rs
index 16233e5..57ec08b 100644
--- a/function/src/gbdt_prediction.rs
+++ b/function/src/gbdt_prediction.rs
@@ -21,24 +21,23 @@ use std::prelude::v1::*;
use std::format;
use std::io::{self, BufRead, BufReader, Write};
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
use gbdt::decision_tree::Data;
use gbdt::gradient_boost::GBDT;
+const IN_MODEL: &str = "model_file";
+const IN_DATA: &str = "data_file";
+const OUT_RESULT: &str = "result_file";
+
#[derive(Default)]
pub struct GbdtPrediction;
-static IN_MODEL: &str = "model_file";
-static IN_DATA: &str = "data_file";
-static OUT_RESULT: &str = "result_file";
-
impl TeaclaveFunction for GbdtPrediction {
fn execute(
&self,
- runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- _args: FunctionArguments,
+ runtime: FunctionRuntime,
+ _arguments: FunctionArguments,
) -> anyhow::Result<String> {
let mut json_model = String::new();
let mut f = runtime.open_input(IN_MODEL)?;
diff --git a/function/src/gbdt_training.rs b/function/src/gbdt_training.rs
index 59cacf3..e911834 100644
--- a/function/src/gbdt_training.rs
+++ b/function/src/gbdt_training.rs
@@ -21,26 +21,35 @@ use std::prelude::v1::*;
use std::format;
use std::io::{self, BufRead, BufReader, Write};
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use std::convert::TryFrom;
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
use gbdt::config::Config;
use gbdt::decision_tree::Data;
use gbdt::gradient_boost::GBDT;
+const IN_DATA: &str = "training_data";
+const OUT_MODEL: &str = "trained_model";
+
#[derive(Default)]
pub struct GbdtTraining;
-static IN_DATA: &str = "training_data";
-static OUT_MODEL: &str = "trained_model";
+struct GbdtTrainingArguments {
+ feature_size: usize,
+ max_depth: u32,
+ iterations: usize,
+ shrinkage: f32,
+ feature_sample_ratio: f64,
+ data_sample_ratio: f64,
+ min_leaf_size: usize,
+ loss: String,
+ training_optimization_level: u8,
+}
-impl TeaclaveFunction for GbdtTraining {
- fn execute(
- &self,
- runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- arguments: FunctionArguments,
- ) -> anyhow::Result<String> {
- log::debug!("start traning...");
+impl TryFrom<FunctionArguments> for GbdtTrainingArguments {
+ type Error = anyhow::Error;
+
+ fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
let feature_size = arguments.get("feature_size")?.as_usize()?;
let max_depth = arguments.get("max_depth")?.as_u32()?;
let iterations = arguments.get("iterations")?.as_usize()?;
@@ -48,27 +57,50 @@ impl TeaclaveFunction for GbdtTraining {
let feature_sample_ratio = arguments.get("feature_sample_ratio")?.as_f64()?;
let data_sample_ratio = arguments.get("data_sample_ratio")?.as_f64()?;
let min_leaf_size = arguments.get("min_leaf_size")?.as_usize()?;
- let loss = arguments.get("loss")?.as_str();
+ let loss = arguments.get("loss")?.as_str().to_owned();
let training_optimization_level = arguments.get("training_optimization_level")?.as_u8()?;
+ Ok(Self {
+ feature_size,
+ max_depth,
+ iterations,
+ shrinkage,
+ feature_sample_ratio,
+ data_sample_ratio,
+ min_leaf_size,
+ loss,
+ training_optimization_level,
+ })
+ }
+}
+
+impl TeaclaveFunction for GbdtTraining {
+ fn execute(
+ &self,
+ runtime: FunctionRuntime,
+ arguments: FunctionArguments,
+ ) -> anyhow::Result<String> {
+ log::debug!("start traning...");
+ let args = GbdtTrainingArguments::try_from(arguments)?;
+
log::debug!("open input...");
// read input
let training_file = runtime.open_input(IN_DATA)?;
- let mut train_dv = parse_training_data(training_file, feature_size)?;
+ let mut train_dv = parse_training_data(training_file, args.feature_size)?;
let data_size = train_dv.len();
// init gbdt config
let mut cfg = Config::new();
cfg.set_debug(false);
- cfg.set_feature_size(feature_size);
- cfg.set_max_depth(max_depth);
- cfg.set_iterations(iterations);
- cfg.set_shrinkage(shrinkage);
- cfg.set_loss(loss);
- cfg.set_min_leaf_size(min_leaf_size);
- cfg.set_data_sample_ratio(data_sample_ratio);
- cfg.set_feature_sample_ratio(feature_sample_ratio);
- cfg.set_training_optimization_level(training_optimization_level);
+ cfg.set_feature_size(args.feature_size);
+ cfg.set_max_depth(args.max_depth);
+ cfg.set_iterations(args.iterations);
+ cfg.set_shrinkage(args.shrinkage);
+ cfg.set_loss(&args.loss);
+ cfg.set_min_leaf_size(args.min_leaf_size);
+ cfg.set_data_sample_ratio(args.data_sample_ratio);
+ cfg.set_feature_sample_ratio(args.feature_sample_ratio);
+ cfg.set_training_optimization_level(args.training_optimization_level);
// start training
let mut gbdt_train_mod = GBDT::new(&cfg);
diff --git a/function/src/logistic_regression_prediction.rs b/function/src/logistic_regression_prediction.rs
index a7b3a98..7aacd79 100644
--- a/function/src/logistic_regression_prediction.rs
+++ b/function/src/logistic_regression_prediction.rs
@@ -18,27 +18,27 @@
#[cfg(feature = "mesalock_sgx")]
use std::prelude::v1::*;
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
+
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
+
use rusty_machine::learning::logistic_reg::LogisticRegressor;
use rusty_machine::learning::optim::grad_desc::GradientDesc;
use rusty_machine::learning::SupModel;
use rusty_machine::linalg;
-use std::format;
-use std::io::{self, BufRead, BufReader, Write};
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+const MODEL_FILE: &str = "model_file";
+const INPUT_DATA: &str = "data_file";
+const RESULT: &str = "result_file";
#[derive(Default)]
pub struct LogitRegPrediction;
-static MODEL_FILE: &str = "model_file";
-static INPUT_DATA: &str = "data_file";
-static RESULT: &str = "result_file";
-
impl TeaclaveFunction for LogitRegPrediction {
fn execute(
&self,
- runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
+ runtime: FunctionRuntime,
_arguments: FunctionArguments,
) -> anyhow::Result<String> {
let mut model_json = String::new();
diff --git a/function/src/logistic_regression_training.rs b/function/src/logistic_regression_training.rs
index fbdc916..d79f4a7 100644
--- a/function/src/logistic_regression_training.rs
+++ b/function/src/logistic_regression_training.rs
@@ -18,43 +18,64 @@
#[cfg(feature = "mesalock_sgx")]
use std::prelude::v1::*;
+use std::convert::TryFrom;
+use std::format;
+use std::io::{self, BufRead, BufReader, Write};
+
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
+
use rusty_machine::learning::logistic_reg::LogisticRegressor;
use rusty_machine::learning::optim::grad_desc::GradientDesc;
use rusty_machine::learning::SupModel;
use rusty_machine::linalg;
-use std::format;
-use std::io::{self, BufRead, BufReader, Write};
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+const TRAINING_DATA: &str = "training_data";
+const OUT_MODEL_FILE: &str = "model_file";
#[derive(Default)]
pub struct LogitRegTraining;
-static TRAINING_DATA: &str = "training_data";
-static OUT_MODEL_FILE: &str = "model_file";
+struct LogitRegTrainingArguments {
+ alg_alpha: f64,
+ alg_iters: usize,
+ feature_size: usize,
+}
+
+impl TryFrom<FunctionArguments> for LogitRegTrainingArguments {
+ type Error = anyhow::Error;
+
+ fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
+ let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
+ let alg_iters = arguments.get("alg_iters")?.as_usize()?;
+ let feature_size = arguments.get("feature_size")?.as_usize()?;
+
+ Ok(Self {
+ alg_alpha,
+ alg_iters,
+ feature_size,
+ })
+ }
+}
impl TeaclaveFunction for LogitRegTraining {
fn execute(
&self,
- runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
+ runtime: FunctionRuntime,
arguments: FunctionArguments,
) -> anyhow::Result<String> {
- let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
- let alg_iters = arguments.get("alg_iters")?.as_usize()?;
- let feature_size = arguments.get("feature_size")?.as_usize()?;
+ let args = LogitRegTrainingArguments::try_from(arguments)?;
let input = runtime.open_input(TRAINING_DATA)?;
- let (flattend_features, targets) = parse_training_data(input, feature_size)?;
+ let (flattend_features, targets) = parse_training_data(input, args.feature_size)?;
let data_size = targets.len();
- let data_matrix = linalg::Matrix::new(data_size, feature_size, flattend_features);
+ let data_matrix = linalg::Matrix::new(data_size, args.feature_size, flattend_features);
let targets = linalg::Vector::new(targets);
- let gd = GradientDesc::new(alg_alpha, alg_iters);
+ let gd = GradientDesc::new(args.alg_alpha, args.alg_iters);
let mut lr = LogisticRegressor::new(gd);
lr.train(&data_matrix, &targets)?;
- let model_json = serde_json::to_string(&lr).unwrap();
+ let model_json = serde_json::to_string(&lr)?;
let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
model_file.write_all(model_json.as_bytes())?;
diff --git a/function/src/mesapy.rs b/function/src/mesapy.rs
index 7b8c7ec..c77f8c8 100644
--- a/function/src/mesapy.rs
+++ b/function/src/mesapy.rs
@@ -18,13 +18,13 @@
#[cfg(feature = "mesalock_sgx")]
use std::prelude::v1::*;
-use teaclave_types::FunctionArguments;
-use teaclave_types::{TeaclaveFunction, TeaclaveRuntime};
+use std::ffi::CString;
use crate::context::reset_thread_context;
use crate::context::set_thread_context;
use crate::context::Context;
-use std::ffi::CString;
+
+use teaclave_types::{FunctionArguments, FunctionRuntime, TeaclaveFunction};
const MAXPYBUFLEN: usize = 20480;
const MESAPY_ERROR_BUFFER_TOO_SHORT: i64 = -1i64;
@@ -44,11 +44,7 @@ extern "C" {
pub struct Mesapy;
impl TeaclaveFunction for Mesapy {
- fn execute(
- &self,
- runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- args: FunctionArguments,
- ) -> anyhow::Result<String> {
+ fn execute(&self, runtime: FunctionRuntime, args: FunctionArguments) -> anyhow::Result<String> {
let script = args.get("py_payload")?.as_str();
let py_args = args.get("py_args")?.as_str();
let py_args: FunctionArguments = serde_json::from_str(py_args)?;
diff --git a/types/src/staged_function.rs b/types/src/staged_function.rs
index 258fff4..8483b7a 100644
--- a/types/src/staged_function.rs
+++ b/types/src/staged_function.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{Executor, ExecutorType, StagedFiles};
+use crate::{Executor, ExecutorType, StagedFiles, TeaclaveRuntime};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@@ -24,6 +24,8 @@ use std::str::FromStr;
use anyhow::{Context, Result};
+pub type FunctionRuntime = Box<dyn TeaclaveRuntime + Send + Sync>;
+
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ArgumentValue {
inner: String,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@teaclave.apache.org
For additional commands, e-mail: commits-help@teaclave.apache.org