You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2021/10/29 09:12:40 UTC

[submarine] branch master updated: SUBMARINE-1045. Add static type parameter in submarine-sdk

This is an automated email from the ASF dual-hosted git repository.

pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fb6652  SUBMARINE-1045. Add static type parameter in submarine-sdk
1fb6652 is described below

commit 1fb665277a55f76a37aa23244b150fcd42683d9d
Author: rayray2002 <ra...@gmail.com>
AuthorDate: Mon Oct 25 15:14:41 2021 +0800

    SUBMARINE-1045. Add static type parameter in submarine-sdk
    
    ### What is this PR for?
    <!-- A few sentences describing the overall goals of the pull request's commits.
    First time? Check out the contributing guide - https://submarine.apache.org/contribution/contributions.html
    -->
    Add static type parameter in submarine-sdk
    
    ### What type of PR is it?
    [Improvement]
    
    ### Todos
    
    ### What is the Jira issue?
    <!-- * Open an issue on Jira https://issues.apache.org/jira/browse/SUBMARINE/
    * Put link here, and add [SUBMARINE-*Jira number*] in PR title, eg. `SUBMARINE-23. PR title`
    -->
    https://issues.apache.org/jira/projects/SUBMARINE/issues/SUBMARINE-1045
    
    ### How should this be tested?
    <!--
    * First time? Setup Travis CI as described on https://submarine.apache.org/contribution/contributions.html#continuous-integration
    * Strongly recommended: add automated unit tests for any new or changed behavior
    * Outline any manual steps to test the PR here.
    -->
    mypy passed
    runtime type check
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Do the license files need updating? Yes/No
    * Are there breaking changes for older versions? Yes/No
    * Does this need new documentation? Yes/No
    
    Author: rayray2002 <ra...@gmail.com>
    
    Signed-off-by: Kevin <pi...@apache.org>
    
    Closes #781 from rayray2002/SUBMARINE-1045 and squashes the following commits:
    
    a846b3cb [rayray2002] SUBMARINE-1045. Add static type parameter in submarine-sdk
    53256dde [rayray2002] SUBMARINE-1045. Add static type parameter in submarine-sdk
---
 .../pysubmarine/submarine/artifacts/repository.py      |  8 ++++----
 .../submarine/entities/_submarine_object.py            |  8 ++++----
 .../submarine/entities/model_registry/model_stages.py  |  2 +-
 .../submarine/entities/model_registry/model_version.py |  2 +-
 .../entities/model_registry/registered_model.py        |  2 +-
 .../submarine/experiment/api/experiment_client.py      |  6 +++---
 .../submarine/experiment/models/code_spec.py           | 16 +++++++++-------
 submarine-sdk/pysubmarine/submarine/experiment/rest.py |  4 ++--
 .../pysubmarine/submarine/ml/pytorch/layers/core.py    | 14 +++++++-------
 submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py |  2 +-
 .../pysubmarine/submarine/ml/pytorch/metric.py         |  2 +-
 .../pysubmarine/submarine/ml/tensorflow/input/input.py |  8 ++++----
 .../pysubmarine/submarine/ml/tensorflow/layers/core.py | 18 ++++++++++--------
 .../pysubmarine/submarine/ml/tensorflow/optimizer.py   |  2 +-
 submarine-sdk/pysubmarine/submarine/models/client.py   | 18 +++++++++---------
 submarine-sdk/pysubmarine/submarine/models/pytorch.py  |  2 +-
 .../pysubmarine/submarine/models/tensorflow.py         |  2 +-
 .../submarine/store/tracking/sqlalchemy_store.py       |  2 +-
 submarine-sdk/pysubmarine/submarine/utils/__init__.py  |  2 +-
 submarine-sdk/pysubmarine/submarine/utils/db_utils.py  |  6 +++---
 submarine-sdk/pysubmarine/submarine/utils/env.py       | 12 ++++++------
 .../pysubmarine/submarine/utils/rest_utils.py          |  4 ++--
 .../pysubmarine/submarine/utils/validation.py          | 14 +++++++-------
 23 files changed, 80 insertions(+), 76 deletions(-)

diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 5a60648..9bee7ae 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -19,7 +19,7 @@ import boto3
 
 
 class Repository:
-    def __init__(self, experiment_id):
+    def __init__(self, experiment_id: str):
         self.client = boto3.client(
             "s3",
             aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
@@ -28,10 +28,10 @@ class Repository:
         )
         self.dest_path = experiment_id
 
