You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@teaclave.apache.org by ms...@apache.org on 2020/03/21 04:20:14 UTC
[incubator-teaclave] branch develop updated: [types] Common types
refactoring
This is an automated email from the ASF dual-hosted git repository.
mssun pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave.git
The following commit(s) were added to refs/heads/develop by this push:
new 4a0f9dc [types] Common types refactoring
4a0f9dc is described below
commit 4a0f9dc9ea2811fbfd402ac87ceb647fb4502f0a
Author: Mingshen Sun <bo...@mssun.me>
AuthorDate: Fri Mar 20 21:04:21 2020 -0700
[types] Common types refactoring
---
services/execution/enclave/src/service.rs | 39 ++++----
services/management/enclave/src/service.rs | 47 +++++-----
services/management/enclave/src/task.rs | 5 +-
tests/integration/enclave/src/teaclave_worker.rs | 4 +-
types/src/staged_task.rs | 108 +++++++++++-----------
types/src/task.rs | 3 +-
types/src/worker.rs | 110 ++++++++++++++++++-----
worker/src/function/echo.rs | 10 +--
worker/src/function/gbdt_prediction.rs | 8 +-
worker/src/function/gbdt_training.rs | 31 +++----
worker/src/function/mesapy.rs | 22 +++--
worker/src/function/mod.rs | 4 +-
worker/src/worker.rs | 10 +--
13 files changed, 238 insertions(+), 163 deletions(-)
diff --git a/services/execution/enclave/src/service.rs b/services/execution/enclave/src/service.rs
index c8b7027..7aae06c 100644
--- a/services/execution/enclave/src/service.rs
+++ b/services/execution/enclave/src/service.rs
@@ -148,7 +148,7 @@ fn finalize_task(task: &StagedTask) -> Result<()> {
let agent_dir_path = Path::new(&agent_dir);
let mut file_request_info = vec![];
- for (key, value) in task.output_map.iter() {
+ for (key, value) in task.output_data.iter() {
let mut src = agent_dir_path.to_path_buf();
src.push(&format!("{}.out", key));
let handle_file_info = HandleFileInfo::new(&src, &value.url);
@@ -171,7 +171,7 @@ fn prepare_task(task: &StagedTask) -> WorkerInvocation {
let executor_type = task.executor_type();
let function_name = task.function_name.clone();
let function_payload = String::from_utf8_lossy(&task.function_payload).to_string();
- let function_args = TeaclaveFunctionArguments::new(&task.arg_list);
+ let function_args = task.function_arguments.clone();
let agent_dir = format!("/tmp/teaclave_agent/{}", task.task_id);
let agent_dir_path = Path::new(&agent_dir);
@@ -181,7 +181,7 @@ fn prepare_task(task: &StagedTask) -> WorkerInvocation {
let mut input_file_map: HashMap<String, (PathBuf, TeaclaveFileCryptoInfo)> = HashMap::new();
let mut file_request_info = vec![];
- for (key, value) in task.input_map.iter() {
+ for (key, value) in task.input_data.iter() {
let mut dest = agent_dir_path.to_path_buf();
dest.push(&format!("{}.in", key));
let info = HandleFileInfo::new(&dest, &value.url);
@@ -204,7 +204,7 @@ fn prepare_task(task: &StagedTask) -> WorkerInvocation {
let input_files = TeaclaveWorkerFileRegistry::new(converted_input_file_map);
let mut output_file_map: HashMap<String, TeaclaveWorkerOutputFileInfo> = HashMap::new();
- for (key, value) in task.output_map.iter() {
+ for (key, value) in task.output_data.iter() {
let mut dest = agent_dir_path.to_path_buf();
dest.push(&format!("{}.out", key));
let crypto = match value.crypto_info {
@@ -246,9 +246,9 @@ pub mod tests {
let staged_task = StagedTask::new()
.task_id(task_id)
.function_name("echo")
- .args(arg_map)
- .input(input_map)
- .output(output_map);
+ .function_arguments(arg_map.into())
+ .input_data(input_map)
+ .output_data(output_map);
let invocation = prepare_task(&staged_task);
@@ -271,7 +271,7 @@ pub mod tests {
"data_sample_ratio".to_string() => "1.0".to_string(),
"min_leaf_size".to_string() => "1".to_string(),
"loss".to_string() => "LAD".to_string(),
- "training_optimization_level".to_string() => "2".to_string()
+ "training_optimization_level".to_string() => "2".to_string(),
);
let fixture_dir = format!(
"file:///{}/fixtures/functions/gbdt_training",
@@ -287,23 +287,18 @@ pub mod tests {
let crypto = TeaclaveFileRootKey128::new(&[0; 16]).unwrap();
let crypto_info = TeaclaveFileCryptoInfo::TeaclaveFileRootKey128(crypto);
- let input_data = InputData {
- url: input_url,
- hash: "".to_string(),
- crypto_info,
- };
- let output_data = OutputData {
- url: output_url,
- crypto_info,
- };
- let input_map = hashmap!("training_data".to_string() => input_data);
- let output_map = hashmap!("trained_model".to_string() => output_data);
+ let training_input_data = InputDataValue::new(&input_url, "", crypto_info);
+ let model_output_data = OutputDataValue::new(&output_url, crypto_info);
+
+ let input_data = hashmap!("training_data".to_string() => training_input_data);
+ let output_data = hashmap!("trained_model".to_string() => model_output_data);
+
let staged_task = StagedTask::new()
.task_id(task_id)
.function_name("gbdt_training")
- .args(arg_map)
- .input(input_map)
- .output(output_map);
+ .function_arguments(arg_map.into())
+ .input_data(input_data)
+ .output_data(output_data);
let invocation = prepare_task(&staged_task);
diff --git a/services/management/enclave/src/service.rs b/services/management/enclave/src/service.rs
index ae4e717..2365760 100644
--- a/services/management/enclave/src/service.rs
+++ b/services/management/enclave/src/service.rs
@@ -27,7 +27,7 @@ use teaclave_service_enclave_utils::teaclave_service;
use teaclave_types::Function;
#[cfg(test_mode)]
use teaclave_types::{FunctionInput, FunctionOutput};
-use teaclave_types::{InputData, OutputData, StagedTask, Task, TaskStatus};
+use teaclave_types::{InputDataValue, OutputDataValue, StagedTask, Task, TaskStatus};
use teaclave_types::{Storable, TeaclaveInputFile, TeaclaveOutputFile};
use teaclave_types::{TeaclaveServiceResponseError, TeaclaveServiceResponseResult};
use thiserror::Error;
@@ -352,7 +352,7 @@ impl TeaclaveManagement for TeaclaveManagementService {
creator: task.creator,
function_id: task.function_id,
function_owner: task.function_owner,
- arg_list: task.arg_list,
+ arg_list: task.function_arguments.into(),
input_data_owner_list: task.input_data_owner_list,
output_data_owner_list: task.output_data_owner_list,
participants: task.participants,
@@ -494,15 +494,15 @@ impl TeaclaveManagement for TeaclaveManagementService {
.read_from_db(task.function_id.as_bytes())
.map_err(|_| TeaclaveManagementError::PermissionDenied)?;
- let arg_list: HashMap<String, String> = task.arg_list.clone();
- let mut input_map: HashMap<String, InputData> = HashMap::new();
- let mut output_map: HashMap<String, OutputData> = HashMap::new();
+ let function_arguments = task.function_arguments.clone();
+ let mut input_map: HashMap<String, InputDataValue> = HashMap::new();
+ let mut output_map: HashMap<String, OutputDataValue> = HashMap::new();
for (data_name, data_id) in task.input_map.iter() {
- let input_data: InputData = if TeaclaveInputFile::match_prefix(data_id) {
+ let input_data: InputDataValue = if TeaclaveInputFile::match_prefix(data_id) {
let input_file: TeaclaveInputFile = self
.read_from_db(data_id.as_bytes())
.map_err(|_| TeaclaveManagementError::PermissionDenied)?;
- InputData::from_input_file(input_file)
+ InputDataValue::from_teaclave_input_file(&input_file)
} else {
return Err(TeaclaveManagementError::PermissionDenied.into());
};
@@ -510,14 +510,14 @@ impl TeaclaveManagement for TeaclaveManagementService {
}
for (data_name, data_id) in task.output_map.iter() {
- let output_data: OutputData = if TeaclaveOutputFile::match_prefix(data_id) {
+ let output_data: OutputDataValue = if TeaclaveOutputFile::match_prefix(data_id) {
let output_file: TeaclaveOutputFile = self
.read_from_db(data_id.as_bytes())
.map_err(|_| TeaclaveManagementError::PermissionDenied)?;
if output_file.hash.is_some() {
return Err(TeaclaveManagementError::PermissionDenied.into());
}
- OutputData::from_output_file(output_file)
+ OutputDataValue::from_teaclave_output_file(&output_file)
} else {
return Err(TeaclaveManagementError::PermissionDenied.into());
};
@@ -526,10 +526,12 @@ impl TeaclaveManagement for TeaclaveManagementService {
let staged_task = StagedTask::new()
.task_id(task.task_id)
- .function(&function)
- .args(arg_list)
- .input(input_map)
- .output(output_map);
+ .function_id(function.function_id)
+ .function_name(&function.name)
+ .function_payload(function.payload)
+ .function_arguments(function_arguments)
+ .input_data(input_map)
+ .output_data(output_map);
self.enqueue_to_db(StagedTask::get_queue_key().as_bytes(), &staged_task)?;
task.status = TaskStatus::Running;
self.write_to_db(&task)
@@ -761,18 +763,15 @@ pub mod tests {
};
let mut arg_list = HashMap::new();
arg_list.insert("arg".to_string(), "data".to_string());
+ let function_arguments = arg_list.into();
let url = Url::parse("s3://bucket_id/path?token=mock_token").unwrap();
let hash = "a6d604b5987b693a19d94704532b5d928c2729f24dfd40745f8d03ac9ac75a8b".to_string();
let crypto_info = TeaclaveFileCryptoInfo::TeaclaveFileRootKey128(
TeaclaveFileRootKey128::new(&[0; 16]).unwrap(),
);
- let input_data = InputData {
- url: url.clone(),
- hash,
- crypto_info,
- };
- let output_data = OutputData { url, crypto_info };
+ let input_data = InputDataValue::new(&url, hash, crypto_info);
+ let output_data = OutputDataValue::new(&url, crypto_info);
let mut input_map = HashMap::new();
input_map.insert("input".to_string(), input_data);
let mut output_map = HashMap::new();
@@ -780,10 +779,12 @@ pub mod tests {
let staged_task = StagedTask::new()
.task_id(Uuid::new_v4())
- .function(&function)
- .args(arg_list)
- .input(input_map)
- .output(output_map);
+ .function_id(function.function_id)
+ .function_name(&function.name)
+ .function_payload(function.payload)
+ .function_arguments(function_arguments)
+ .input_data(input_map)
+ .output_data(output_map);
let value = staged_task.to_vec().unwrap();
let deserialized_data = StagedTask::from_slice(&value).unwrap();
diff --git a/services/management/enclave/src/task.rs b/services/management/enclave/src/task.rs
index 2d3a584..b407573 100644
--- a/services/management/enclave/src/task.rs
+++ b/services/management/enclave/src/task.rs
@@ -47,12 +47,13 @@ pub(crate) fn create_task(
participants.insert(user_id.clone());
}
}
+ let function_arguments = arg_list.into();
let task = Task {
task_id,
creator,
function_id: function.external_id(),
function_owner: function.owner,
- arg_list,
+ function_arguments,
input_data_owner_list,
output_data_owner_list,
participants,
@@ -65,7 +66,7 @@ pub(crate) fn create_task(
};
// check arguments
let function_args: HashSet<String> = function.arg_list.into_iter().collect();
- let provide_args: HashSet<String> = task.arg_list.keys().cloned().collect();
+ let provide_args: HashSet<String> = task.function_arguments.inner().keys().cloned().collect();
let diff: HashSet<_> = function_args.difference(&provide_args).collect();
ensure!(diff.is_empty(), "bad arguments");
diff --git a/tests/integration/enclave/src/teaclave_worker.rs b/tests/integration/enclave/src/teaclave_worker.rs
index 1a8274d..27b6456 100644
--- a/tests/integration/enclave/src/teaclave_worker.rs
+++ b/tests/integration/enclave/src/teaclave_worker.rs
@@ -3,8 +3,8 @@ use std::prelude::v1::*;
use teaclave_types::hashmap;
use teaclave_types::read_all_bytes;
+use teaclave_types::FunctionArguments;
use teaclave_types::TeaclaveFileRootKey128;
-use teaclave_types::TeaclaveFunctionArguments;
use teaclave_types::TeaclaveWorkerFileRegistry;
use teaclave_types::TeaclaveWorkerInputFileInfo;
use teaclave_types::TeaclaveWorkerOutputFileInfo;
@@ -13,7 +13,7 @@ use teaclave_types::WorkerInvocation;
use teaclave_worker::Worker;
fn test_start_worker() {
- let function_args = TeaclaveFunctionArguments::new(&hashmap!(
+ let function_args = FunctionArguments::from_map(&hashmap!(
"feature_size" => "4",
"max_depth" => "4",
"iterations" => "100",
diff --git a/types/src/staged_task.rs b/types/src/staged_task.rs
index e38912d..1444d2a 100644
--- a/types/src/staged_task.rs
+++ b/types/src/staged_task.rs
@@ -1,38 +1,77 @@
-use crate::{
- Function, Storable, TeaclaveExecutorSelector, TeaclaveFileCryptoInfo, TeaclaveInputFile,
- TeaclaveOutputFile,
-};
-use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::prelude::v1::*;
+
+use serde::{Deserialize, Serialize};
use url::Url;
use uuid::Uuid;
+use crate::{
+ FunctionArguments, Storable, TeaclaveExecutorSelector, TeaclaveFileCryptoInfo,
+ TeaclaveInputFile, TeaclaveOutputFile,
+};
+
const STAGED_TASK_PREFIX: &str = "staged-"; // staged-task-uuid
pub const QUEUE_KEY: &str = "staged-task";
+pub type FunctionInputData = HashMap<String, InputDataValue>;
+pub type FunctionOutputData = HashMap<String, OutputDataValue>;
+
#[derive(Debug, Deserialize, Serialize)]
-pub struct InputData {
+pub struct InputDataValue {
pub url: Url,
pub hash: String,
pub crypto_info: TeaclaveFileCryptoInfo,
}
+impl InputDataValue {
+ pub fn new(url: &Url, hash: impl ToString, crypto_info: TeaclaveFileCryptoInfo) -> Self {
+ Self {
+ url: url.to_owned(),
+ hash: hash.to_string(),
+ crypto_info,
+ }
+ }
+
+ pub fn from_teaclave_input_file(file: &TeaclaveInputFile) -> Self {
+ Self {
+ url: file.url.to_owned(),
+ hash: file.hash.to_owned(),
+ crypto_info: file.crypto_info,
+ }
+ }
+}
+
#[derive(Debug, Deserialize, Serialize)]
-pub struct OutputData {
+pub struct OutputDataValue {
pub url: Url,
pub crypto_info: TeaclaveFileCryptoInfo,
}
+impl OutputDataValue {
+ pub fn new(url: &Url, crypto_info: TeaclaveFileCryptoInfo) -> Self {
+ Self {
+ url: url.to_owned(),
+ crypto_info,
+ }
+ }
+
+ pub fn from_teaclave_output_file(file: &TeaclaveOutputFile) -> Self {
+ Self {
+ url: file.url.to_owned(),
+ crypto_info: file.crypto_info,
+ }
+ }
+}
+
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct StagedTask {
pub task_id: Uuid,
- pub function_id: String,
+ pub function_id: Uuid,
pub function_name: String,
pub function_payload: Vec<u8>,
- pub arg_list: HashMap<String, String>,
- pub input_map: HashMap<String, InputData>,
- pub output_map: HashMap<String, OutputData>,
+ pub function_arguments: FunctionArguments,
+ pub input_data: FunctionInputData,
+ pub output_data: FunctionOutputData,
}
impl Storable for StagedTask {
@@ -45,25 +84,6 @@ impl Storable for StagedTask {
}
}
-impl InputData {
- pub fn from_input_file(file: TeaclaveInputFile) -> InputData {
- InputData {
- url: file.url,
- hash: file.hash,
- crypto_info: file.crypto_info,
- }
- }
-}
-
-impl OutputData {
- pub fn from_output_file(file: TeaclaveOutputFile) -> OutputData {
- OutputData {
- url: file.url,
- crypto_info: file.crypto_info,
- }
- }
-}
-
impl StagedTask {
pub fn new() -> Self {
Self::default()
@@ -73,18 +93,9 @@ impl StagedTask {
Self { task_id, ..self }
}
- pub fn function(self, function: &Function) -> Self {
+ pub fn function_id(self, function_id: Uuid) -> Self {
Self {
- function_id: function.external_id(),
- function_name: function.name.clone(),
- function_payload: function.payload.clone(),
- ..self
- }
- }
-
- pub fn function_id(self, function_id: impl Into<String>) -> Self {
- Self {
- function_id: function_id.into(),
+ function_id,
..self
}
}
@@ -103,23 +114,20 @@ impl StagedTask {
}
}
- pub fn args(self, args: HashMap<String, String>) -> Self {
+ pub fn function_arguments(self, function_arguments: FunctionArguments) -> Self {
Self {
- arg_list: args,
+ function_arguments,
..self
}
}
- pub fn input(self, input: HashMap<String, InputData>) -> Self {
- Self {
- input_map: input,
- ..self
- }
+ pub fn input_data(self, input_data: FunctionInputData) -> Self {
+ Self { input_data, ..self }
}
- pub fn output(self, output: HashMap<String, OutputData>) -> Self {
+ pub fn output_data(self, output_data: FunctionOutputData) -> Self {
Self {
- output_map: output,
+ output_data,
..self
}
}
diff --git a/types/src/task.rs b/types/src/task.rs
index b83af0c..8d08a4d 100644
--- a/types/src/task.rs
+++ b/types/src/task.rs
@@ -1,3 +1,4 @@
+use crate::FunctionArguments;
use crate::Storable;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
@@ -27,7 +28,7 @@ pub struct Task {
pub creator: String,
pub function_id: String,
pub function_owner: String,
- pub arg_list: HashMap<String, String>,
+ pub function_arguments: FunctionArguments,
pub input_data_owner_list: HashMap<String, DataOwnerList>,
pub output_data_owner_list: HashMap<String, DataOwnerList>,
pub participants: HashSet<String>,
diff --git a/types/src/worker.rs b/types/src/worker.rs
index 502e265..62e358c 100644
--- a/types/src/worker.rs
+++ b/types/src/worker.rs
@@ -11,17 +11,21 @@ use std::untrusted::fs::File;
use std::fs::File;
use anyhow;
+use anyhow::Context;
+use anyhow::Result;
use crate::TeaclaveFileCryptoInfo;
use crate::TeaclaveFileRootKey128;
use protected_fs::ProtectedFile;
use serde::{Deserialize, Serialize};
+use std::str::FromStr;
#[macro_export]
macro_rules! hashmap {
- ($( $key: expr => $val: expr ),*) => {{
+ ($( $key: expr => $value: expr,)+) => { hashmap!($($key => $value),+) };
+ ($( $key: expr => $value: expr ),*) => {{
let mut map = ::std::collections::HashMap::new();
- $( map.insert($key, $val); )*
+ $( map.insert($key, $value); )*
map
}}
}
@@ -223,41 +227,107 @@ where
}
}
-#[derive(Serialize, Deserialize, Debug, Default)]
-pub struct TeaclaveFunctionArguments {
- pub args: HashMap<String, String>,
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ArgumentValue {
+ inner: String,
}
-impl TeaclaveFunctionArguments {
- pub fn new<K, V>(input: &HashMap<K, V>) -> Self
+impl ArgumentValue {
+ pub fn new(value: String) -> Self {
+ Self { inner: value }
+ }
+
+ pub fn inner(&self) -> &String {
+ &self.inner
+ }
+
+ pub fn as_str(&self) -> &str {
+ &self.inner
+ }
+
+ pub fn as_usize(&self) -> Result<usize> {
+ usize::from_str(&self.inner).with_context(|| format!("cannot parse {}", self.inner))
+ }
+
+ pub fn as_u32(&self) -> Result<u32> {
+ u32::from_str(&self.inner).with_context(|| format!("cannot parse {}", self.inner))
+ }
+
+ pub fn as_f32(&self) -> Result<f32> {
+ f32::from_str(&self.inner).with_context(|| format!("cannot parse {}", self.inner))
+ }
+
+ pub fn as_f64(&self) -> Result<f64> {
+ f64::from_str(&self.inner).with_context(|| format!("cannot parse {}", self.inner))
+ }
+
+ pub fn as_u8(&self) -> Result<u8> {
+ u8::from_str(&self.inner).with_context(|| format!("cannot parse {}", self.inner))
+ }
+}
+
+impl std::fmt::Display for ArgumentValue {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.inner)
+ }
+}
+
+#[derive(Clone, Serialize, Deserialize, Debug, Default)]
+pub struct FunctionArguments {
+ #[serde(flatten)]
+ pub inner: HashMap<String, ArgumentValue>,
+}
+
+impl<S: core::default::Default + std::hash::BuildHasher> From<FunctionArguments>
+ for HashMap<String, String, S>
+{
+ fn from(arguments: FunctionArguments) -> Self {
+ arguments
+ .inner()
+ .iter()
+ .map(|(k, v)| (k.to_owned(), v.as_str().to_owned()))
+ .collect()
+ }
+}
+
+impl From<HashMap<String, String>> for FunctionArguments {
+ fn from(map: HashMap<String, String>) -> Self {
+ FunctionArguments::from_map(&map)
+ }
+}
+
+impl FunctionArguments {
+ pub fn from_map<K, V>(input: &HashMap<K, V>) -> Self
where
K: std::string::ToString,
V: std::string::ToString,
{
- let args = input.iter().fold(HashMap::new(), |mut acc, (k, v)| {
- acc.insert(k.to_string(), v.to_string());
+ let inner = input.iter().fold(HashMap::new(), |mut acc, (k, v)| {
+ acc.insert(k.to_string(), ArgumentValue::new(v.to_string()));
acc
});
- TeaclaveFunctionArguments { args }
+ Self { inner }
}
- pub fn try_get<T: std::str::FromStr>(&self, key: &str) -> anyhow::Result<T> {
- self.args
+ pub fn inner(&self) -> &HashMap<String, ArgumentValue> {
+ &self.inner
+ }
+
+ pub fn get(&self, key: &str) -> anyhow::Result<&ArgumentValue> {
+ self.inner
.get(key)
- .ok_or_else(|| anyhow::anyhow!("Cannot find function argument: {}", key))
- .and_then(|s| {
- s.parse::<T>()
- .map_err(|_| anyhow::anyhow!("parse argument error"))
- })
+ .with_context(|| format!("key not found: {}", key))
}
pub fn into_vec(self) -> Vec<String> {
let mut vector = Vec::new();
- self.args.into_iter().for_each(|(k, v)| {
+
+ self.inner.into_iter().for_each(|(k, v)| {
vector.push(k);
- vector.push(v);
+ vector.push(v.to_string());
});
+
vector
}
}
@@ -274,7 +344,7 @@ pub struct WorkerInvocation {
pub executor_type: TeaclaveExecutorSelector, // "native" | "python"
pub function_name: String, // "gbdt_training" | "mesapy" |
pub function_payload: String,
- pub function_args: TeaclaveFunctionArguments,
+ pub function_args: FunctionArguments,
pub input_files: TeaclaveWorkerFileRegistry<TeaclaveWorkerInputFileInfo>,
pub output_files: TeaclaveWorkerFileRegistry<TeaclaveWorkerOutputFileInfo>,
}
diff --git a/worker/src/function/echo.rs b/worker/src/function/echo.rs
index 7857c72..cbd2e60 100644
--- a/worker/src/function/echo.rs
+++ b/worker/src/function/echo.rs
@@ -21,7 +21,7 @@ use std::prelude::v1::*;
use crate::function::TeaclaveFunction;
use crate::runtime::TeaclaveRuntime;
use anyhow;
-use teaclave_types::TeaclaveFunctionArguments;
+use teaclave_types::FunctionArguments;
#[derive(Default)]
pub struct Echo;
@@ -30,9 +30,9 @@ impl TeaclaveFunction for Echo {
fn execute(
&self,
_runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- args: TeaclaveFunctionArguments,
+ arguments: FunctionArguments,
) -> anyhow::Result<String> {
- let message: String = args.try_get("message")?;
+ let message = arguments.get("message")?.to_string();
Ok(message)
}
}
@@ -43,7 +43,7 @@ pub mod tests {
use teaclave_test_utils::*;
use teaclave_types::hashmap;
- use teaclave_types::TeaclaveFunctionArguments;
+ use teaclave_types::FunctionArguments;
use teaclave_types::TeaclaveWorkerFileRegistry;
use crate::function::TeaclaveFunction;
@@ -54,7 +54,7 @@ pub mod tests {
}
fn test_echo() {
- let func_args = TeaclaveFunctionArguments::new(&hashmap!(
+ let func_args = FunctionArguments::from_map(&hashmap!(
"message" => "Hello Teaclave!"
));
diff --git a/worker/src/function/gbdt_prediction.rs b/worker/src/function/gbdt_prediction.rs
index dfbf527..61c1431 100644
--- a/worker/src/function/gbdt_prediction.rs
+++ b/worker/src/function/gbdt_prediction.rs
@@ -26,7 +26,7 @@ use serde_json;
use crate::function::TeaclaveFunction;
use crate::runtime::TeaclaveRuntime;
-use teaclave_types::TeaclaveFunctionArguments;
+use teaclave_types::FunctionArguments;
use gbdt::decision_tree::Data;
use gbdt::gradient_boost::GBDT;
@@ -42,7 +42,7 @@ impl TeaclaveFunction for GbdtPrediction {
fn execute(
&self,
runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- _args: TeaclaveFunctionArguments,
+ _args: FunctionArguments,
) -> anyhow::Result<String> {
let mut json_model = String::new();
let mut f = runtime.open_input(IN_MODEL)?;
@@ -101,8 +101,8 @@ pub mod tests {
use std::untrusted::fs;
use teaclave_types::hashmap;
+ use teaclave_types::FunctionArguments;
use teaclave_types::TeaclaveFileRootKey128;
- use teaclave_types::TeaclaveFunctionArguments;
use teaclave_types::TeaclaveWorkerFileRegistry;
use teaclave_types::TeaclaveWorkerInputFileInfo;
use teaclave_types::TeaclaveWorkerOutputFileInfo;
@@ -115,7 +115,7 @@ pub mod tests {
}
fn test_gbdt_prediction() {
- let func_args = TeaclaveFunctionArguments::default();
+ let func_args = FunctionArguments::default();
let plain_if_model = "fixtures/functions/gbdt_prediction/model.txt";
let plain_if_data = "fixtures/functions/gbdt_prediction/test_data.txt";
diff --git a/worker/src/function/gbdt_training.rs b/worker/src/function/gbdt_training.rs
index 340fde8..d647c60 100644
--- a/worker/src/function/gbdt_training.rs
+++ b/worker/src/function/gbdt_training.rs
@@ -26,7 +26,7 @@ use serde_json;
use crate::function::TeaclaveFunction;
use crate::runtime::TeaclaveRuntime;
-use teaclave_types::TeaclaveFunctionArguments;
+use teaclave_types::FunctionArguments;
use gbdt::config::Config;
use gbdt::decision_tree::Data;
@@ -42,18 +42,19 @@ impl TeaclaveFunction for GbdtTraining {
fn execute(
&self,
runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- args: TeaclaveFunctionArguments,
+ arguments: FunctionArguments,
) -> anyhow::Result<String> {
log::debug!("start traning...");
- let feature_size: usize = args.try_get("feature_size")?;
- let max_depth: u32 = args.try_get("max_depth")?;
- let iterations: usize = args.try_get("iterations")?;
- let shrinkage: f32 = args.try_get("shrinkage")?;
- let feature_sample_ratio: f64 = args.try_get("feature_sample_ratio")?;
- let data_sample_ratio: f64 = args.try_get("data_sample_ratio")?;
- let min_leaf_size: usize = args.try_get("min_leaf_size")?;
- let loss: String = args.try_get("loss")?;
- let training_optimization_level: u8 = args.try_get("training_optimization_level")?;
+ let feature_size = arguments.get("feature_size")?.as_usize()?;
+ let max_depth: u32 = arguments.get("max_depth")?.as_u32()?;
+ let iterations: usize = arguments.get("iterations")?.as_usize()?;
+ let shrinkage: f32 = arguments.get("shrinkage")?.as_f32()?;
+ let feature_sample_ratio: f64 = arguments.get("feature_sample_ratio")?.as_f64()?;
+ let data_sample_ratio: f64 = arguments.get("data_sample_ratio")?.as_f64()?;
+ let min_leaf_size: usize = arguments.get("min_leaf_size")?.as_usize()?;
+ let loss = arguments.get("loss")?.as_str();
+ let training_optimization_level: u8 =
+ arguments.get("training_optimization_level")?.as_u8()?;
log::debug!("open input...");
// read input
@@ -68,7 +69,7 @@ impl TeaclaveFunction for GbdtTraining {
cfg.set_max_depth(max_depth);
cfg.set_iterations(iterations);
cfg.set_shrinkage(shrinkage);
- cfg.set_loss(&loss);
+ cfg.set_loss(loss);
cfg.set_min_leaf_size(min_leaf_size);
cfg.set_data_sample_ratio(data_sample_ratio);
cfg.set_feature_sample_ratio(feature_sample_ratio);
@@ -136,8 +137,8 @@ pub mod tests {
use std::untrusted::fs;
use teaclave_types::hashmap;
+ use teaclave_types::FunctionArguments;
use teaclave_types::TeaclaveFileRootKey128;
- use teaclave_types::TeaclaveFunctionArguments;
use teaclave_types::TeaclaveWorkerFileRegistry;
use teaclave_types::TeaclaveWorkerInputFileInfo;
use teaclave_types::TeaclaveWorkerOutputFileInfo;
@@ -150,7 +151,7 @@ pub mod tests {
}
fn test_gbdt_training() {
- let func_args = TeaclaveFunctionArguments::new(&hashmap!(
+ let func_arguments = FunctionArguments::from_map(&hashmap!(
"feature_size" => "4",
"max_depth" => "4",
"iterations" => "100",
@@ -179,7 +180,7 @@ pub mod tests {
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
let function = GbdtTraining;
- let summary = function.execute(runtime, func_args).unwrap();
+ let summary = function.execute(runtime, func_arguments).unwrap();
assert_eq!(summary, "Trained 120 lines of data.");
let result = fs::read_to_string(&plain_output).unwrap();
diff --git a/worker/src/function/mesapy.rs b/worker/src/function/mesapy.rs
index 662a92c..f7460b2 100644
--- a/worker/src/function/mesapy.rs
+++ b/worker/src/function/mesapy.rs
@@ -23,7 +23,7 @@ use itertools::Itertools;
use crate::function::TeaclaveFunction;
use crate::runtime::TeaclaveRuntime;
-use teaclave_types::TeaclaveFunctionArguments;
+use teaclave_types::FunctionArguments;
use crate::function::context::reset_thread_context;
use crate::function::context::set_thread_context;
@@ -52,18 +52,18 @@ impl TeaclaveFunction for Mesapy {
fn execute(
&self,
runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- args: TeaclaveFunctionArguments,
+ args: FunctionArguments,
) -> anyhow::Result<String> {
- let script = args.try_get::<String>("py_payload")?;
- let py_args = args.try_get::<String>("py_args")?;
- let py_args: TeaclaveFunctionArguments = serde_json::from_str(&py_args)?;
+ let script = args.get("py_payload")?.as_str();
+ let py_args = args.get("py_args")?.as_str();
+ let py_args: FunctionArguments = serde_json::from_str(py_args)?;
let py_argv = py_args.into_vec();
let cstr_argv: Vec<_> = py_argv
.iter()
.map(|arg| CString::new(arg.as_str()).unwrap())
.collect();
- let mut script_bytes = script.into_bytes();
+ let mut script_bytes = script.to_owned().into_bytes();
script_bytes.push(0u8);
let mut p_argv: Vec<_> = cstr_argv
@@ -108,8 +108,8 @@ pub mod tests {
use crate::function::TeaclaveFunction;
use crate::runtime::RawIoRuntime;
use teaclave_types::hashmap;
+ use teaclave_types::FunctionArguments;
use teaclave_types::TeaclaveFileRootKey128;
- use teaclave_types::TeaclaveFunctionArguments;
use teaclave_types::TeaclaveWorkerFileRegistry;
use teaclave_types::TeaclaveWorkerInputFileInfo;
use teaclave_types::TeaclaveWorkerOutputFileInfo;
@@ -119,7 +119,7 @@ pub mod tests {
}
fn test_mesapy() {
- let py_args = TeaclaveFunctionArguments::new(&hashmap!("--name" => "Teaclave"));
+ let py_args = FunctionArguments::from_map(&hashmap!("--name" => "Teaclave"));
let py_payload = r#"
def entrypoint(argv):
in_file_id = "in_f1"
@@ -182,12 +182,10 @@ def entrypoint(argv):
};
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
- let func_args = TeaclaveFunctionArguments {
- args: hashmap!(
+ let func_args = FunctionArguments::from_map(&hashmap!(
"py_payload".to_string() => py_payload.to_string(),
"py_args".to_string() => serde_json::to_string(&py_args).unwrap()
- ),
- };
+ ));
let function = Mesapy;
let summary = function.execute(runtime, func_args).unwrap();
diff --git a/worker/src/function/mod.rs b/worker/src/function/mod.rs
index 0cdcc87..88d6261 100644
--- a/worker/src/function/mod.rs
+++ b/worker/src/function/mod.rs
@@ -20,13 +20,13 @@ use std::prelude::v1::*;
use crate::runtime::TeaclaveRuntime;
use anyhow;
-use teaclave_types::TeaclaveFunctionArguments;
+use teaclave_types::FunctionArguments;
pub trait TeaclaveFunction {
fn execute(
&self,
runtime: Box<dyn TeaclaveRuntime + Send + Sync>,
- args: TeaclaveFunctionArguments,
+ args: FunctionArguments,
) -> anyhow::Result<String>;
// TODO: Add more flexible control support on a running function
diff --git a/worker/src/worker.rs b/worker/src/worker.rs
index 6a37d88..d28897f 100644
--- a/worker/src/worker.rs
+++ b/worker/src/worker.rs
@@ -25,7 +25,7 @@ use anyhow;
use serde_json;
use teaclave_types::{
- TeaclaveExecutorSelector, TeaclaveFunctionArguments, TeaclaveWorkerFileRegistry,
+ FunctionArguments, TeaclaveExecutorSelector, TeaclaveWorkerFileRegistry,
TeaclaveWorkerInputFileInfo, TeaclaveWorkerOutputFileInfo, WorkerCapability, WorkerInvocation,
};
@@ -139,9 +139,9 @@ fn setup_runtimes() -> HashMap<String, RuntimeBuilder> {
// script arguments from the wrapped argument.
fn prepare_arguments(
executor_type: TeaclaveExecutorSelector,
- function_args: TeaclaveFunctionArguments,
+ function_args: FunctionArguments,
function_payload: String,
-) -> anyhow::Result<TeaclaveFunctionArguments> {
+) -> anyhow::Result<FunctionArguments> {
let unified_args = match executor_type {
TeaclaveExecutorSelector::Native => {
anyhow::ensure!(
@@ -156,10 +156,10 @@ fn prepare_arguments(
"Python function payload must not be empty!"
);
let mut wrap_args = HashMap::new();
- let req_args = serde_json::to_string(&function_args.args)?;
+ let req_args = serde_json::to_string(&function_args)?;
wrap_args.insert("py_payload".to_string(), function_payload);
wrap_args.insert("py_args".to_string(), req_args);
- TeaclaveFunctionArguments { args: wrap_args }
+ FunctionArguments::from_map(&wrap_args)
}
};
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@teaclave.apache.org
For additional commands, e-mail: commits-help@teaclave.apache.org