This is an automated email from the ASF dual-hosted git repository.
shreemaanabhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push:
new e775640f7 feat: ai-prompt-template plugin (#11517)
e775640f7 is described below
commit e775640f79923b4480283a3aea6486c3208dff82
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Thu Aug 29 13:28:53 2024 +0545
feat: ai-prompt-template plugin (#11517)
---
apisix/cli/config.lua | 1 +
apisix/plugins/ai-prompt-template.lua | 146 ++++++++++
conf/config.yaml.example | 1 +
docs/en/latest/config.json | 1 +
docs/en/latest/plugins/ai-prompt-template.md | 102 +++++++
t/admin/plugins.t | 1 +
t/plugin/ai-prompt-template.t | 403 +++++++++++++++++++++++++++
7 files changed, 655 insertions(+)
diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua
index 94843621a..7f15542b1 100644
--- a/apisix/cli/config.lua
+++ b/apisix/cli/config.lua
@@ -213,6 +213,7 @@ local _M = {
"authz-keycloak",
"proxy-cache",
"body-transformer",
+ "ai-prompt-template",
"proxy-mirror",
"proxy-rewrite",
"workflow",
diff --git a/apisix/plugins/ai-prompt-template.lua
b/apisix/plugins/ai-prompt-template.lua
new file mode 100644
index 000000000..0a092c3f7
--- /dev/null
+++ b/apisix/plugins/ai-prompt-template.lua
@@ -0,0 +1,146 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+local core = require("apisix.core")
+local body_transformer = require("apisix.plugins.body-transformer")
+local ipairs = ipairs
+
+local prompt_schema = {
+ properties = {
+ role = {
+ type = "string",
+ enum = { "system", "user", "assistant" }
+ },
+ content = {
+ type = "string",
+ minLength = 1,
+ }
+ },
+ required = { "role", "content" }
+}
+
+local prompts = {
+ type = "array",
+ minItems = 1,
+ items = prompt_schema
+}
+
+local schema = {
+ type = "object",
+ properties = {
+ templates = {
+ type = "array",
+ minItems = 1,
+ items = {
+ type = "object",
+ properties = {
+ name = {
+ type = "string",
+ minLength = 1,
+ },
+ template = {
+ type = "object",
+ properties = {
+ model = {
+ type = "string",
+ minLength = 1,
+ },
+ messages = prompts
+ }
+ }
+ },
+ required = {"name", "template"}
+ }
+ },
+ },
+ required = {"templates"},
+}
+
+
+local _M = {
+ version = 0.1,
+ priority = 1060,
+ name = "ai-prompt-template",
+ schema = schema,
+}
+
+local templates_lrucache = core.lrucache.new({
+ ttl = 300, count = 256
+})
+
+local templates_json_lrucache = core.lrucache.new({
+ ttl = 300, count = 256
+})
+
+function _M.check_schema(conf)
+ return core.schema.check(schema, conf)
+end
+
+
+local function get_request_body_table()
+ local body, err = core.request.get_body()
+ if not body then
+ return nil, { message = "could not get body: " .. err }
+ end
+
+ local body_tab, err = core.json.decode(body)
+ if not body_tab then
+ return nil, { message = "could not get parse JSON request body: ", err
}
+ end
+
+ return body_tab
+end
+
+
+local function find_template(conf, template_name)
+ for _, template in ipairs(conf.templates) do
+ if template.name == template_name then
+ return template.template
+ end
+ end
+ return nil
+end
+
+function _M.rewrite(conf, ctx)
+ local body_tab, err = get_request_body_table()
+ if not body_tab then
+ return 400, err
+ end
+ local template_name = body_tab.template_name
+ if not template_name then
+ return 400, { message = "template name is missing in request." }
+ end
+
+ local template = templates_lrucache(template_name, conf, find_template,
conf, template_name)
+ if not template then
+ return 400, { message = "template: " .. template_name .. " not
configured." }
+ end
+
+ local template_json = templates_json_lrucache(template, template,
core.json.encode, template)
+ core.log.info("sending template to body_transformer: ", template_json)
+ return body_transformer.rewrite(
+ {
+ request = {
+ template = template_json,
+ input_format = "json"
+ }
+ },
+ ctx
+ )
+end
+
+
+return _M
diff --git a/conf/config.yaml.example b/conf/config.yaml.example
index 5a490a4bb..5d22418ca 100644
--- a/conf/config.yaml.example
+++ b/conf/config.yaml.example
@@ -476,6 +476,7 @@ plugins: # plugin list (sorted by
priority)
#- error-log-logger # priority: 1091
- proxy-cache # priority: 1085
- body-transformer # priority: 1080
+ - ai-prompt-template # priority: 1060
- proxy-mirror # priority: 1010
- proxy-rewrite # priority: 1008
- workflow # priority: 1006
diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json
index 928aec3b2..0998ec730 100644
--- a/docs/en/latest/config.json
+++ b/docs/en/latest/config.json
@@ -91,6 +91,7 @@
"plugins/proxy-rewrite",
"plugins/grpc-transcode",
"plugins/grpc-web",
+ "plugins/ai-prompt-template",
"plugins/fault-injection",
"plugins/mocking",
"plugins/degraphql",
diff --git a/docs/en/latest/plugins/ai-prompt-template.md
b/docs/en/latest/plugins/ai-prompt-template.md
new file mode 100644
index 000000000..9ca4e1f70
--- /dev/null
+++ b/docs/en/latest/plugins/ai-prompt-template.md
@@ -0,0 +1,102 @@
+---
+title: ai-prompt-template
+keywords:
+ - Apache APISIX
+ - API Gateway
+ - Plugin
+ - ai-prompt-template
+description: This document contains information about the Apache APISIX
ai-prompt-template Plugin.
+---
+
+<!--
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+-->
+
+## Description
+
+The `ai-prompt-template` plugin simplifies access to LLM providers, such as
OpenAI and Anthropic, and their models by predefining the request format
+using a template, which only allows users to pass customized values into
template variables.
+
+## Plugin Attributes
+
+| **Field** | **Required** | **Type** |
**Description**
|
+| ------------------------------------- | ------------ | -------- |
---------------------------------------------------------------------------------------------------------------------------
|
+| `templates` | Yes | Array | An array
of template objects
|
+| `templates.name` | Yes | String | Name of
the template.
|
+| `templates.template.model` | Yes | String | Model of
the AI Model, for example `gpt-4` or `gpt-3.5`. See your LLM provider API
documentation for more available models. |
+| `templates.template.messages.role` | Yes | String | Role of
the message (`system`, `user`, `assistant`)
|
+| `templates.template.messages.content` | Yes | String | Content of
the message.
|
+
+## Example usage
+
+Create a route with the `ai-prompt-template` plugin like so:
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \
+ -H "X-API-KEY: ${ADMIN_API_KEY}" \
+ -d '{
+ "uri": "/v1/chat/completions",
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "api.openai.com:443": 1
+ },
+ "scheme": "https",
+ "pass_host": "node"
+ },
+ "plugins": {
+ "ai-prompt-template": {
+ "templates": [
+ {
+ "name": "level of detail",
+ "template": {
+ "model": "gpt-4",
+ "messages": [
+ {
+ "role": "user",
+ "content": "Explain about {{ topic }} in {{ level }}."
+ }
+ ]
+ }
+ }
+ ]
+ }
+ }
+ }'
+```
+
+Now send a request:
+
+```shell
+curl http://127.0.0.1:9080/v1/chat/completions -i -XPOST -H 'Content-Type:
application/json' -d '{
+ "template_name": "level of detail",
+ "topic": "psychology",
+ "level": "brief"
+}' -H "Authorization: Bearer <your token here>"
+```
+
+Then the request body will be modified to something like this:
+
+```json
+{
+ "model": "some model",
+ "messages": [
+ { "role": "user", "content": "Explain about psychology in brief." }
+ ]
+}
+```
diff --git a/t/admin/plugins.t b/t/admin/plugins.t
index 911205f48..547b1a316 100644
--- a/t/admin/plugins.t
+++ b/t/admin/plugins.t
@@ -93,6 +93,7 @@ opa
authz-keycloak
proxy-cache
body-transformer
+ai-prompt-template
proxy-mirror
proxy-rewrite
workflow
diff --git a/t/plugin/ai-prompt-template.t b/t/plugin/ai-prompt-template.t
new file mode 100644
index 000000000..050e0f246
--- /dev/null
+++ b/t/plugin/ai-prompt-template.t
@@ -0,0 +1,403 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+use t::APISIX 'no_plan';
+
+repeat_each(1);
+log_level('info');
+no_root_location();
+no_shuffle();
+
+add_block_preprocessor(sub {
+ my ($block) = @_;
+
+ if (!$block->request) {
+ $block->set_value("request", "GET /t");
+ }
+
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: sanity
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/echo",
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "127.0.0.1:1980": 1
+ }
+ },
+ "plugins": {
+ "ai-prompt-template": {
+ "templates":[
+ {
+ "name": "programming question",
+ "template": {
+ "model": "some model",
+ "messages": [
+ { "role": "system", "content":
"You are a {{ language }} programmer." },
+ { "role": "user", "content":
"Write a {{ program_name }} program." }
+ ]
+ }
+ },
+ {
+ "name": "level of detail",
+ "template": {
+ "model": "some model",
+ "messages": [
+ { "role": "user", "content":
"Explain about {{ topic }} in {{ level }}." }
+ ]
+ }
+ }
+ ]
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+}
+--- response_body
+passed
+
+
+
+=== TEST 2: no templates
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/echo",
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "127.0.0.1:1980": 1
+ }
+ },
+ "plugins": {
+ "ai-prompt-template": {
+ "templates":[]
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+}
+--- error_code: 400
+--- response_body eval
+qr/.*property \\"templates\\" validation failed: expect array to have at least
1 items.*/
+
+
+
+=== TEST 3: test template insertion
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local json = require("apisix.core.json")
+ local code, body, actual_resp = t('/echo',
+ ngx.HTTP_POST,
+ [[{
+ "template_name": "programming question",
+ "language": "python",
+ "program_name": "quick sort"
+ }]],
+ [[{
+ "model": "some model",
+ "messages": [
+ { "role": "system", "content": "You are a python
programmer." },
+ { "role": "user", "content": "Write a quick sort
program." }
+ ]
+ }]]
+ )
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+ ngx.say("passed")
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 4: multiple templates
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/echo",
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "127.0.0.1:1980": 1
+ }
+ },
+ "plugins": {
+ "ai-prompt-template": {
+ "templates":[
+ {
+ "name": "programming question",
+ "template": {
+ "model": "some model",
+ "messages": [
+ { "role": "system", "content":
"You are a {{ language }} programmer." },
+ { "role": "user", "content":
"Write a {{ program_name }} program." }
+ ]
+ }
+ },
+ {
+ "name": "level of detail",
+ "template": {
+ "model": "some model",
+ "messages": [
+ { "role": "user", "content":
"Explain about {{ topic }} in {{ level }}." }
+ ]
+ }
+ }
+ ]
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+}
+--- response_body
+passed
+
+
+
+=== TEST 5: test second template
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local json = require("apisix.core.json")
+ local code, body, actual_resp = t('/echo',
+ ngx.HTTP_POST,
+ [[{
+ "template_name": "level of detail",
+ "topic": "psychology",
+ "level": "brief"
+ }]],
+ [[{
+ "model": "some model",
+ "messages": [
+ { "role": "user", "content": "Explain about
psychology in brief." }
+ ]
+ }]]
+ )
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+ ngx.say("passed")
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 6: missing template items
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local json = require("apisix.core.json")
+ local code, body, actual_resp = t('/echo',
+ ngx.HTTP_POST,
+ [[{
+ "template_name": "level of detail",
+ "topic-missing": "psychology",
+ "level-missing": "brief"
+ }]],
+ [[{
+ "model": "some model",
+ "messages": [
+ { "role": "user", "content": "Explain about in ."
}
+ ]
+ }]]
+ )
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+ ngx.say("passed")
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 7: request body contains non-existent template
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local json = require("apisix.core.json")
+ local code, body, actual_resp = t('/echo',
+ ngx.HTTP_POST,
+ [[{
+ "template_name": "random",
+ "some-key": "some-value"
+ }]]
+ )
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+ ngx.say("passed")
+ }
+ }
+--- error_code: 400
+--- response_body eval
+qr/.*template: random not configured.*/
+
+
+
+=== TEST 8: request body contains non-existent template
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local json = require("apisix.core.json")
+ local code, body, actual_resp = t('/echo',
+ ngx.HTTP_POST,
+ [[{
+ "missing-template-name": "haha"
+ }]]
+ )
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+ ngx.say("passed")
+ }
+ }
+--- error_code: 400
+--- response_body eval
+qr/.*template name is missing in request.*/
+
+
+
+=== TEST 9: (cache test) same template name in different routes
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ for i = 1, 5, 1 do
+ local code = t('/apisix/admin/routes/' .. i,
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/]] .. i .. [[",
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "127.0.0.1:1980": 1
+ }
+ },
+ "plugins": {
+ "ai-prompt-template": {
+ "templates":[
+ {
+ "name": "same name",
+ "template": {
+ "model": "some model",
+ "messages": [
+ { "role": "system", "content":
"Field: {{ field }} in route]] .. i .. [[." }
+ ]
+ }
+ }
+ ]
+ },
+ "proxy-rewrite": {
+ "uri": "/echo"
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ ngx.say("failed")
+ return
+ end
+ end
+
+ for i = 1, 5, 1 do
+ local code, body = t('/' .. i,
+ ngx.HTTP_POST,
+ [[{
+ "template_name": "same name",
+ "field": "foo"
+ }]],
+ [[{
+ "model": "some model",
+ "messages": [
+ { "role": "system", "content": "Field: foo in
route]] .. i .. [[." }
+ ]
+ }]]
+ )
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+ end
+ ngx.status = 200
+ ngx.say("passed")
+ }
+ }
+
+--- response_body
+passed