-    def _upload_file(self, local_file, bucket, key):
+    def _upload_file(self, local_file: str, bucket: str, key: str) -> None:
         self.client.upload_file(Filename=local_file, Bucket=bucket, Key=key)
 
-    def _list_artifact_subfolder(self, artifact_path):
+    def _list_artifact_subfolder(self, artifact_path: str):
         response = self.client.list_objects(
             Bucket="submarine",
             Prefix=os.path.join(self.dest_path, artifact_path) + "/",
@@ -39,7 +39,7 @@ class Repository:
         )
         return response.get("CommonPrefixes")
 
-    def log_artifact(self, local_file, artifact_path):
+    def log_artifact(self, local_file: str, artifact_path: str) -> None:
         bucket = "submarine"
         dest_path = self.dest_path
         dest_path = os.path.join(dest_path, artifact_path)
diff --git a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
index ffcea7f..db2ad09 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
@@ -31,11 +31,11 @@ class _SubmarineObject:
         filtered_dict = {key: value for key, value in the_dict.items() if key in cls._properties()}
         return cls(**filtered_dict)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return to_string(self)
 
 
-def to_string(obj):
+def to_string(obj) -> str:
     return _SubmarineObjectPrinter().to_string(obj)
 
 
@@ -48,10 +48,10 @@ class _SubmarineObjectPrinter:
         super(_SubmarineObjectPrinter, self).__init__()
         self.printer = pprint.PrettyPrinter()
 
-    def to_string(self, obj):
+    def to_string(self, obj) -> str:
         if isinstance(obj, _SubmarineObject):
             return "<%s: %s>" % (get_classname(obj), self._entity_to_string(obj))
         return self.printer.pformat(obj)
 
-    def _entity_to_string(self, entity):
+    def _entity_to_string(self, entity) -> str:
         return ", ".join(["%s=%s" % (key, self.to_string(value)) for key, value in entity])
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
index 4a5e565..3d3f556 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
@@ -26,7 +26,7 @@ ALL_STAGES = [STAGE_NONE, STAGE_DEVELOPING, STAGE_PRODUCTION, STAGE_ARCHIVED]
 _CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
 
 
-def get_canonical_stage(stage):
+def get_canonical_stage(stage: str) -> str:
     key = stage.lower()
     if key not in _CANONICAL_MAPPING:
         raise SubmarineException(f"Invalid Model Version stage {stage}.")
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
index 86652b6..0b43e0a 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
@@ -98,6 +98,6 @@ class ModelVersion(_SubmarineObject):
         return self._description
 
     @property
-    def tags(self):
+    def tags(self) -> list:
         """List of strings."""
         return self._tags
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
index b88ac22..f94c5a1 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
@@ -49,6 +49,6 @@ class RegisteredModel(_SubmarineObject):
         return self._description
 
     @property
-    def tags(self):
+    def tags(self) -> list:
         """List of strings"""
         return self._tags
diff --git a/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py b/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py
index d659876..f54481e 100644
--- a/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py
+++ b/submarine-sdk/pysubmarine/submarine/experiment/api/experiment_client.py
@@ -38,7 +38,7 @@ def generate_host():
 
 
 class ExperimentClient:
-    def __init__(self, host=generate_host()):
+    def __init__(self, host: str = generate_host()):
         """
         Submarine experiment client constructor
         :param host: An HTTP URI like http://submarine-server:8080.
@@ -59,7 +59,7 @@ class ExperimentClient:
         response = self.experiment_api.create_experiment(experiment_spec=experiment_spec)
         return response.result
 
-    def wait_for_finish(self, id, polling_interval=10):
+    def wait_for_finish(self, id, polling_interval: float = 10):
         """
         Waits until experiment is finished or failed
         :param id: submarine experiment id
@@ -75,7 +75,7 @@ class ExperimentClient:
             index = self._log_pod(id, index)
             time.sleep(polling_interval)
 
-    def _log_pod(self, id, index):
+    def _log_pod(self, id, index: int):
         response = self.experiment_api.get_log(id)
         log_contents = response.result["logContent"]
         if len(log_contents) == 0:
diff --git a/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py b/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py
index b319f90..06e5f0f 100644
--- a/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py
+++ b/submarine-sdk/pysubmarine/submarine/experiment/models/code_spec.py
@@ -52,7 +52,9 @@ class CodeSpec(object):
 
     attribute_map = {"sync_mode": "syncMode", "url": "url"}
 
