liuminghui233 commented on code in PR #9337:
URL: https://github.com/apache/iotdb/pull/9337#discussion_r1138305788


##########
mlnode/iotdb/mlnode/constant.py:
##########
@@ -16,6 +16,11 @@
 # under the License.
 #
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
+MLNODE_CONF_DIRECTORY_NAME = "resources/conf"

Review Comment:
   ```suggestion
   MLNODE_CONF_DIRECTORY_NAME = "conf"
   ```



##########
mlnode/iotdb/mlnode/constant.py:
##########
@@ -16,6 +16,11 @@
 # under the License.
 #
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
+MLNODE_CONF_DIRECTORY_NAME = "resources/conf"
 MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
 MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
+
+MLNODE_MODEL_STORAGE_DIR = "ml_models"

Review Comment:
   ```suggestion
   MLNODE_MODEL_STORAGE_DIR = "models"
   ```



##########
mlnode/iotdb/mlnode/constant.py:
##########
@@ -16,6 +16,11 @@
 # under the License.
 #
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
+MLNODE_CONF_DIRECTORY_NAME = "resources/conf"
 MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
 MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
+
+MLNODE_MODEL_STORAGE_DIR = "ml_models"
+MLNODE_MODEL_STORAGE_CACHESIZE = 30
+
+MLNODE_REQUEST_TEMPLATE = "resources/template"

Review Comment:
   ```suggestion
   MLNODE_REQUEST_TEMPLATE = "template"
   ```



##########
mlnode/iotdb/mlnode/constant.py:
##########
@@ -16,6 +16,11 @@
 # under the License.
 #
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
+MLNODE_CONF_DIRECTORY_NAME = "resources/conf"
 MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
 MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
+
+MLNODE_MODEL_STORAGE_DIR = "ml_models"
+MLNODE_MODEL_STORAGE_CACHESIZE = 30

Review Comment:
   The cache size should be a configuration, not a constant



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):

Review Comment:
   ```suggestion
   class ModelStorage(object):
   ```



##########
mlnode/iotdb/mlnode/storage/__init__.py:
##########


Review Comment:
   `storage` seems to be a redundant module level



##########
mlnode/iotdb/mlnode/constant.py:
##########
@@ -16,6 +16,11 @@
 # under the License.
 #
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
+MLNODE_CONF_DIRECTORY_NAME = "resources/conf"
 MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
 MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
+
+MLNODE_MODEL_STORAGE_DIR = "ml_models"

Review Comment:
   should be a configuration



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):

Review Comment:
   Read configuration parameters instead of passing in parameters



##########
mlnode/iotdb/mlnode/constant.py:
##########
@@ -16,6 +16,11 @@
 # under the License.
 #
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
+MLNODE_CONF_DIRECTORY_NAME = "resources/conf"
 MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
 MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
+
+MLNODE_MODEL_STORAGE_DIR = "ml_models"
+MLNODE_MODEL_STORAGE_CACHESIZE = 30
+
+MLNODE_REQUEST_TEMPLATE = "resources/template"

Review Comment:
   not used



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)
+
+    def save_model(self,
+                   model: nn.Module,
+                   model_config: dict,
+                   model_id: str,
+                   trial_id: str):
+        """
+        Return: True if successfully saved
+        """
+        fold_path = os.path.join(self.root_path, f'{model_id}')
+        if not os.path.exists(fold_path):
+            os.mkdir(fold_path)
+        sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]
+        torch.jit.save(torch.jit.trace(model, sample_input),
+                       os.path.join(fold_path, f'{trial_id}.pt'),
+                       _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.exists(os.path.join(fold_path, f'{trial_id}.pt'))
+
+    def load_model(self, model_id: str, trial_id: str):
+        file_path = os.path.join(self.root_path, f'{model_id}', 
f'{trial_id}.pt')
+        if model_id in self._loaded_model_cache:
+            return self._loaded_model_cache[file_path]

Review Comment:
   Should `file_path` be used to determine whether the cache is hit?



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)
+
+    def save_model(self,
+                   model: nn.Module,
+                   model_config: dict,
+                   model_id: str,
+                   trial_id: str):
+        """
+        Return: True if successfully saved
+        """
+        fold_path = os.path.join(self.root_path, f'{model_id}')
+        if not os.path.exists(fold_path):
+            os.mkdir(fold_path)
+        sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]
+        torch.jit.save(torch.jit.trace(model, sample_input),
+                       os.path.join(fold_path, f'{trial_id}.pt'),
+                       _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.exists(os.path.join(fold_path, f'{trial_id}.pt'))

Review Comment:
   We should use `try...except...`. If there is no exception, we can consider 
the model to be saved successfully.



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)
+
+    def save_model(self,
+                   model: nn.Module,
+                   model_config: dict,
+                   model_id: str,
+                   trial_id: str):
+        """
+        Return: True if successfully saved
+        """
+        fold_path = os.path.join(self.root_path, f'{model_id}')
+        if not os.path.exists(fold_path):
+            os.mkdir(fold_path)
+        sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]

