This is an automated email from the ASF dual-hosted git repository.
wuzhiguo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/bigtop-manager.git
The following commit(s) were added to refs/heads/main by this push:
new 70dbb6ce BIGTOP-4260: Add chatbot command and tools (#101)
70dbb6ce is described below
commit 70dbb6cee24289e40484cc2155600b8fa6556195
Author: haopeng <[email protected]>
AuthorDate: Wed Dec 11 13:49:41 2024 +0800
BIGTOP-4260: Add chatbot command and tools (#101)
---
.../bigtop-manager-ai-dashscope/pom.xml | 5 --
.../server/controller/ChatbotController.java | 15 +++++-
.../manager/server/enums/ChatbotCommand.java | 62 ++++++++++++++++++++++
.../manager/server/service/ChatbotService.java | 5 +-
.../server/service/impl/ChatbotServiceImpl.java | 44 +++++++++++----
.../server/service/impl/LLMConfigServiceImpl.java | 48 +++++++++++++++--
.../server/tools/AiServiceToolsProvider.java | 47 ++++++++++++++++
.../manager/server/tools/ClusterInfoTools.java | 51 ++++++++++++++++++
.../server/controller/ChatbotControllerTest.java | 3 +-
9 files changed, 257 insertions(+), 23 deletions(-)
diff --git a/bigtop-manager-ai/bigtop-manager-ai-dashscope/pom.xml
b/bigtop-manager-ai/bigtop-manager-ai-dashscope/pom.xml
index 794a5b11..d062fa1f 100644
--- a/bigtop-manager-ai/bigtop-manager-ai-dashscope/pom.xml
+++ b/bigtop-manager-ai/bigtop-manager-ai-dashscope/pom.xml
@@ -40,10 +40,5 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId>
</dependency>
-
- <dependency>
- <groupId>dev.langchain4j</groupId>
- <artifactId>langchain4j-dashscope</artifactId>
- </dependency>
</dependencies>
</project>
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/ChatbotController.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/ChatbotController.java
index 8e3a2e67..14ef4cba 100644
---
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/ChatbotController.java
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/controller/ChatbotController.java
@@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.server.controller;
+import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.model.converter.ChatThreadConverter;
import org.apache.bigtop.manager.server.model.dto.ChatThreadDTO;
import org.apache.bigtop.manager.server.model.req.ChatbotMessageReq;
@@ -88,7 +89,13 @@ public class ChatbotController {
@Operation(summary = "talk", description = "Talk with Chatbot")
@PostMapping("/threads/{threadId}/talk")
public SseEmitter talk(@PathVariable Long threadId, @RequestBody
ChatbotMessageReq messageReq) {
- return chatbotService.talk(threadId, messageReq.getMessage());
+ ChatbotCommand command =
ChatbotCommand.getCommandFromMessage(messageReq.getMessage());
+ if (command != null) {
+ messageReq.setMessage(
+
messageReq.getMessage().substring(command.getCmd().length() + 2));
+ return chatbotService.talk(threadId, command,
messageReq.getMessage());
+ }
+ return chatbotService.talk(threadId, null, messageReq.getMessage());
}
@Operation(summary = "history", description = "Get chat records")
@@ -96,4 +103,10 @@ public class ChatbotController {
public ResponseEntity<List<ChatMessageVO>> history(@PathVariable Long
threadId) {
return ResponseEntity.success(chatbotService.history(threadId));
}
+
+ @Operation(summary = "get commands", description = "Get all commands")
+ @GetMapping("/commands")
+ public ResponseEntity<List<String>> getCommands() {
+ return ResponseEntity.success(chatbotService.getChatbotCommands());
+ }
}
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ChatbotCommand.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ChatbotCommand.java
new file mode 100644
index 00000000..8fba4e67
--- /dev/null
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/enums/ChatbotCommand.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.bigtop.manager.server.enums;
+
+import lombok.Getter;
+
+import java.util.ArrayList;
+import java.util.List;
+
+@Getter
+public enum ChatbotCommand {
+ INFO("info"),
+ SEARCH("search"),
+ HELP("help");
+
+ private final String cmd;
+
+ ChatbotCommand(String cmd) {
+ this.cmd = cmd;
+ }
+
+ public static List<String> getAllCommands() {
+ List<String> commands = new ArrayList<>();
+ for (ChatbotCommand command : ChatbotCommand.values()) {
+ commands.add(command.cmd);
+ }
+ return commands;
+ }
+
+ public static ChatbotCommand getCommand(String cmd) {
+ for (ChatbotCommand command : ChatbotCommand.values()) {
+ if (command.cmd.equals(cmd)) {
+ return command;
+ }
+ }
+ return null;
+ }
+
+ public static ChatbotCommand getCommandFromMessage(String message) {
+ if (message.startsWith("/")) {
+ String[] parts = message.split(" ");
+ return getCommand(parts[0].substring(1));
+ }
+ return null;
+ }
+}
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/ChatbotService.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/ChatbotService.java
index 95fa0643..d7a52fc2 100644
---
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/ChatbotService.java
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/ChatbotService.java
@@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.server.service;
+import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.model.dto.ChatThreadDTO;
import org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
@@ -34,11 +35,13 @@ public interface ChatbotService {
List<ChatThreadVO> getAllChatThreads();
- SseEmitter talk(Long threadId, String message);
+ SseEmitter talk(Long threadId, ChatbotCommand command, String message);
List<ChatMessageVO> history(Long threadId);
ChatThreadVO updateChatThread(ChatThreadDTO chatThreadDTO);
+ List<String> getChatbotCommands();
+
ChatThreadVO getChatThread(Long threadId);
}
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java
index 2ab8acde..a27b64a9 100644
---
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/ChatbotServiceImpl.java
@@ -35,6 +35,7 @@ import org.apache.bigtop.manager.dao.repository.ChatThreadDao;
import org.apache.bigtop.manager.dao.repository.PlatformDao;
import org.apache.bigtop.manager.server.enums.ApiExceptionEnum;
import org.apache.bigtop.manager.server.enums.AuthPlatformStatus;
+import org.apache.bigtop.manager.server.enums.ChatbotCommand;
import org.apache.bigtop.manager.server.exception.ApiException;
import org.apache.bigtop.manager.server.holder.SessionUserHolder;
import org.apache.bigtop.manager.server.model.converter.AuthPlatformConverter;
@@ -46,6 +47,7 @@ import
org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
import org.apache.bigtop.manager.server.model.vo.TalkVO;
import org.apache.bigtop.manager.server.service.ChatbotService;
+import org.apache.bigtop.manager.server.tools.AiServiceToolsProvider;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;
@@ -118,10 +120,19 @@ public class ChatbotServiceImpl implements ChatbotService
{
}
private AIAssistant buildAIAssistant(
- String platformName, String model, Map<String, String>
credentials, Long threadId) {
- return getAIAssistantFactory()
- .createAiService(
- getPlatformType(platformName),
getAIAssistantConfig(model, credentials), threadId, null);
+ String platformName, String model, Map<String, String>
credentials, Long threadId, ChatbotCommand command) {
+ if (command == null) {
+ return getAIAssistantFactory()
+ .createAiService(
+ getPlatformType(platformName),
getAIAssistantConfig(model, credentials), threadId, null);
+ } else {
+ return getAIAssistantFactory()
+ .createAiService(
+ getPlatformType(platformName),
+ getAIAssistantConfig(model, credentials),
+ threadId,
+ new AiServiceToolsProvider(command));
+ }
}
@Override
@@ -182,7 +193,7 @@ public class ChatbotServiceImpl implements ChatbotService {
return chatThreads;
}
- private AIAssistant prepareTalk(Long threadId) {
+ private AIAssistant prepareTalk(Long threadId, ChatbotCommand command,
String message) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
Long userId = SessionUserHolder.getUserId();
if (!Objects.equals(userId, chatThreadPO.getUserId()) ||
chatThreadPO.getIsDeleted()) {
@@ -199,14 +210,22 @@ public class ChatbotServiceImpl implements ChatbotService
{
PlatformPO platformPO =
platformDao.findById(authPlatformPO.getPlatformId());
return buildAIAssistant(
- platformPO.getName(), authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(), threadId);
+ platformPO.getName(),
+ authPlatformDTO.getModel(),
+ authPlatformDTO.getAuthCredentials(),
+ threadId,
+ command);
}
@Override
- public SseEmitter talk(Long threadId, String message) {
- AIAssistant aiAssistant = prepareTalk(threadId);
- Flux<String> stringFlux = aiAssistant.streamAsk(message);
-
+ public SseEmitter talk(Long threadId, ChatbotCommand command, String
message) {
+ AIAssistant aiAssistant = prepareTalk(threadId, command, message);
+ Flux<String> stringFlux;
+ if (command == null) {
+ stringFlux = aiAssistant.streamAsk(message);
+ } else {
+ stringFlux = Flux.just(aiAssistant.ask(message));
+ }
SseEmitter emitter = new SseEmitter();
stringFlux.subscribe(
s -> {
@@ -290,6 +309,11 @@ public class ChatbotServiceImpl implements ChatbotService {
chatThreadPO, authPlatformPO,
platformDao.findById(authPlatformPO.getPlatformId()));
}
+ @Override
+ public List<String> getChatbotCommands() {
+ return ChatbotCommand.getAllCommands();
+ }
+
@Override
public ChatThreadVO getChatThread(Long threadId) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java
index e26dfc71..7b1e1ead 100644
---
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/service/impl/LLMConfigServiceImpl.java
@@ -24,6 +24,7 @@ import
org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory;
+import org.apache.bigtop.manager.common.utils.JsonUtils;
import org.apache.bigtop.manager.dao.po.AuthPlatformPO;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
@@ -48,6 +49,11 @@ import org.jetbrains.annotations.NotNull;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Service;
+import dev.langchain4j.agent.tool.JsonSchemaProperty;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.service.tool.ToolExecutor;
+import dev.langchain4j.service.tool.ToolProvider;
+import dev.langchain4j.service.tool.ToolProviderResult;
import lombok.extern.slf4j.Slf4j;
import jakarta.annotation.Resource;
@@ -73,6 +79,9 @@ public class LLMConfigServiceImpl implements LLMConfigService
{
private AIAssistantFactory aiAssistantFactory;
+ private static final String TEST_FLAG = "ZmxhZw==";
+ private static final String TEST_KEY = "bm";
+
public AIAssistantFactory getAIAssistantFactory() {
if (aiAssistantFactory == null) {
aiAssistantFactory =
@@ -96,11 +105,9 @@ public class LLMConfigServiceImpl implements
LLMConfigService {
}
private Boolean testAuthorization(String platformName, String model,
Map<String, String> credentials) {
- AIAssistantConfig aiAssistantConfig = AIAssistantConfig.builder()
- .setModel(model)
- .setLanguage(LocaleContextHolder.getLocale().toString())
- .addCredentials(credentials)
- .build();
+ Boolean result = testFuncCalling(platformName, model, credentials);
+ log.info("Test func calling result: {}", result);
+ AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model,
credentials, null);
AIAssistant aiAssistant =
getAIAssistantFactory().create(getPlatformType(platformName),
aiAssistantConfig);
try {
return aiAssistant.test();
@@ -109,6 +116,37 @@ public class LLMConfigServiceImpl implements
LLMConfigService {
}
}
+ private Boolean testFuncCalling(String platformName, String model,
Map<String, String> credentials) {
+ ToolProvider toolProvider = (toolProviderRequest) -> {
+ ToolSpecification toolSpecification = ToolSpecification.builder()
+ .name("getFlag")
+ .description("Get flag based on key")
+ .addParameter("key", JsonSchemaProperty.STRING,
JsonSchemaProperty.description("Lowercase key"))
+ .build();
+ ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
+ Map<String, Object> arguments =
JsonUtils.readFromString(toolExecutionRequest.arguments());
+ String key = arguments.get("key").toString();
+ if (key.equals(TEST_KEY)) {
+ return TEST_FLAG;
+ }
+ return null;
+ };
+
+ return ToolProviderResult.builder()
+ .add(toolSpecification, toolExecutor)
+ .build();
+ };
+
+ AIAssistantConfig aiAssistantConfig = getAIAssistantConfig(model,
credentials, null);
+ AIAssistant aiAssistant = getAIAssistantFactory()
+ .createAiService(getPlatformType(platformName),
aiAssistantConfig, null, toolProvider);
+ try {
+ return aiAssistant.ask("What is the flag of " +
TEST_KEY).contains(TEST_FLAG);
+ } catch (Exception e) {
+ throw new ApiException(ApiExceptionEnum.CREDIT_INCORRECT,
e.getMessage());
+ }
+ }
+
private void switchActivePlatform(Long id) {
List<AuthPlatformPO> authPlatformPOS = authPlatformDao.findAll();
for (AuthPlatformPO authPlatformPO : authPlatformPOS) {
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/AiServiceToolsProvider.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/AiServiceToolsProvider.java
new file mode 100644
index 00000000..ffcabfcc
--- /dev/null
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/AiServiceToolsProvider.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.bigtop.manager.server.tools;
+
+import org.apache.bigtop.manager.server.enums.ChatbotCommand;
+
+import dev.langchain4j.service.tool.ToolProvider;
+import dev.langchain4j.service.tool.ToolProviderRequest;
+import dev.langchain4j.service.tool.ToolProviderResult;
+
+public class AiServiceToolsProvider implements ToolProvider {
+
+ ChatbotCommand chatbotCommand;
+
+ public AiServiceToolsProvider(ChatbotCommand chatbotCommand) {
+ this.chatbotCommand = chatbotCommand;
+ }
+
+ public AiServiceToolsProvider() {
+ this.chatbotCommand = null;
+ }
+
+ @Override
+ public ToolProviderResult provideTools(ToolProviderRequest
toolProviderRequest) {
+ if (chatbotCommand.equals(ChatbotCommand.INFO)) {
+ ClusterInfoTools clusterInfoTools = new ClusterInfoTools();
+ return
ToolProviderResult.builder().addAll(clusterInfoTools.list()).build();
+ }
+ return null;
+ }
+}
diff --git
a/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/ClusterInfoTools.java
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/ClusterInfoTools.java
new file mode 100644
index 00000000..3f654891
--- /dev/null
+++
b/bigtop-manager-server/src/main/java/org/apache/bigtop/manager/server/tools/ClusterInfoTools.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.bigtop.manager.server.tools;
+
+import org.apache.bigtop.manager.dao.po.ClusterPO;
+import org.apache.bigtop.manager.server.model.converter.ClusterConverter;
+
+import dev.langchain4j.agent.tool.Tool;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.service.tool.ToolExecutor;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+public class ClusterInfoTools {
+
+ @Tool("Get cluster list")
+ public Map<ToolSpecification, ToolExecutor> list() {
+ ToolSpecification toolSpecification = ToolSpecification.builder()
+ .name("list")
+ .description("Get cluster list")
+ .build();
+ ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
+ List<ClusterPO> clusterPOList = new ArrayList<>();
+ ClusterPO mockClusterPO = new ClusterPO();
+ mockClusterPO.setId(1L);
+ mockClusterPO.setName("mock-cluster");
+ clusterPOList.add(mockClusterPO);
+ return
ClusterConverter.INSTANCE.fromPO2VO(clusterPOList).toString();
+ };
+
+ return Map.of(toolSpecification, toolExecutor);
+ }
+}
diff --git
a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/ChatbotControllerTest.java
b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/ChatbotControllerTest.java
index 71082b7f..8bc0cfc5 100644
---
a/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/ChatbotControllerTest.java
+++
b/bigtop-manager-server/src/test/java/org/apache/bigtop/manager/server/controller/ChatbotControllerTest.java
@@ -110,7 +110,8 @@ class ChatbotControllerTest {
messageReq.setMessage("Hello");
SseEmitter emitter = new SseEmitter();
- when(chatbotService.talk(eq(threadId),
eq(messageReq.getMessage()))).thenReturn(emitter);
+ when(chatbotService.talk(eq(threadId), any(),
eq(messageReq.getMessage())))
+ .thenReturn(emitter);
SseEmitter result = chatbotController.talk(threadId, messageReq);