-    def __init__(self, sync_mode=None, url=None, local_vars_configuration=None):  # noqa: E501
+    def __init__(
+        self, sync_mode: str = None, url: str = None, local_vars_configuration: Configuration = None
+    ):  # noqa: E501
         """CodeSpec - a model defined in OpenAPI"""  # noqa: E501
         if local_vars_configuration is None:
             local_vars_configuration = Configuration()
@@ -78,7 +80,7 @@ class CodeSpec(object):
         return self._sync_mode
 
     @sync_mode.setter
-    def sync_mode(self, sync_mode):
+    def sync_mode(self, sync_mode: str) -> None:
         """Sets the sync_mode of this CodeSpec.
 
 
@@ -99,7 +101,7 @@ class CodeSpec(object):
         return self._url
 
     @url.setter
-    def url(self, url):
+    def url(self, url: str) -> None:
         """Sets the url of this CodeSpec.
 
 
@@ -135,22 +137,22 @@ class CodeSpec(object):
 
         return result
 
-    def to_str(self):
+    def to_str(self) -> str:
         """Returns the string representation of the model"""
         return pprint.pformat(self.to_dict())
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """For `print` and `pprint`"""
         return self.to_str()
 
-    def __eq__(self, other):
+    def __eq__(self, other) -> bool:
         """Returns true if both objects are equal"""
         if not isinstance(other, CodeSpec):
             return False
 
         return self.to_dict() == other.to_dict()
 
-    def __ne__(self, other):
+    def __ne__(self, other) -> bool:
         """Returns true if both objects are not equal"""
         if not isinstance(other, CodeSpec):
             return True
diff --git a/submarine-sdk/pysubmarine/submarine/experiment/rest.py b/submarine-sdk/pysubmarine/submarine/experiment/rest.py
index 06b8d93..18fd44d 100644
--- a/submarine-sdk/pysubmarine/submarine/experiment/rest.py
+++ b/submarine-sdk/pysubmarine/submarine/experiment/rest.py
@@ -55,13 +55,13 @@ class RESTResponse(io.IOBase):
         """Returns a dictionary of the response headers."""
         return self.urllib3_response.getheaders()
 
-    def getheader(self, name, default=None):
+    def getheader(self, name: str, default=None):
         """Returns a given response header."""
         return self.urllib3_response.getheader(name, default)
 
 
 class RESTClientObject(object):
-    def __init__(self, configuration, pools_size=4, maxsize=None):
+    def __init__(self, configuration, pools_size: int = 4, maxsize: int = None):
         # urllib3.PoolManager will pass all kw parameters to connectionpool
         # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75  # noqa: E501
         # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680  # noqa: E501
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
index 265ea1f..fd7d9e9 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
@@ -19,7 +19,7 @@ from torch import nn
 
 # pylint: disable=W0223
 class FeatureLinear(nn.Module):
-    def __init__(self, num_features, out_features):
+    def __init__(self, num_features: int, out_features: int):
         """
         :param num_features: number of total features.
         :param out_features: The number of output features.
@@ -28,7 +28,7 @@ class FeatureLinear(nn.Module):
         self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=out_features)
         self.bias = nn.Parameter(torch.zeros((out_features,)))
 
-    def forward(self, feature_idx, feature_value):
+    def forward(self, feature_idx: torch.LongTensor, feature_value: torch.LongTensor):
         """
         :param feature_idx: torch.LongTensor (batch_size, num_fields)
         :param feature_value: torch.LongTensor (batch_size, num_fields)
@@ -39,11 +39,11 @@ class FeatureLinear(nn.Module):
 
 
 class FeatureEmbedding(nn.Module):
-    def __init__(self, num_features, embedding_dim):
+    def __init__(self, num_features: int, embedding_dim):
         super().__init__()
         self.weight = nn.Embedding(num_embeddings=num_features, embedding_dim=embedding_dim)
 
