This is an automated email from the ASF dual-hosted git repository.

bowenliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 24e87ef21 [KYUUBI #4556] [CHAT] Refactor ChatGPTProvider to use 
`openai-java` client
24e87ef21 is described below

commit 24e87ef21a4aeb3eaf2a951533b591b56d5f9a52
Author: liangbowen <[email protected]>
AuthorDate: Sun Mar 19 19:24:41 2023 +0800

    [KYUUBI #4556] [CHAT] Refactor ChatGPTProvider to use `openai-java` client
    
    ### _Why are the changes needed?_
    
    - use Java SDK `openai-java` for ChatGPT which is popular and listed in 
official website, https://github.com/TheoKanning/openai-java
    - Focus on lifecycle in ChatGPTProvider, and prevent handling lower-level 
concepts in details, like POJO mapping, HTTP request handling.
    - follow the changes from upstream changes from OpenAI
    
    ### _How was this patch tested?_
    - [ ] Add some test cases that check the changes thoroughly including 
negative and positive cases if possible
    
    - [x] Add screenshots for manual tests if appropriate
    
    - [x] [Run 
test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests)
 locally before make a pull request
    
    Closes #4556 from bowenliang123/chatgpt-third.
    
    Closes #4556
    
    ecf1e2cf6 [liangbowen] manually add `openai-gpt3-java:*` and its dependency 
to LICENSE-binary
    53b8375a5 [liangbowen] refactor ChatGPTProvider to use `openai-java` SDK
    
    Authored-by: liangbowen <[email protected]>
    Signed-off-by: liangbowen <[email protected]>
---
 LICENSE-binary                                     |  5 ++
 externals/kyuubi-chat-engine/pom.xml               |  5 +-
 .../engine/chat/provider/ChatGPTProvider.scala     | 87 ++++++++++------------
 pom.xml                                            |  7 ++
 4 files changed, 53 insertions(+), 51 deletions(-)

diff --git a/LICENSE-binary b/LICENSE-binary
index 92daf62ab..feab9965e 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -319,6 +319,8 @@ io.swagger.core.v3:swagger-models
 io.vertx:vertx-core
 io.vertx:vertx-grpc
 org.apache.zookeeper:zookeeper
+com.squareup.retrofit2:retrofit
+com.squareup.okhttp3:okhttp
 
 BSD
 ------------
@@ -356,6 +358,9 @@ org.codehaus.mojo:animal-sniffer-annotations
 org.slf4j:slf4j-api
 org.slf4j:jcl-over-slf4j
 org.slf4j:jul-over-slf4j
+com.theokanning.openai-gpt3-java:api
+com.theokanning.openai-gpt3-java:client
+com.theokanning.openai-gpt3-java:service
 
 kyuubi-server/src/main/resources/org/apache/kyuubi/ui/static/assets/fonts/*
 kyuubi-server/src/main/resources/org/apache/kyuubi/ui/static/icon.min.css
diff --git a/externals/kyuubi-chat-engine/pom.xml 
b/externals/kyuubi-chat-engine/pom.xml
index 7e2178918..28779f450 100644
--- a/externals/kyuubi-chat-engine/pom.xml
+++ b/externals/kyuubi-chat-engine/pom.xml
@@ -45,8 +45,9 @@
         </dependency>
 
         <dependency>
-            <groupId>org.apache.httpcomponents</groupId>
-            <artifactId>httpclient</artifactId>
+            <groupId>com.theokanning.openai-gpt3-java</groupId>
+            <artifactId>service</artifactId>
+            <version>${openai.java.version}</version>
         </dependency>
 
         <dependency>
diff --git 
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
 
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
index a4cdb7c94..2e4bf3f8d 100644
--- 
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
+++ 
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala
@@ -17,22 +17,20 @@
 
 package org.apache.kyuubi.engine.chat.provider
 
+import java.net.{InetSocketAddress, Proxy, URL}
+import java.time.Duration
 import java.util
 import java.util.concurrent.TimeUnit
 
 import scala.collection.JavaConverters._
 
 import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
-import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
-import org.apache.http.client.config.RequestConfig
-import org.apache.http.client.methods.HttpPost
-import org.apache.http.entity.{ContentType, StringEntity}
-import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder}
-import org.apache.http.message.BasicHeader
-import org.apache.http.util.EntityUtils
+import com.theokanning.openai.OpenAiApi
+import com.theokanning.openai.completion.chat.{ChatCompletionRequest, 
ChatMessage}
+import com.theokanning.openai.service.OpenAiService
+import com.theokanning.openai.service.OpenAiService.{defaultClient, 
defaultObjectMapper, defaultRetrofit}
 
 import org.apache.kyuubi.config.KyuubiConf
-import org.apache.kyuubi.engine.chat.provider.ChatProvider.mapper
 
 class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
 
@@ -42,31 +40,32 @@ class ChatGPTProvider(conf: KyuubiConf) extends 
ChatProvider {
         s"which could be got at https://platform.openai.com/account/api-keys";)
   }
 
-  private val httpClient: CloseableHttpClient = {
-    HttpClientBuilder.create()
-      .setDefaultHeaders(List(
-        new BasicHeader(HttpHeaders.AUTHORIZATION, s"Bearer 
$gptApiKey")).asJava)
-      .build()
-  }
+  private val openAiService: OpenAiService = {
+    val builder = defaultClient(
+      gptApiKey,
+      
Duration.ofMillis(conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT)))
+      .newBuilder
+      
.connectTimeout(Duration.ofMillis(conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT)))
 
-  private val requestConfig: RequestConfig = {
-    val connectTimeout = 
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).intValue()
-    val socketTimeout = 
conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).intValue()
-    val builder: RequestConfig.Builder = RequestConfig.custom()
-      .setConnectTimeout(connectTimeout)
-      .setSocketTimeout(socketTimeout)
-    conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_PROXY).foreach { url =>
-      builder.setProxy(HttpHost.create(url))
+    conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_PROXY) match {
+      case Some(httpProxyUrl) =>
+        val url = new URL(httpProxyUrl)
+        val proxy = new Proxy(Proxy.Type.HTTP, new 
InetSocketAddress(url.getHost, url.getPort))
+        builder.proxy(proxy)
+      case _ =>
     }
-    builder.build()
+
+    val retrofit = defaultRetrofit(builder.build(), defaultObjectMapper)
+    val api = retrofit.create(classOf[OpenAiApi])
+    new OpenAiService(api)
   }
 
-  private val chatHistory: LoadingCache[String, util.ArrayDeque[Message]] =
+  private val chatHistory: LoadingCache[String, util.ArrayDeque[ChatMessage]] =
     CacheBuilder.newBuilder()
       .expireAfterWrite(10, TimeUnit.MINUTES)
-      .build(new CacheLoader[String, util.ArrayDeque[Message]] {
-        override def load(sessionId: String): util.ArrayDeque[Message] =
-          new util.ArrayDeque[Message]
+      .build(new CacheLoader[String, util.ArrayDeque[ChatMessage]] {
+        override def load(sessionId: String): util.ArrayDeque[ChatMessage] =
+          new util.ArrayDeque[ChatMessage]
       })
 
   override def open(sessionId: String): Unit = {
@@ -75,29 +74,19 @@ class ChatGPTProvider(conf: KyuubiConf) extends 
ChatProvider {
 
   override def ask(sessionId: String, q: String): String = {
     val messages = chatHistory.get(sessionId)
-    messages.addLast(Message("user", q))
-
-    val request = new HttpPost("https://api.openai.com/v1/chat/completions";)
-    val req = Map(
-      "messages" -> messages,
-      "model" -> "gpt-3.5-turbo",
-      "max_tokens" -> 200,
-      "temperature" -> 0.5,
-      "top_p" -> 1)
-    val entity = new StringEntity(mapper.writeValueAsString(req), 
ContentType.APPLICATION_JSON)
-    request.setEntity(entity)
-    request.setConfig(requestConfig)
-    val response = httpClient.execute(request)
-    val respJson = mapper.readTree(EntityUtils.toString(response.getEntity))
-    response.getStatusLine.getStatusCode match {
-      case HttpStatus.SC_OK =>
-        val replyMessage = mapper.treeToValue[Message](
-          respJson.get("choices").get(0).get("message"))
-        messages.addLast(replyMessage)
-        replyMessage.content
-      case errorStatusCode =>
+    try {
+      messages.addLast(new ChatMessage("user", q))
+      val completionRequest = ChatCompletionRequest.builder()
+        .messages(messages.asScala.toList.asJava)
+        .model("gpt-3.5-turbo")
+        .build()
+      val responseText = 
openAiService.createChatCompletion(completionRequest).getChoices.asScala
+        .map(c => c.getMessage.getContent).mkString
+      responseText
+    } catch {
+      case e: Throwable =>
         messages.removeLast()
-        s"Chat failed. Status: $errorStatusCode. 
${respJson.get("error").get("message").asText}"
+        s"Chat failed. Error: ${e.getMessage}"
     }
   }
 
diff --git a/pom.xml b/pom.xml
index a2de5d29d..7f6d1ed71 100644
--- a/pom.xml
+++ b/pom.xml
@@ -174,6 +174,7 @@
         <log4j.version>2.20.0</log4j.version>
         <mysql.jdbc.version>8.0.32</mysql.jdbc.version>
         <netty.version>4.1.89.Final</netty.version>
+        <openai.java.version>0.11.1</openai.java.version>
         <parquet.version>1.10.1</parquet.version>
         <phoenix.version>6.0.0</phoenix.version>
         <prometheus.version>0.16.0</prometheus.version>
@@ -1658,6 +1659,12 @@
                 <artifactId>py4j</artifactId>
                 <version>${py4j.version}</version>
             </dependency>
+
+            <dependency>
+                <groupId>com.theokanning.openai-gpt3-java</groupId>
+                <artifactId>service</artifactId>
+                <version>${openai.java.version}</version>
+            </dependency>
         </dependencies>
     </dependencyManagement>
 

Reply via email to