You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2021/10/29 09:12:40 UTC
[submarine] branch master updated: SUBMARINE-1045. Add static type
parameter in submarine-sdk
This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 1fb6652 SUBMARINE-1045. Add static type parameter in submarine-sdk
1fb6652 is described below
commit 1fb665277a55f76a37aa23244b150fcd42683d9d
Author: rayray2002 <ra...@gmail.com>
AuthorDate: Mon Oct 25 15:14:41 2021 +0800
SUBMARINE-1045. Add static type parameter in submarine-sdk
### What is this PR for?
<!-- A few sentences describing the overall goals of the pull request's commits.
First time? Check out the contributing guide - https://submarine.apache.org/contribution/contributions.html
-->
Add static type parameter in submarine-sdk
### What type of PR is it?
[Improvement]
### Todos
### What is the Jira issue?
<!-- * Open an issue on Jira https://issues.apache.org/jira/browse/SUBMARINE/
* Put link here, and add [SUBMARINE-*Jira number*] in PR title, eg. `SUBMARINE-23. PR title`
-->
https://issues.apache.org/jira/projects/SUBMARINE/issues/SUBMARINE-1045
### How should this be tested?
<!--
* First time? Setup Travis CI as described on https://submarine.apache.org/contribution/contributions.html#continuous-integration
* Strongly recommended: add automated unit tests for any new or changed behavior
* Outline any manual steps to test the PR here.
-->
mypy passed
runtime type check
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? Yes/No
* Are there breaking changes for older versions? Yes/No
* Does this need new documentation? Yes/No
Author: rayray2002 <ra...@gmail.com>
Signed-off-by: Kevin <pi...@apache.org>
Closes #781 from rayray2002/SUBMARINE-1045 and squashes the following commits:
a846b3cb [rayray2002] SUBMARINE-1045. Add static type parameter in submarine-sdk
53256dde [rayray2002] SUBMARINE-1045. Add static type parameter in submarine-sdk
---
.../pysubmarine/submarine/artifacts/repository.py | 8 ++++----
.../submarine/entities/_submarine_object.py | 8 ++++----
.../submarine/entities/model_registry/model_stages.py | 2 +-
.../submarine/entities/model_registry/model_version.py | 2 +-
.../entities/model_registry/registered_model.py | 2 +-
.../submarine/experiment/api/experiment_client.py | 6 +++---
.../submarine/experiment/models/code_spec.py | 16 +++++++++-------
submarine-sdk/pysubmarine/submarine/experiment/rest.py | 4 ++--
.../pysubmarine/submarine/ml/pytorch/layers/core.py | 14 +++++++-------
submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py | 2 +-
.../pysubmarine/submarine/ml/pytorch/metric.py | 2 +-
.../pysubmarine/submarine/ml/tensorflow/input/input.py | 8 ++++----
.../pysubmarine/submarine/ml/tensorflow/layers/core.py | 18 ++++++++++--------
.../pysubmarine/submarine/ml/tensorflow/optimizer.py | 2 +-
submarine-sdk/pysubmarine/submarine/models/client.py | 18 +++++++++---------
submarine-sdk/pysubmarine/submarine/models/pytorch.py | 2 +-
.../pysubmarine/submarine/models/tensorflow.py | 2 +-
.../submarine/store/tracking/sqlalchemy_store.py | 2 +-
submarine-sdk/pysubmarine/submarine/utils/__init__.py | 2 +-
submarine-sdk/pysubmarine/submarine/utils/db_utils.py | 6 +++---
submarine-sdk/pysubmarine/submarine/utils/env.py | 12 ++++++------
.../pysubmarine/submarine/utils/rest_utils.py | 4 ++--
.../pysubmarine/submarine/utils/validation.py | 14 +++++++-------
23 files changed, 80 insertions(+), 76 deletions(-)
diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 5a60648..9bee7ae 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -19,7 +19,7 @@ import boto3
class Repository:
- def __init__(self, experiment_id):
+ def __init__(self, experiment_id: str):
self.client = boto3.client(
"s3",
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
@@ -28,10 +28,10 @@ class Repository:
)
self.dest_path = experiment_id
- def _upload_file(self, local_file, bucket, key):
+ def _upload_file(self, local_file: str, bucket: str, key: str) -> None:
self.client.upload_file(Filename=local_file, Bucket=bucket, Key=key)
- def _list_artifact_subfolder(self, artifact_path):
+ def _list_artifact_subfolder(self, artifact_path: str):
response = self.client.list_objects(
Bucket="submarine",
Prefix=os.path.join(self.dest_path, artifact_path) + "/",
@@ -39,7 +39,7 @@ class Repository:
)
return response.get("CommonPrefixes")
- def log_artifact(self, local_file, artifact_path):
+ def log_artifact(self, local_file: str, artifact_path: str) -> None:
bucket = "submarine"
dest_path = self.dest_path
dest_path = os.path.join(dest_path, artifact_path)
diff --git a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
index ffcea7f..db2ad09 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
@@ -31,11 +31,11 @@ class _SubmarineObject:
filtered_dict = {key: value for key, value in the_dict.items() if key in cls._properties()}
return cls(**filtered_dict)
- def __repr__(self):
+ def __repr__(self) -> str:
return to_string(self)
-def to_string(obj):
+def to_string(obj) -> str:
return _SubmarineObjectPrinter().to_string(obj)
@@ -48,10 +48,10 @@ class _SubmarineObjectPrinter:
super(_SubmarineObjectPrinter, self).__init__()
self.printer = pprint.PrettyPrinter()
- def to_string(self, obj):
+ def to_string(self, obj) -> str:
if isinstance(obj, _SubmarineObject):
return "<%s: %s>" % (get_classname(obj), self._entity_to_string(obj))
return self.printer.pformat(obj)
- def _entity_to_string(self, entity):
+ def _entity_to_string(self, entity) -> str:
return ", ".join(["%s=%s" % (key, self.to_string(value)) for key, value in entity])
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
index 4a5e565..3d3f556 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
@@ -26,7 +26,7 @@ ALL_STAGES = [STAGE_NONE, STAGE_DEVELOPING, STAGE_PRODUCTION, STAGE_ARCHIVED]
_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
-def get_canonical_stage(stage):
+def get_canonical_stage(stage: str) -> str:
key = stage.lower()
if key not in _CANONICAL_MAPPING:
raise SubmarineException(f"Invalid Model Version stage {stage}.")
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
index 86652b6..0b43e0a 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
@@ -98,6 +98,6 @@ class ModelVersion(_SubmarineObject):
return self._description
@property
- def tags(self):
+ def tags(self) -> list:
"""List of strings."""
return self._tags
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
index b88ac22..f94c5a1 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
@@ -49,6 +49,6 @@ class RegisteredModel(_SubmarineObject):
return self._description
@property
- def tags(self):
+ def tags(self) -> list:
"""List of strings"""
return self._tags
diff --git a/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py b/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py
index d659876..f54481e 100644
--- a/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py
+++ b/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py
@@ -38,7 +38,7 @@ def generate_host():
class ExperimentClient:
- def __init__(self, host=generate_host()):
+ def __init__(self, host: str = generate_host()):
"""
Submarine experiment client constructor
:param host: An HTTP URI like http://submarine-server:8080.
@@ -59,7 +59,7 @@ class ExperimentClient:
response = self.experiment_api.create_experiment(experiment_spec=experiment_spec)
return response.result
- def wait_for_finish(self, id, polling_interval=10):
+ def wait_for_finish(self, id, polling_interval: float = 10):
"""
Waits until experiment is finished or failed
:param id: submarine experiment id
@@ -75,7 +75,7 @@ class ExperimentClient:
index = self._log_pod(id, index)
time.sleep(polling_interval)
- def _log_pod(self, id, index):
+ def _log_pod(self, id, index: int):
response = self.experiment_api.get_log(id)
log_contents = response.result["logContent"]
if len(log_contents) == 0:
diff --git a/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py b/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py
index b319f90..06e5f0f 100644
--- a/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py
+++ b/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py
@@ -52,7 +52,9 @@ class CodeSpec(object):
attribute_map = {"sync_mode": "syncMode", "url": "url"}
- def __init__(self, sync_mode=None, url=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self, sync_mode: str = None, url: str = None, local_vars_configuration: Configuration = None
+ ): # noqa: E501
"""CodeSpec - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -78,7 +80,7 @@ class CodeSpec(object):
return self._sync_mode
@sync_mode.setter
- def sync_mode(self, sync_mode):
+ def sync_mode(self, sync_mode: str) -> None:
"""Sets the sync_mode of this CodeSpec.
@@ -99,7 +101,7 @@ class CodeSpec(object):
return self._url
@url.setter
- def url(self, url):
+ def url(self, url: str) -> None:
"""Sets the url of this CodeSpec.
@@ -135,22 +137,22 @@ class CodeSpec(object):
return result
- def to_str(self):
+ def to_str(self) -> str:
"""Returns the string representation of the model"""
return pprint.pformat(self.to_dict())
- def __repr__(self):
+ def __repr__(self) -> str:
"""For `print` and `pprint`"""
return self.to_str()
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
"""Returns true if both objects are equal"""
if not isinstance(other, CodeSpec):
return False
return self.to_dict() == other.to_dict()
- def __ne__(self, other):
+ def __ne__(self, other) -> bool:
"""Returns true if both objects are not equal"""
if not isinstance(other, CodeSpec):
return True
diff --git a/submarine-sdk/pysubmarine/submarine/experiment/rest.py b/submarine-sdk/pysubmarine/submarine/experiment/rest.py
index 06b8d93..18fd44d 100644
--- a/submarine-sdk/pysubmarine/submarine/experiment/rest.py
+++ b/submarine-sdk/pysubmarine/submarine/experiment/rest.py
@@ -55,13 +55,13 @@ class RESTResponse(io.IOBase):
"""Returns a dictionary of the response headers."""
return self.urllib3_response.getheaders()
- def getheader(self, name, default=None):
+ def getheader(self, name: str, default=None):
"""Returns a given response header."""
return self.urllib3_response.getheader(name, default)
class RESTClientObject(object):
- def __init__(self, configuration, pools_size=4, maxsize=None):
+ def __init__(self, configuration, pools_size: int = 4, maxsize: int = None):
# urllib3.PoolManager will pass all kw parameters to connectionpool
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
index 265ea1f..fd7d9e9 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
@@ -19,7 +19,7 @@ from torch import nn
# pylint: disable=W0223
class FeatureLinear(nn.Module):
- def __init__(self, num_features, out_features):
+ def __init__(self, num_features: int, out_features: int):
"""
:param num_features: number of total features.
:param out_features: The number of output features.
@@ -28,7 +28,7 @@ class FeatureLinear(nn.Module):
self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=out_features)
self.bias = nn.Parameter(torch.zeros((out_features,)))
- def forward(self, feature_idx, feature_value):
+ def forward(self, feature_idx: torch.LongTensor, feature_value: torch.LongTensor):
"""
:param feature_idx: torch.LongTensor (batch_size, num_fields)
:param feature_value: torch.LongTensor (batch_size, num_fields)
@@ -39,11 +39,11 @@ class FeatureLinear(nn.Module):
class FeatureEmbedding(nn.Module):
- def __init__(self, num_features, embedding_dim):
+ def __init__(self, num_features: int, embedding_dim):
super().__init__()
self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim)
- def forward(self, feature_idx, feature_value):
+ def forward(self, feature_idx: torch.LongTensor, feature_value: torch.LongTensor):
"""
:param feature_idx: torch.LongTensor (batch_size, num_fields)
:param feature_value: torch.LongTensor (batch_size, num_fields)
@@ -52,7 +52,7 @@ class FeatureEmbedding(nn.Module):
class PairwiseInteraction(nn.Module):
- def forward(self, x):
+ def forward(self, x: torch.Tensor):
"""
:param x: torch.Tensor (batch_size, num_fields, embedding_dim)
"""
@@ -65,7 +65,7 @@ class PairwiseInteraction(nn.Module):
class DNN(nn.Module):
- def __init__(self, in_features, out_features, hidden_units, dropout_rates):
+ def __init__(self, in_features: int, out_features: int, hidden_units, dropout_rates):
super().__init__()
*layers, out_layer = list(zip([in_features, *hidden_units], [*hidden_units, out_features]))
self.net = nn.Sequential(
@@ -81,7 +81,7 @@ class DNN(nn.Module):
nn.Linear(*out_layer)
)
- def forward(self, x):
+ def forward(self, x: torch.FloatTensor):
"""
:param x: torch.FloatTensor (batch_size, in_features)
"""
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py
index 234c237..952863f 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py
@@ -23,7 +23,7 @@ class LossKey:
BCEWithLogitsLoss = "BCEWithLogitsLoss".lower()
-def get_loss_fn(key):
+def get_loss_fn(key: str):
key = key.lower()
if key == LossKey.BCELoss:
return nn.BCELoss
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py
index 43f3d26..0c838fd 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py
@@ -24,7 +24,7 @@ class MetricKey:
RECALL = "recall"
-def get_metric_fn(key):
+def get_metric_fn(key: str):
key = key.lower()
if key == MetricKey.F1_SCORE:
return metrics.f1_score
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
index f779fb6..0cc7259 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
@@ -24,10 +24,10 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE
def libsvm_input_fn(
filepath,
- batch_size=256,
- num_epochs=3, # pylint: disable=W0613
- perform_shuffle=False,
- delimiter=" ",
+ batch_size: int = 256,
+ num_epochs: int = 3, # pylint: disable=W0613
+ perform_shuffle: bool = False,
+ delimiter: str = " ",
**kwargs
):
def _input_fn():
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
index dbe048f..47f3698 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
@@ -43,12 +43,12 @@ def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay):
def dnn_layer(
inputs,
- estimator_mode,
- batch_norm,
- deep_layers,
+ estimator_mode: str,
+ batch_norm: bool,
+ deep_layers: list,
dropout,
- batch_norm_decay=0.9,
- l2_reg=0,
+ batch_norm_decay: float = 0.9,
+ l2_reg: float = 0,
**kwargs
):
"""
@@ -100,7 +100,7 @@ def dnn_layer(
return deep_out
-def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs):
+def linear_layer(features, feature_size, field_size, l2_reg: float = 0, **kwargs):
"""
Layer which represents linear function.
:param features: input features
@@ -131,7 +131,9 @@ def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs):
return linear_out
-def embedding_layer(features, feature_size, field_size, embedding_size, l2_reg=0, **kwargs):
+def embedding_layer(
+ features, feature_size, field_size, embedding_size, l2_reg: float = 0, **kwargs
+):
"""
Turns positive integers (indexes) into dense vectors of fixed size.
eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
@@ -199,7 +201,7 @@ class KMaxPooling(Layer):
- **axis**: positive integer, the dimension to look for elements.
"""
- def __init__(self, k=1, axis=-1, **kwargs):
+ def __init__(self, k: int = 1, axis: int = -1, **kwargs):
self.dims = 1
self.k = k
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
index dd61d6e..3ab9bbb 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
@@ -29,7 +29,7 @@ class OptimizerKey(object):
FTRL = "ftrl"
-def get_optimizer(optimizer_key, learning_rate):
+def get_optimizer(optimizer_key: str, learning_rate: float):
optimizer_key = optimizer_key.lower()
if optimizer_key == OptimizerKey.ADAM:
diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py
index e633188..da13082 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -33,10 +33,10 @@ from .utils import exist_ps, get_job_id, get_worker_index
class ModelsClient:
def __init__(
self,
- tracking_uri=None,
- registry_uri=None,
- aws_access_key_id=None,
- aws_secret_access_key=None,
+ tracking_uri: str = None,
+ registry_uri: str = None,
+ aws_access_key_id: str = None,
+ aws_secret_access_key: str = None,
):
"""
Set up mlflow server connection, including: s3 endpoint, aws, tracking server
@@ -69,26 +69,26 @@ class ModelsClient:
experiment_id = self._get_or_create_experiment(experiment_name)
return mlflow.start_run(run_name=run_name, experiment_id=experiment_id)
- def log_param(self, key, value):
+ def log_param(self, key: str, value: str):
mlflow.log_param(key, value)
def log_params(self, params):
mlflow.log_params(params)
- def log_metric(self, key, value, step=None):
+ def log_metric(self, key: str, value: str, step=None):
mlflow.log_metric(key, value, step)
def log_metrics(self, metrics, step=None):
mlflow.log_metrics(metrics, step)
- def load_model(self, name, version):
+ def load_model(self, name: str, version: str):
model = mlflow.pyfunc.load_model(model_uri=f"models:/{name}/{version}")
return model
- def update_model(self, name, new_name):
+ def update_model(self, name: str, new_name: str):
self.client.rename_registered_model(name=name, new_name=new_name)
- def delete_model(self, name, version):
+ def delete_model(self, name: str, version: str):
self.client.delete_model_version(name=name, version=version)
def save_model(self, model_type, model, artifact_path, registered_model_name=None):
diff --git a/submarine-sdk/pysubmarine/submarine/models/pytorch.py b/submarine-sdk/pysubmarine/submarine/models/pytorch.py
index a143aa5..38cdd57 100644
--- a/submarine-sdk/pysubmarine/submarine/models/pytorch.py
+++ b/submarine-sdk/pysubmarine/submarine/models/pytorch.py
@@ -18,5 +18,5 @@ import os
import torch
-def save_model(model, artifact_path):
+def save_model(model, artifact_path: str):
torch.save(model, os.path.join(artifact_path, "model.pth"))
diff --git a/submarine-sdk/pysubmarine/submarine/models/tensorflow.py b/submarine-sdk/pysubmarine/submarine/models/tensorflow.py
index fbe5324..a947a91 100644
--- a/submarine-sdk/pysubmarine/submarine/models/tensorflow.py
+++ b/submarine-sdk/pysubmarine/submarine/models/tensorflow.py
@@ -14,5 +14,5 @@
# limitations under the License.
-def save_model(model, artifact_path):
+def save_model(model, artifact_path: str):
model.save(artifact_path)
diff --git a/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
index 23e8a8d..e7e1e3a 100644
--- a/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
@@ -103,7 +103,7 @@ class SqlAlchemyStore(AbstractStore):
return make_managed_session
@staticmethod
- def _save_to_db(session, objs):
+ def _save_to_db(session, objs: object):
"""
Store in db
"""
diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
index 4908ba6..1b0f045 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
@@ -19,7 +19,7 @@ from submarine.exceptions import SubmarineException
from submarine.utils.db_utils import get_db_uri, set_db_uri
-def extract_db_type_from_uri(db_uri):
+def extract_db_type_from_uri(db_uri: str):
"""
Parse the specified DB URI to extract the database type. Confirm the database type is
supported. If a driver is specified, confirm it passes a plausible regex.
diff --git a/submarine-sdk/pysubmarine/submarine/utils/db_utils.py b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
index b23ce2d..8fcc9a5 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
@@ -22,14 +22,14 @@ _DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
_db_uri = None
-def is_db_uri_set():
+def is_db_uri_set() -> bool:
"""Returns True if the DB URI has been set, False otherwise."""
if _db_uri or env.get_env(_DB_URI_ENV_VAR):
return True
return False
-def set_db_uri(uri):
+def set_db_uri(uri: str):
"""
Set the DB URI. This does not affect the currently active run (if one exists),
but takes effect for successive runs.
@@ -38,7 +38,7 @@ def set_db_uri(uri):
_db_uri = uri
-def get_db_uri():
+def get_db_uri() -> str:
"""
Get the current DB URI.
:return: The DB URI.
diff --git a/submarine-sdk/pysubmarine/submarine/utils/env.py b/submarine-sdk/pysubmarine/submarine/utils/env.py
index 3797efc..110a134 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/env.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/env.py
@@ -19,22 +19,22 @@ import os
from collections.abc import Mapping
-def get_env(variable_name):
+def get_env(variable_name: str):
return os.environ.get(variable_name)
-def unset_variable(variable_name):
+def unset_variable(variable_name: str) -> None:
if variable_name in os.environ:
del os.environ[variable_name]
-def check_env_exists(variable_name):
+def check_env_exists(variable_name: str) -> bool:
if variable_name not in os.environ:
return False
return True
-def get_from_json(path, defaultParams):
+def get_from_json(path: str, defaultParams: dict):
"""
If model parameters not specify in Json, use parameter in defaultParams
:param path: The json file that specifies the model parameters.
@@ -50,7 +50,7 @@ def get_from_json(path, defaultParams):
return get_from_dicts(params, defaultParams)
-def get_from_dicts(params, defaultParams):
+def get_from_dicts(params: dict, defaultParams: dict):
"""
If model parameters not specify in params, use parameter in defaultParams
:param params: parameters which will be merged
@@ -71,7 +71,7 @@ def get_from_dicts(params, defaultParams):
return dct
-def get_from_registry(key, registry):
+def get_from_registry(key: str, registry: dict):
if hasattr(key, "lower"):
key = key.lower()
if key in registry:
diff --git a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
index db71b13..5054af0 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
@@ -51,7 +51,7 @@ def http_request(base_url, endpoint, method, json_body, timeout=60, headers=None
return result
-def _can_parse_as_json(string):
+def _can_parse_as_json(string: str) -> bool:
try:
json.loads(string)
return True
@@ -59,7 +59,7 @@ def _can_parse_as_json(string):
return False
-def verify_rest_response(response, endpoint):
+def verify_rest_response(response, endpoint: str):
"""Verify the return code and raise exception if the request was not successful."""
if response.status_code != 200:
if _can_parse_as_json(response.text):
diff --git a/submarine-sdk/pysubmarine/submarine/utils/validation.py b/submarine-sdk/pysubmarine/submarine/utils/validation.py
index 00e2c98..049a873 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/validation.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/validation.py
@@ -37,7 +37,7 @@ _BAD_CHARACTERS_MESSAGE = (
_UNSUPPORTED_DB_TYPE_MSG = "Supported database engines are {%s}" % ", ".join(DATABASE_ENGINES)
-def bad_path_message(name):
+def bad_path_message(name: str):
return (
"Names may be treated as files in certain cases, and must not resolve to other names"
" when treated as such. This name would resolve to '%s'"
@@ -45,12 +45,12 @@ def bad_path_message(name):
)
-def path_not_unique(name):
+def path_not_unique(name: str):
norm = posixpath.normpath(name)
return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/")
-def _validate_param_name(name):
+def _validate_param_name(name: str):
"""Check that `name` is a valid parameter name and raise an exception if it isn't."""
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
raise SubmarineException(
@@ -63,7 +63,7 @@ def _validate_param_name(name):
)
-def _validate_metric_name(name):
+def _validate_metric_name(name: str):
"""Check that `name` is a valid metric name and raise an exception if it isn't."""
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
raise SubmarineException(
@@ -74,7 +74,7 @@ def _validate_metric_name(name):
raise SubmarineException("Invalid metric name: '%s'. %s" % (name, bad_path_message(name)))
-def _validate_length_limit(entity_name, limit, value):
+def _validate_length_limit(entity_name: str, limit: int, value):
if len(value) > limit:
raise SubmarineException(
"%s '%s' had length %s, which exceeded length limit of %s"
@@ -82,7 +82,7 @@ def _validate_length_limit(entity_name, limit, value):
)
-def validate_metric(key, value, timestamp, step):
+def validate_metric(key, value, timestamp, step) -> None:
"""
Check that a param with the specified key, value, timestamp is valid and raise an exception if
it isn't.
@@ -107,7 +107,7 @@ def validate_metric(key, value, timestamp, step):
)
-def validate_param(key, value):
+def validate_param(key, value) -> None:
"""
Check that a param with the specified key & value is valid and raise an exception if it
isn't.
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org