-    def forward(self, feature_idx, feature_value):
+    def forward(self, feature_idx: torch.LongTensor, feature_value: torch.LongTensor):
         """
         :param feature_idx: torch.LongTensor (batch_size, num_fields)
         :param feature_value: torch.LongTensor (batch_size, num_fields)
@@ -52,7 +52,7 @@ class FeatureEmbedding(nn.Module):
 
 
 class PairwiseInteraction(nn.Module):
-    def forward(self, x):
+    def forward(self, x: torch.Tensor):
         """
         :param x: torch.Tensor (batch_size, num_fields, embedding_dim)
         """
@@ -65,7 +65,7 @@ class PairwiseInteraction(nn.Module):
 
 
 class DNN(nn.Module):
-    def __init__(self, in_features, out_features, hidden_units, dropout_rates):
+    def __init__(self, in_features: int, out_features: int, hidden_units, dropout_rates):
         super().__init__()
         *layers, out_layer = list(zip([in_features, *hidden_units], [*hidden_units, out_features]))
         self.net = nn.Sequential(
@@ -81,7 +81,7 @@ class DNN(nn.Module):
             nn.Linear(*out_layer)
         )
 
-    def forward(self, x):
+    def forward(self, x: torch.FloatTensor):
         """
         :param x: torch.FloatTensor (batch_size, in_features)
         """
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py
index 234c237..952863f 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/loss.py
@@ -23,7 +23,7 @@ class LossKey:
     BCEWithLogitsLoss = "BCEWithLogitsLoss".lower()
 
 
-def get_loss_fn(key):
+def get_loss_fn(key: str):
     key = key.lower()
     if key == LossKey.BCELoss:
         return nn.BCELoss
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py
index 43f3d26..0c838fd 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/metric.py
@@ -24,7 +24,7 @@ class MetricKey:
     RECALL = "recall"
 
 
-def get_metric_fn(key):
+def get_metric_fn(key: str):
     key = key.lower()
     if key == MetricKey.F1_SCORE:
         return metrics.f1_score
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
index f779fb6..0cc7259 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
@@ -24,10 +24,10 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE
 
 def libsvm_input_fn(
     filepath,
-    batch_size=256,
-    num_epochs=3,  # pylint: disable=W0613
-    perform_shuffle=False,
-    delimiter=" ",
+    batch_size: int = 256,
+    num_epochs: int = 3,  # pylint: disable=W0613
+    perform_shuffle: bool = False,
+    delimiter: str = " ",
     **kwargs
 ):
     def _input_fn():
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
index dbe048f..47f3698 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
@@ -43,12 +43,12 @@ def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay):
 
 def dnn_layer(
     inputs,
-    estimator_mode,
-    batch_norm,
-    deep_layers,
+    estimator_mode: str,
+    batch_norm: bool,
+    deep_layers: list,
     dropout,
-    batch_norm_decay=0.9,
-    l2_reg=0,
+    batch_norm_decay: float = 0.9,
+    l2_reg: float = 0,
     **kwargs
 ):
     """
@@ -100,7 +100,7 @@ def dnn_layer(
     return deep_out
 
 
-def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs):
+def linear_layer(features, feature_size, field_size, l2_reg: float = 0, **kwargs):
     """
     Layer which represents linear function.
     :param features: input features
@@ -131,7 +131,9 @@ def linear_layer(features, feature_size, field_size, l2_reg=0, **kwargs):
     return linear_out
 
 
-def embedding_layer(features, feature_size, field_size, embedding_size, l2_reg=0, **kwargs):
+def embedding_layer(
+    features, feature_size, field_size, embedding_size, l2_reg: float = 0, **kwargs
+):
     """
     Turns positive integers (indexes) into dense vectors of fixed size.
     eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
@@ -199,7 +201,7 @@ class KMaxPooling(Layer):
       - **axis**: positive integer, the dimension to look for elements.
     """
 
-    def __init__(self, k=1, axis=-1, **kwargs):
+    def __init__(self, k: int = 1, axis: int = -1, **kwargs):
 
         self.dims = 1
         self.k = k
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
index dd61d6e..3ab9bbb 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
@@ -29,7 +29,7 @@ class OptimizerKey(object):
     FTRL = "ftrl"
 
 
-def get_optimizer(optimizer_key, learning_rate):
+def get_optimizer(optimizer_key: str, learning_rate: float):
     optimizer_key = optimizer_key.lower()
 
     if optimizer_key == OptimizerKey.ADAM:
diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py
index e633188..da13082 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -33,10 +33,10 @@ from .utils import exist_ps, get_job_id, get_worker_index
 class ModelsClient:
     def __init__(
         self,
-        tracking_uri=None,
-        registry_uri=None,
-        aws_access_key_id=None,
-        aws_secret_access_key=None,
+        tracking_uri: str = None,
+        registry_uri: str = None,
+        aws_access_key_id: str = None,
+        aws_secret_access_key: str = None,
     ):
         """
         Set up mlflow server connection, including: s3 endpoint, aws, tracking server
