Copilot commented on code in PR #16595:
URL: https://github.com/apache/iotdb/pull/16595#discussion_r2435375647
##########
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java:
##########
@@ -284,16 +284,7 @@ public TableFunctionAnalysis analyze(Map<String, Argument>
arguments) {
String.format("%s should never be null or empty",
MODEL_ID_PARAMETER_NAME));
}
- // make sure modelId exists
- ModelInferenceDescriptor descriptor = getModelInfo(modelId);
- if (descriptor == null || !descriptor.getModelInformation().available()) {
- throw new IoTDBRuntimeException(
- String.format("model [%s] is not available", modelId),
- TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
- }
-
- int maxInputLength = descriptor.getModelInformation().getInputShape()[0];
- TEndPoint targetAINode = descriptor.getTargetAINode();
+ TEndPoint targetAINode = getModelInfo(modelId).getTargetAINode();
Review Comment:
Potential null pointer dereference: `getModelInfo(modelId)` may return null,
but the result is directly dereferenced without a null check. The removed
validation code previously checked for this condition.
```suggestion
ModelInferenceDescriptor modelInfo = getModelInfo(modelId);
if (modelInfo == null) {
throw new SemanticException(
String.format("Model with id '%s' not found.", modelId));
}
TEndPoint targetAINode = modelInfo.getTargetAINode();
```
##########
iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py:
##########
@@ -119,39 +120,41 @@ def schedule(self, model_id: str) -> List[ScaleAction]:
return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
def schedule_load_model_to_device(
- self, model_id: str, device_id: str
+ self, model_info: ModelInfo, device_id: str
) -> List[ScaleAction]:
- existing_model_ids = [
- existing_model_id
+ existing_model_infos = [
+ MODEL_MANAGER.get_model_info(existing_model_id)
for existing_model_id, pool_group_map in
self._request_pool_map.items()
- if existing_model_id != model_id and device_id in pool_group_map
+ if existing_model_id != model_info.model_id and device_id in
pool_group_map
]
allocation_result = _estimate_shared_pool_size_by_total_mem(
device=convert_device_id_to_torch_device(device_id),
- existing_model_ids=existing_model_ids,
- new_model_id=model_id,
+ existing_model_infos=existing_model_infos,
+ new_model_info=model_info,
)
return self._convert_allocation_result_to_scale_actions(
allocation_result, device_id
)
def schedule_unload_model_from_device(
- self, model_id: str, device_id: str
+ self, model_info: ModelInfo, device_id: str
) -> List[ScaleAction]:
- existing_model_ids = [
- existing_model_id
+ existing_model_infos = [
+ MODEL_MANAGER.get_model_info(existing_model_id)
for existing_model_id, pool_group_map in
self._request_pool_map.items()
- if existing_model_id != model_id and device_id in pool_group_map
+ if existing_model_id != model_info.model_id and device_id in
pool_group_map
]
allocation_result = (
_estimate_shared_pool_size_by_total_mem(
device=convert_device_id_to_torch_device(device_id),
- existing_model_ids=existing_model_ids,
- new_model_id=None,
+ existing_model_infos=existing_model_infos,
+ new_model_info=None,
)
- if len(existing_model_ids) > 0
- else {model_id: 0}
+ if len(existing_model_infos) > 0
+ else {model_info.model_id: 0}
)
+ if len(existing_model_infos) > 0:
+ allocation_result[model_info.model_id] = 0
Review Comment:
The allocation result for the unloaded model is set to 0 after already being
computed in the conditional expression. This overwrites the computed value and
may cause incorrect behavior. Lines 156-157 should be removed as the logic at
lines 147-155 already handles this case correctly.
```suggestion
```
##########
iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java:
##########
@@ -186,33 +184,15 @@ public TShowAIDevicesResp showAIDevices() {
}
public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
- try {
- GetModelInfoResp response =
- (GetModelInfoResp) configManager.getConsensusManager().read(new
GetModelInfoPlan(req));
- if (response.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- return new TGetModelInfoResp(response.getStatus());
- }
- int aiNodeId = response.getTargetAINodeId();
- if (aiNodeId != 0) {
- response.setTargetAINodeAddress(
- configManager.getNodeManager().getRegisteredAINode(aiNodeId));
- } else {
- if (configManager.getNodeManager().getRegisteredAINodes().isEmpty()) {
- return new TGetModelInfoResp(
- new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode())
- .setMessage("There is no AINode available"));
- }
- response.setTargetAINodeAddress(
- configManager.getNodeManager().getRegisteredAINodes().get(0));
- }
- return response.convertToThriftResponse();
- } catch (ConsensusException e) {
- LOGGER.warn("Unexpected error happened while getting model: ", e);
- // consensus layer related errors
- TSStatus res = new
TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
- res.setMessage(e.getMessage());
- return new TGetModelInfoResp(res);
- }
+ return new TGetModelInfoResp()
+ .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()))
+ .setAiNodeAddress(
+ configManager
+ .getNodeManager()
+ .getRegisteredAINodes()
+ .get(0)
Review Comment:
Direct access to `.get(0)` without checking if the list is empty will throw
`IndexOutOfBoundsException` if no AINodes are registered. The removed code
previously handled this case with an appropriate error message.
```suggestion
List<TAINodeInfo> aiNodes =
configManager.getNodeManager().getRegisteredAINodes();
if (aiNodes.isEmpty()) {
return new TGetModelInfoResp()
.setStatus(
new
TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode())
.setMessage("No AINodes are registered."));
}
return new TGetModelInfoResp()
.setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()))
.setAiNodeAddress(
aiNodes.get(0)
```
##########
iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py:
##########
@@ -253,13 +256,19 @@ def _expand_pools_on_device(self, model_id: str,
device_id: str, count: int):
def _expand_pool_on_device(*_):
result_queue = mp.Queue()
pool_id = self._new_pool_id.get_and_increment()
- if model_id == "sundial":
+ model_info = MODEL_MANAGER.get_model_info(model_id)
+ model_type = model_info.model_type
+ if model_type == BuiltInModelType.SUNDIAL.value:
config = SundialConfig()
- elif model_id == "timer_xl":
+ elif model_id == BuiltInModelType.TIMER_XL.value:
Review Comment:
Incorrect comparison: comparing `model_id` (a string) with
`BuiltInModelType.TIMER_XL.value` instead of `model_type`. This should be
`model_type == BuiltInModelType.TIMER_XL.value` to match the pattern used in
line 261.
```suggestion
elif model_type == BuiltInModelType.TIMER_XL.value:
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]