KUAN-HSUN-LI commented on a change in pull request #801:
URL: https://github.com/apache/submarine/pull/801#discussion_r749035130
##########
File path: submarine-sdk/pysubmarine/submarine/tracking/client.py
##########
@@ -101,24 +109,53 @@ def save_model(
:param artifact_path: Relative path of the artifact in the minio pod.
:param registered_model_name: If not None, register model into the
model registry with
this name. If None, the model only be
saved in minio pod.
+ :param input_dim: Save the input dimension of the given model to the
description file.
+ :param output_dim: Save the output dimension of the given model to the
description file.
"""
pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
if not re.fullmatch(pattern, artifact_path):
raise Exception(
- "Artifact_path must only contains numbers, characters, hyphen
and underscore. "
- " Artifact_path must starts and ends with numbers or
characters."
+ "Artifact_path must only contains numbers, characters, hyphen
and underscore. "
+ "Artifact_path must starts and ends with numbers or
characters."
)
with tempfile.TemporaryDirectory() as tempdir:
+ description: Dict[str, Any] = dict()
+ model_save_dir = os.path.join(tempdir, "1")
+ os.mkdir(model_save_dir)
if model_type == "pytorch":
import submarine.models.pytorch
- submarine.models.pytorch.save_model(model, tempdir)
+ submarine.models.pytorch.save_model(model, model_save_dir)
+ description["model_type"] = "pytorch"
Review comment:
We are using triton to serve the PyTorch model therefore we need to use
the `torch.jit.trace` to save the model. We must make it required for saving
the PyTorch model since the `torch.jit.trace` function requires the input shape
of the model.
##########
File path: submarine-sdk/pysubmarine/submarine/tracking/client.py
##########
@@ -101,24 +109,53 @@ def save_model(
:param artifact_path: Relative path of the artifact in the minio pod.
:param registered_model_name: If not None, register model into the
model registry with
this name. If None, the model only be
saved in minio pod.
+ :param input_dim: Save the input dimension of the given model to the
description file.
+ :param output_dim: Save the output dimension of the given model to the
description file.
"""
pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
if not re.fullmatch(pattern, artifact_path):
raise Exception(
- "Artifact_path must only contains numbers, characters, hyphen
and underscore. "
- " Artifact_path must starts and ends with numbers or
characters."
+ "Artifact_path must only contains numbers, characters, hyphen
and underscore. "
+ "Artifact_path must starts and ends with numbers or
characters."
)
with tempfile.TemporaryDirectory() as tempdir:
+ description: Dict[str, Any] = dict()
+ model_save_dir = os.path.join(tempdir, "1")
+ os.mkdir(model_save_dir)
if model_type == "pytorch":
import submarine.models.pytorch
- submarine.models.pytorch.save_model(model, tempdir)
+ submarine.models.pytorch.save_model(model, model_save_dir)
+ description["model_type"] = "pytorch"
elif model_type == "tensorflow":
import submarine.models.tensorflow
- submarine.models.tensorflow.save_model(model, tempdir)
+ submarine.models.tensorflow.save_model(model, model_save_dir)
+ description["model_type"] = "tensorflow"
else:
raise Exception("No valid type of model has been matched to
{}".format(model_type))
+
+ # Write description file
+ if input_dim is not None:
+ description["input"] = [
+ {
+ "name": "INPUT__0",
+ "data_type": "TYPE_FP32",
+ "dims": str(input_dim),
+ }
+ ]
+ if output_dim is not None:
+ description["output"] = [
+ {
+ "name": "OUTPUT__0",
+ "data_type": "TYPE_FP32",
+ "dims": output_dim,
+ }
+ ]
Review comment:
I think saving the input and output dimensions is enough.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]