Review Comment:
   Is it only for forecast models?



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)

Review Comment:
   ```suggestion
           self.__model_cache = lrucache(cache_size)
   ```



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)
+
+    def save_model(self,
+                   model: nn.Module,
+                   model_config: dict,
+                   model_id: str,
+                   trial_id: str):
+        """
+        Return: True if successfully saved
+        """
+        fold_path = os.path.join(self.root_path, f'{model_id}')
+        if not os.path.exists(fold_path):
+            os.mkdir(fold_path)
+        sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]
+        torch.jit.save(torch.jit.trace(model, sample_input),
+                       os.path.join(fold_path, f'{trial_id}.pt'),
+                       _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.exists(os.path.join(fold_path, f'{trial_id}.pt'))
+
+    def load_model(self, model_id: str, trial_id: str):
+        file_path = os.path.join(self.root_path, f'{model_id}', 
f'{trial_id}.pt')
+        if model_id in self._loaded_model_cache:
+            return self._loaded_model_cache[file_path]
+        else:
+            if not os.path.exists(file_path):
+                raise RuntimeError('Model path (%s) is not found' % file_path)
+            else:
+                tmp_dict = {'model_config': ''}
+                jit_model = torch.jit.load(file_path, _extra_files=tmp_dict)
+                model_config = json.loads(tmp_dict['model_config'])
+                self._loaded_model_cache[file_path] = jit_model, model_config
+                return jit_model, model_config
+
+    def _remove_from_cache(self, key: str):
+        if key in self._loaded_model_cache:
+            del self._loaded_model_cache[key]
+
+    def delete_trial(self, model_id: str, trial_id: str):
+        """
+        Return: True if successfully deleted
+        """
+        file_path = os.path.join(self.root_path, f'{model_id}', 
f'{trial_id}.pt')
+        self._remove_from_cache(file_path)
+        if os.path.exists(file_path):
+            os.remove(file_path)
+        return not os.path.exists(file_path)
+
+    def delete_model(self, model_id: str):
+        """
+        Return: True if successfully deleted
+        """
+        folder_path = os.path.join(self.root_path, f'{model_id}')
+        if os.path.exists(folder_path):
+            for file_name in os.listdir(folder_path):
+                self._remove_from_cache(os.path.join(folder_path, file_name))
+            shutil.rmtree(folder_path)
+        return not os.path.exists(folder_path)
+
+    def delete_by_path(self, model_path: str):  # TODO: for test only, remove 
this when thrift has redefined

Review Comment:
   You can just delete this method.



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)

Review Comment:
   What is the unit of cache size? We should use memory size as the unit of 
capacity.



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########


Review Comment:
   We should declare the return type of each method



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)

Review Comment:
   ```suggestion
           self.__model_dir = os.path.join(os.getcwd(), root_path)
   ```



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)
+
+    def save_model(self,
+                   model: nn.Module,
+                   model_config: dict,
+                   model_id: str,
+                   trial_id: str):
+        """
+        Return: True if successfully saved
+        """
+        fold_path = os.path.join(self.root_path, f'{model_id}')
+        if not os.path.exists(fold_path):
+            os.mkdir(fold_path)
+        sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]

Review Comment:
   What is the meaning of `sample` here?



##########
mlnode/iotdb/mlnode/storage/model_storager.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import os
+import json
+import torch
+import shutil
+import torch.nn as nn
+from pylru import lrucache
+from iotdb.mlnode.constant import (MLNODE_MODEL_STORAGE_DIR,
+                                   MLNODE_MODEL_STORAGE_CACHESIZE)
+
+
+# TODO: Add permission check firstly
+# TODO: Consider concurrency, maybe
+class ModelStorager(object):
+    def __init__(self,
+                 root_path: str = 'ml_models',
+                 cache_size: int = 30):
+        current_path = os.getcwd()
+        self.root_path = os.path.join(current_path, root_path)
+        if not os.path.exists(self.root_path):
+            os.mkdir(self.root_path)
+        self._loaded_model_cache = lrucache(cache_size)
+
+    def save_model(self,
+                   model: nn.Module,
+                   model_config: dict,
+                   model_id: str,
+                   trial_id: str):
+        """
+        Return: True if successfully saved
+        """
+        fold_path = os.path.join(self.root_path, f'{model_id}')
+        if not os.path.exists(fold_path):
+            os.mkdir(fold_path)
+        sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]
+        torch.jit.save(torch.jit.trace(model, sample_input),
+                       os.path.join(fold_path, f'{trial_id}.pt'),
+                       _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.exists(os.path.join(fold_path, f'{trial_id}.pt'))
+
+    def load_model(self, model_id: str, trial_id: str):
+        file_path = os.path.join(self.root_path, f'{model_id}', 
f'{trial_id}.pt')
+        if model_id in self._loaded_model_cache:
+            return self._loaded_model_cache[file_path]
+        else:
+            if not os.path.exists(file_path):
+                raise RuntimeError('Model path (%s) is not found' % file_path)

Review Comment:
   Add an exception type in `exception.py`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to