@@ -69,26 +69,26 @@ class ModelsClient:
         experiment_id = self._get_or_create_experiment(experiment_name)
         return mlflow.start_run(run_name=run_name, experiment_id=experiment_id)
 
-    def log_param(self, key, value):
+    def log_param(self, key: str, value: str):
         mlflow.log_param(key, value)
 
     def log_params(self, params):
         mlflow.log_params(params)
 
-    def log_metric(self, key, value, step=None):
+    def log_metric(self, key: str, value: str, step=None):
         mlflow.log_metric(key, value, step)
 
     def log_metrics(self, metrics, step=None):
         mlflow.log_metrics(metrics, step)
 
-    def load_model(self, name, version):
+    def load_model(self, name: str, version: str):
         model = mlflow.pyfunc.load_model(model_uri=f"models:/{name}/{version}")
         return model
 
-    def update_model(self, name, new_name):
+    def update_model(self, name: str, new_name: str):
         self.client.rename_registered_model(name=name, new_name=new_name)
 
-    def delete_model(self, name, version):
+    def delete_model(self, name: str, version: str):
         self.client.delete_model_version(name=name, version=version)
 
     def save_model(self, model_type, model, artifact_path, registered_model_name=None):
diff --git a/submarine-sdk/pysubmarine/submarine/models/pytorch.py b/submarine-sdk/pysubmarine/submarine/models/pytorch.py
index a143aa5..38cdd57 100644
--- a/submarine-sdk/pysubmarine/submarine/models/pytorch.py
+++ b/submarine-sdk/pysubmarine/submarine/models/pytorch.py
@@ -18,5 +18,5 @@ import os
 import torch
 
 
-def save_model(model, artifact_path):
+def save_model(model, artifact_path: str):
     torch.save(model, os.path.join(artifact_path, "model.pth"))
diff --git a/submarine-sdk/pysubmarine/submarine/models/tensorflow.py b/submarine-sdk/pysubmarine/submarine/models/tensorflow.py
index fbe5324..a947a91 100644
--- a/submarine-sdk/pysubmarine/submarine/models/tensorflow.py
+++ b/submarine-sdk/pysubmarine/submarine/models/tensorflow.py
@@ -14,5 +14,5 @@
 # limitations under the License.
 
 
-def save_model(model, artifact_path):
+def save_model(model, artifact_path: str):
     model.save(artifact_path)
diff --git a/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
index 23e8a8d..e7e1e3a 100644
--- a/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
@@ -103,7 +103,7 @@ class SqlAlchemyStore(AbstractStore):
         return make_managed_session
 
     @staticmethod
-    def _save_to_db(session, objs):
+    def _save_to_db(session, objs: object):
         """
         Store in db
         """
diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
index 4908ba6..1b0f045 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
@@ -19,7 +19,7 @@ from submarine.exceptions import SubmarineException
 from submarine.utils.db_utils import get_db_uri, set_db_uri
 
 
-def extract_db_type_from_uri(db_uri):
+def extract_db_type_from_uri(db_uri: str):
     """
     Parse the specified DB URI to extract the database type. Confirm the database type is
     supported. If a driver is specified, confirm it passes a plausible regex.
diff --git a/submarine-sdk/pysubmarine/submarine/utils/db_utils.py b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
index b23ce2d..8fcc9a5 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
@@ -22,14 +22,14 @@ _DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
 _db_uri = None
 
 
-def is_db_uri_set():
+def is_db_uri_set() -> bool:
     """Returns True if the DB URI has been set, False otherwise."""
     if _db_uri or env.get_env(_DB_URI_ENV_VAR):
         return True
     return False
 
 
-def set_db_uri(uri):
+def set_db_uri(uri: str):
     """
     Set the DB URI. This does not affect the currently active run (if one exists),
     but takes effect for successive runs.
@@ -38,7 +38,7 @@ def set_db_uri(uri):
     _db_uri = uri
 
 
-def get_db_uri():
+def get_db_uri() -> str:
     """
     Get the current DB URI.
     :return: The DB URI.
diff --git a/submarine-sdk/pysubmarine/submarine/utils/env.py b/submarine-sdk/pysubmarine/submarine/utils/env.py
index 3797efc..110a134 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/env.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/env.py
@@ -19,22 +19,22 @@ import os
 from collections.abc import Mapping
 
 
-def get_env(variable_name):
+def get_env(variable_name: str):
     return os.environ.get(variable_name)
 
 
-def unset_variable(variable_name):
+def unset_variable(variable_name: str) -> None:
     if variable_name in os.environ:
         del os.environ[variable_name]
 
 
-def check_env_exists(variable_name):
+def check_env_exists(variable_name: str) -> bool:
     if variable_name not in os.environ:
         return False
     return True
 
 
-def get_from_json(path, defaultParams):
+def get_from_json(path: str, defaultParams: dict):
     """
     If model parameters not specify in Json, use parameter in defaultParams
     :param path: The json file that specifies the model parameters.
@@ -50,7 +50,7 @@ def get_from_json(path, defaultParams):
     return get_from_dicts(params, defaultParams)
 
 
-def get_from_dicts(params, defaultParams):
+def get_from_dicts(params: dict, defaultParams: dict):
     """
     If model parameters not specify in params, use parameter in defaultParams
     :param params: parameters which will be merged
@@ -71,7 +71,7 @@ def get_from_dicts(params, defaultParams):
     return dct
 
 
-def get_from_registry(key, registry):
+def get_from_registry(key: str, registry: dict):
     if hasattr(key, "lower"):
         key = key.lower()
     if key in registry:
diff --git a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
index db71b13..5054af0 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
@@ -51,7 +51,7 @@ def http_request(base_url, endpoint, method, json_body, timeout=60, headers=None
     return result
 
 
-def _can_parse_as_json(string):
+def _can_parse_as_json(string: str) -> bool:
     try:
         json.loads(string)
         return True
@@ -59,7 +59,7 @@ def _can_parse_as_json(string):
         return False
 
 
-def verify_rest_response(response, endpoint):
+def verify_rest_response(response, endpoint: str):
     """Verify the return code and raise exception if the request was not successful."""
     if response.status_code != 200:
         if _can_parse_as_json(response.text):
diff --git a/submarine-sdk/pysubmarine/submarine/utils/validation.py b/submarine-sdk/pysubmarine/submarine/utils/validation.py
index 00e2c98..049a873 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/validation.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/validation.py
@@ -37,7 +37,7 @@ _BAD_CHARACTERS_MESSAGE = (
 _UNSUPPORTED_DB_TYPE_MSG = "Supported database engines are {%s}" % ", ".join(DATABASE_ENGINES)
 
 
-def bad_path_message(name):
+def bad_path_message(name: str):
     return (
         "Names may be treated as files in certain cases, and must not resolve to other names"
         " when treated as such. This name would resolve to '%s'"
@@ -45,12 +45,12 @@ def bad_path_message(name):
     )
 
 
-def path_not_unique(name):
+def path_not_unique(name: str):
     norm = posixpath.normpath(name)
     return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/")
 
 
-def _validate_param_name(name):
+def _validate_param_name(name: str):
     """Check that `name` is a valid parameter name and raise an exception if it isn't."""
     if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
         raise SubmarineException(
@@ -63,7 +63,7 @@ def _validate_param_name(name):
         )
 
 
-def _validate_metric_name(name):
+def _validate_metric_name(name: str):
     """Check that `name` is a valid metric name and raise an exception if it isn't."""
     if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
         raise SubmarineException(
@@ -74,7 +74,7 @@ def _validate_metric_name(name):
         raise SubmarineException("Invalid metric name: '%s'. %s" % (name, bad_path_message(name)))
 
 
-def _validate_length_limit(entity_name, limit, value):
+def _validate_length_limit(entity_name: str, limit: int, value):
     if len(value) > limit:
         raise SubmarineException(
             "%s '%s' had length %s, which exceeded length limit of %s"
@@ -82,7 +82,7 @@ def _validate_length_limit(entity_name, limit, value):
         )
 
 
-def validate_metric(key, value, timestamp, step):
+def validate_metric(key, value, timestamp, step) -> None:
     """
     Check that a param with the specified key, value, timestamp is valid and raise an exception if
     it isn't.
@@ -107,7 +107,7 @@ def validate_metric(key, value, timestamp, step):
         )
 
 
-def validate_param(key, value):
+def validate_param(key, value) -> None:
     """
     Check that a param with the specified key & value is valid and raise an exception if it
     isn't.

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org