Compare commits

...

203 Commits

Author SHA1 Message Date
takatost
1387f9b23e version to 0.5.11 (#3038) 2024-03-29 21:09:21 +08:00
takatost
6817eab5f1 fix: api / moderation extension import error (#3037) 2024-03-29 21:07:34 +08:00
zxhlyh
218f591a5d fix: prompt editor linebreak (#3036) 2024-03-29 21:01:04 +08:00
Richards Tu
17af0de7b6 Add New Tool: StackExchange (#3034)
Co-authored-by: crazywoola <427733928@qq.com>
2024-03-29 20:28:21 +08:00
Chenhe Gu
9d962053a2 Fix claude request errors in bedrock (#3015)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-03-29 13:57:45 +08:00
kun321
59909b5ca7 update the discord Invalid invite (#3028) 2024-03-29 13:16:52 +08:00
Jyong
a6cd0f0e73 fix add segment when dataset and document is empty (#3021)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-29 13:06:00 +08:00
Richards Tu
2c43393bf1 Add New Tool: DevDocs (#2993) 2024-03-29 11:21:02 +08:00
Jyong
669c8c3cca some optimization for admin api key, create tenant and reset-encrypt-key-pair command (#3013)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-28 17:02:52 +08:00
Jyong
b0b0cc045f add mutil-thread document embedding (#3016)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-28 17:02:35 +08:00
crazywoola
20d16d7b31 doc: update helm charts (#3012) 2024-03-28 13:02:41 +08:00
Ricky
714722bb2d fix: 'next' button unresponsive when uploading additional documents before previous batch completes (#2991) 2024-03-28 12:28:15 +08:00
Bowen Liang
830495a607 bump celery from 5.2 to 5.3 (#2478)
Co-authored-by: takatost <takatost@users.noreply.github.com>
2024-03-28 11:53:48 +08:00
Bowen Liang
41a4593b6d bump redis client to 5.0 and enable hiredis support (#2518) 2024-03-28 11:40:21 +08:00
Bowen Liang
08b727833e generalize helper for loading module from source (#2862) 2024-03-28 11:37:26 +08:00
aqachun
c8b82b9d08 fix: missing comma in JSON for /completion-messages request (#2999) 2024-03-27 14:31:06 +08:00
Weaxs
5becb4c43a update wenxin llm (#2929) 2024-03-27 11:36:21 +08:00
Kenny
13694293e3 fix: resolve header.uid' length must be less or equal than 32 on Spark V1.5 (#2983) 2024-03-27 09:58:41 +08:00
Ricky
815beac356 Fix the time in the annotation from 12-hour clock to 24-hour clock. (#2990) 2024-03-27 09:08:38 +08:00
legao
5e60204832 fix: progress bar issue (#2957) 2024-03-26 17:26:58 +08:00
legao
d2624b13a0 fix: the issue of text overflow in the NavSelector component (#2976) 2024-03-26 17:22:01 +08:00
zxhlyh
61f5de9662 fix: chat scroll (#2981) 2024-03-26 16:19:41 +08:00
Ricky
40dbf30784 feat: support new reranker [jina-colbert-v1-en] (#2975) 2024-03-26 11:34:40 +08:00
Ricky
afd77c4745 fix: the batch annotaion btn should also be loading when progress status is waiting (#2974) 2024-03-26 11:05:29 +08:00
listeng
d70bd4aaa4 fix tool_inputs parse error in message that in CoT(ReAct) agent mode (#2949) 2024-03-26 11:05:10 +08:00
Yulong Wang
8e05261588 Fix handling of missing required parameters in ApiTool (#2965) 2024-03-26 10:53:39 +08:00
Weishan-0
a676d4387c fix: Correct image parameter passing in GLM-4v model API calls (#2948) 2024-03-26 10:43:20 +08:00
Kenny
08a5afcf9f feat: update nginx and docker-compose files to support HTTPS. (#2940)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-26 10:37:43 +08:00
crazywoola
eeaa3c1643 Fix/2969 add model provider ollama not work (#2973) 2024-03-26 10:26:34 +08:00
Leo Q
7c8c233cf4 Add S3_ADDRESS_STYLE configuration option (#2934) 2024-03-26 10:18:26 +08:00
Bowen Liang
129a9850eb fix: correct response hint for generated image to avoid illusion of regernerated image link (#2962) 2024-03-26 10:13:35 +08:00
Bowen Liang
1f98a4fff3 improve: cache tool icons by setting max-age HTTP header and enable gzip compression SVG icons from backend (#2971) 2024-03-26 10:11:43 +08:00
Ricky
58e4702b14 fix: white screen when editing annotaion in log panel (#2968) 2024-03-26 10:10:14 +08:00
colvin777
c60749678b When disabling the "Annotation Reply" button, the backend reports an error. #2904 (#2933)
Co-authored-by: colvin <colvin.zhang@boaocloud.com>
2024-03-25 22:20:40 +08:00
legao
d5214e4644 reuse layout (#2956) 2024-03-25 15:13:50 +08:00
legao
52804ca6d1 fix: adjust popup panel's z-index value (#2952) 2024-03-25 15:09:01 +08:00
orangeclk
4fb9606361 fix: max_token default help info improved (#2951) 2024-03-25 10:07:32 +08:00
orangeclk
c534d95972 fix: yi model price correction (#2946) 2024-03-24 12:10:57 +08:00
Nanguan Lin
46ccfda493 fix: invalid i18 link in README (#2947) 2024-03-24 12:10:13 +08:00
orangeclk
6dc62334d6 doc: model schema document fix and wording about the model price parameter (#2944) 2024-03-24 12:06:20 +08:00
wangkehan
c7d003d551 fix: Upgrade duckduckgo-search to version 5.1.0 & update document segment api parameter error (#2938) 2024-03-22 19:18:01 +08:00
Leo Q
cc754122fc Authentication is only applied when both the username and password have values. (#2937) 2024-03-22 17:58:21 +08:00
Yeuoly
240a94182e Feat/add triton inference server (#2928) 2024-03-22 15:15:48 +08:00
Kenny
16af509c46 Update docker-compose files version (#2920) 2024-03-21 15:16:30 +08:00
Jyong
86e474fff1 Add azure blob storage support (#2919)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-20 20:49:58 +08:00
Joel
9a3d5729bb fix: suggest service api missed user in query (#2918) 2024-03-20 20:08:26 +08:00
Su Yang
5a1c29fd8c chore: change Yi model SDK to OpenAI (#2910) 2024-03-20 16:02:13 +08:00
Qiwen Tong
180775a0ec fix: init qdrant vector max recursion (#2909) 2024-03-20 14:57:13 +08:00
Bowen Liang
d018e279f8 fix: typo $ mark in logs of vdb migrate command (#2901) 2024-03-19 22:21:58 +08:00
takatost
11636bc7c7 bump version to 0.5.10 (#2902) 2024-03-19 21:35:58 +08:00
Joshua
518c1ceb94 Feat/add-NVIDIA-as-a-new-model-provider (#2900) 2024-03-19 21:08:17 +08:00
listeng
696efe494e fix: Ignore some emtpy page_content when append to split_documents (#2898) 2024-03-19 20:55:15 +08:00
Su Yang
4419d357c4 chore: update Yi models params (#2895) 2024-03-19 20:54:31 +08:00
takatost
fbbba6db92 feat: optimize ollama model default parameters (#2894) 2024-03-19 18:34:23 +08:00
Lance Mao
53d428907b fix incorrect exception raised by api tool which leads to incorrect L… (#2886)
Co-authored-by: OSS-MAOLONGDONG\kaihong <maolongdong@kaihong.com>
2024-03-19 18:17:12 +08:00
Su Yang
8133ba16b1 chore: update Qwen model params (#2892) 2024-03-19 18:13:32 +08:00
crazywoola
e9aa0e89d3 chore: update pr template (#2893) 2024-03-19 17:24:57 +08:00
Su Yang
7e3c59e53e chore: Update TongYi models prices (#2890) 2024-03-19 16:32:42 +08:00
呆萌闷油瓶
f6314f8e73 feat:support azure openai llm 0125 version (#2889) 2024-03-19 16:32:26 +08:00
Su Yang
3bcfd84fba chore: use API Key instead of APIKey (#2888) 2024-03-19 16:32:06 +08:00
Bowen Liang
7c0ae76cd0 Bump tiktoken to 0.6.0 to support text-embedding-3-* in encoding_for_model (#2891) 2024-03-19 16:31:46 +08:00
Su Yang
2dee8a25d5 fix: anthropic system prompt not working (#2885) 2024-03-19 15:50:02 +08:00
Su Yang
507aa6d949 fix: Fix the problem of system not working (#2884) 2024-03-19 13:56:22 +08:00
crazywoola
59f173f2e6 feat: add icons for 01.ai (#2883) 2024-03-19 13:53:21 +08:00
Su Yang
c3790c239c i18n: update bedrock label (#2879) 2024-03-19 00:57:19 +08:00
Su Yang
45e51e7730 feat: AWS Bedrock Claude3 (#2864)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
2024-03-18 18:16:36 +08:00
Jyong
4834eae887 fix enable annotation reply when collection is None (#2877)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-18 17:18:52 +08:00
Yeuoly
01108e6172 fix/Add isModel flag to AgentTools component (#2876) 2024-03-18 17:01:25 +08:00
Yeuoly
95b74c211d Feat/support tool credentials bool schema (#2875) 2024-03-18 16:55:26 +08:00
Onelevenvy
cb79a90031 feat: Add tools for open weather search and image generation using the Spark API. (#2845) 2024-03-18 16:22:48 +08:00
Onelevenvy
4502436c47 feat:Embedding models Support for the Aliyun dashscope text-embedding-v1 and text-embedding-v2 (#2874) 2024-03-18 15:21:26 +08:00
Jyong
c3d0cf940c add tenant id index for document and document_segment table (#2873)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-18 14:34:32 +08:00
orangeclk
e7343cc67c add max_tokens parameter rule for zhipuai glm4 and glm4v (#2861) 2024-03-18 13:19:36 +08:00
VoidIsVoid
83145486b0 fix: fix unstable function call response arguments missing (#2872) 2024-03-18 13:17:16 +08:00
Su Yang
6fd1795d25 feat: Allow users to specify AWS Bedrock validation models (#2857) 2024-03-18 00:44:09 +08:00
Su Yang
f770232b63 feat: add model for 01.ai, yi-chat-34b series (#2865) 2024-03-17 21:24:01 +08:00
Bowen Liang
a8e694c235 fix: print exception logs for ValueError and InvokeError (#2823) 2024-03-17 14:34:32 +08:00
Eric Wang
15a6d94953 Refactor: Streamline the build-push and deploy-dev workflow (#2852) 2024-03-17 14:20:34 +08:00
crazywoola
056331981e fix: api doc duplicate symbols (#2853) 2024-03-15 18:17:43 +08:00
Yeuoly
cef16862da fix: charts encoding (#2848) 2024-03-15 14:02:52 +08:00
Rozstone
8a4015722d prevent auto scrolling down to bottom when user already scrolled up (#2813) 2024-03-15 13:19:06 +08:00
crazywoola
156345cb4b fix: use supported languages only for install form (#2844) 2024-03-15 12:05:35 +08:00
Yeuoly
f29280ba5c Fix/compatible to old tool config (#2839) 2024-03-15 11:44:24 +08:00
Yeuoly
742be06ea9 Fix/localai (#2840) 2024-03-15 11:41:51 +08:00
crazywoola
af98954fc1 Feat/add script to check i18n keys (#2835) 2024-03-14 18:03:59 +08:00
David
4d63770189 fix: The generate conversation name was not saved (#2836) 2024-03-14 17:53:55 +08:00
Yeuoly
bbea3a6b84 fix: compatible to old tool config (#2837) 2024-03-14 17:51:11 +08:00
Bowen Liang
19d3a56194 feat: add weekday calculator in time tool (#2822) 2024-03-14 17:01:48 +08:00
ChiayenGu
5cab2b711f fix: doc for datasets (#2831) 2024-03-14 16:41:40 +08:00
Qun
1e5455e266 enhance: use override_settings for concurrent stable diffusion (#2818) 2024-03-14 15:26:07 +08:00
Eric Wang
4fe585acc2 feat(llm/models): add claude-3-haiku-20240307 (#2825) 2024-03-14 10:08:24 +08:00
呆萌闷油瓶
e52448b84b feat:add api-version selection for azure openai APIs (#2821) 2024-03-14 09:14:27 +08:00
crazywoola
1f92b55f58 fix: doc for completion-messages (#2820) 2024-03-13 22:25:18 +08:00
Bowen Liang
8b15b742ad generalize position helper for parsing _position.yaml and sorting objects by name (#2803) 2024-03-13 20:29:38 +08:00
Laurent Magnien
849dc0560b feat: add French fr-FR (#2810)
Co-authored-by: Laurent Magnien <laurent.magnien@adsn.fr>
2024-03-13 18:20:55 +08:00
Phạm Viết Nghĩa
a026c5fd08 feat: add Vietnamese vi-VN (#2807) 2024-03-13 15:54:47 +08:00
Charlie.Wei
fd7aade26b Fix tts api err (#2809)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-13 15:38:10 +08:00
Mark Sun
510f8ede10 Improve automatic prompt generation (#2805) 2024-03-13 14:10:47 +08:00
呆萌闷油瓶
8f9125b08a fix:typo (#2808) 2024-03-13 13:00:46 +08:00
呆萌闷油瓶
e5e97c0a0a fix:change azure openai api_version default value to 2024-02-15-preview (#2797) 2024-03-12 22:07:06 +08:00
Yulong Wang
870ca713df Refactor Markdown component to include paragraph after image (#2798) 2024-03-12 22:06:54 +08:00
Joshua
6854a3fd26 Update README.md (#2800) 2024-03-12 18:14:07 +08:00
Joshua
620360d41a Update README.md (#2799) 2024-03-12 17:02:46 +08:00
Weaxs
20bd49285b excel: get keys from every sheet (#2796) 2024-03-12 16:59:25 +08:00
crazywoola
6bd2730317 Fix/2770 suggestions for next steps (#2788) 2024-03-12 16:27:55 +08:00
Yeuoly
f734cca337 enhance: add stable diffusion user guide (#2795) 2024-03-12 14:45:48 +08:00
takatost
ce5b19d011 bump version to 0.5.9 (#2794) 2024-03-12 14:01:24 +08:00
Bowen Liang
f82a64d149 feat: add DingTalk(钉钉) tool for sending message to chat group bot via webhook (#2693) 2024-03-12 13:45:59 +08:00
呆萌闷油瓶
f49b1afd6c feat:support azure tts (#2751) 2024-03-12 12:06:35 +08:00
Jyong
796c5626a7 fix delete dataset when dataset has no document (#2789)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-11 23:57:38 +08:00
Jyong
e54c9cd401 Feat/open ai compatible functioncall (#2783)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-11 19:48:21 +08:00
Yeuoly
f8951d7f57 fix: api tool provider not found (#2782) 2024-03-11 18:21:41 +08:00
Jyong
6454e1d644 chunk-overlap None check (#2781)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-11 15:36:56 +08:00
crazywoola
e184c8cb42 Update README.md (#2780) 2024-03-11 14:53:40 +08:00
Eric Wang
fdd211e399 debug/chat: increase notify error duration to 3000 (#2778) 2024-03-11 14:16:31 +08:00
Eric Wang
7001e21e7d overview: fix filter today calc start & end (#2777) 2024-03-11 14:11:51 +08:00
Yeuoly
82d0732c12 fix: aippt default styles (#2779) 2024-03-11 14:04:09 +08:00
Yeuoly
53cd125780 fix: deep copy of model-tool label (#2775) 2024-03-11 10:27:00 +08:00
crazywoola
3c91f9b5ab fix: dataset segements api (#2766) 2024-03-11 09:26:15 +08:00
takatost
f073dca22a feat: optimize db connection when llm invoking (#2774) 2024-03-10 15:48:31 +08:00
crazywoola
8b1e35d7dc doc: add suggested questions back (#2771) 2024-03-10 15:40:17 +08:00
Rozstone
b75d8ca621 fix: auto closing when close local image uploading (#2767) 2024-03-10 13:11:41 +08:00
zxhlyh
9beefd7d5a fix: auto prompt (#2768) 2024-03-09 18:36:58 +08:00
Vikey Chen
88145efa97 fix: app name can be empty in settings modal (#2761) 2024-03-09 09:13:12 +08:00
Laurent Magnien
bdc13f9238 SMTP authentication is optional (#2765)
Co-authored-by: Laurent Magnien <laurent.magnien@adsn.fr>
2024-03-09 09:11:03 +08:00
Yeuoly
ce58f0607b Feat/tool secret parameter (#2760) 2024-03-08 20:31:13 +08:00
crazywoola
bbc0d330a9 chore: rename lastStep to previousStep (#2759) 2024-03-08 19:27:02 +08:00
洪朔
60e7e17c86 feat: Add new Azure OpenAI Embedding models (#2758) 2024-03-08 19:04:20 +08:00
Vikey Chen
237bb8514e replace message content type list to string when file_objs is empty .. (#2745) 2024-03-08 18:46:31 +08:00
yoogo
bd26c933d2 fix: valid password on reset-password page (#2753) 2024-03-08 18:44:49 +08:00
Yeuoly
b6b58da2d2 enhance: custom tool timeout (#2754) 2024-03-08 15:26:08 +08:00
Yeuoly
40c646cf7a Feat/model as tool (#2744) 2024-03-08 15:22:55 +08:00
Yeuoly
3231a8c51c fix: image tokenizer (#2752) 2024-03-08 14:50:51 +08:00
Bowen Liang
4170d6a491 use SVG icons for built-in tools (#2748) 2024-03-08 10:21:26 +08:00
Bowen Liang
0b50c525cf feat: support error correction and border size in qrcode tool (#2731) 2024-03-07 20:54:14 +08:00
Jyong
8ba38e8e74 fix overlap and splitter optimization (#2742)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-07 18:25:49 +08:00
Bowen Liang
b163545771 Use python-docx to extract docx files (#2654) 2024-03-07 18:24:55 +08:00
Yash Parmar
c0b82f8e58 UPDATE: Twilio tool crdential verification (#2741) 2024-03-07 18:08:52 +08:00
呆萌闷油瓶
b75ff5fa03 fix:missing import (#2739) 2024-03-07 17:31:30 +08:00
crazywoola
9440d7fe88 fix: the behavior of save action in opening config panel (#2736) 2024-03-07 16:48:44 +08:00
Yeuoly
24809fce07 fix: missing en_name of aippt (#2737) 2024-03-07 16:37:12 +08:00
呆萌闷油瓶
9819ad347f feat:support azure whisper model and fix:rename text-embedidng-ada-002.yaml to text-embedding-ada-002.yaml (#2732) 2024-03-07 16:36:58 +08:00
Yeuoly
8fe83750b7 Fix/jina tokenizer cache (#2735) 2024-03-07 16:32:37 +08:00
Yeuoly
1809f05904 Feat/add groq (#2733) 2024-03-07 16:00:40 +08:00
Bowen Liang
0ac250a035 fix: check webhook key of Wecom tool in valid UUID form and fix typo (#2719) 2024-03-07 15:51:06 +08:00
taokuizu
405a00bb2c fix:delete the slash at the end of xinference provider server_url (#2730) 2024-03-07 15:37:05 +08:00
Yeuoly
3a3ca8e6a9 fix: max tokens can only up to 2048 (#2734) 2024-03-07 15:35:56 +08:00
Yeuoly
27e678480e Feat: AIPPT & DynamicToolParamter (#2725) 2024-03-07 15:04:42 +08:00
Lance Mao
7052565380 fix typo: responsing -> responding (#2718)
Co-authored-by: OSS-MAOLONGDONG\kaihong <maolongdong@kaihong.com>
2024-03-07 10:20:35 +08:00
Jyong
31070ffbca fix qa index processor tenant id is None error (#2713)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-06 16:46:08 +08:00
Jyong
7f3dec7bee fix error msg format issue (#2715)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-06 16:45:40 +08:00
Joel
b1e0db4944 fix: chatbot service api auto generate name default value error (#2709) 2024-03-06 13:19:27 +08:00
Rhon Joe
c439952a41 fix(web): chat input auto resize by window (#2696) 2024-03-06 12:49:22 +08:00
Yash Parmar
2f28afebb6 FEAT: Add twilio tool for sending text and whatsapp messages (#2700)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-06 11:35:08 +08:00
Charlie.Wei
fa7ba30ba3 Fix rebuild index&csv parsing (#2705)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-06 11:33:32 +08:00
Bowen Liang
1cf5f510ed feat: add qrcode tool for QR code generation (#2699) 2024-03-06 11:26:16 +08:00
Joshua
526c874caa fix mistralai icon (#2707) 2024-03-06 11:08:22 +08:00
Bowen Liang
f88f744097 make volume folders for milvus docker containers ignored by git (#2694) 2024-03-05 17:26:21 +08:00
Yeuoly
95733796f0 fix: replace os.path.join with yarl (#2690) 2024-03-05 17:25:20 +08:00
Bowen Liang
552f319b9d feat: support HTTP response compression in api server (#2680) 2024-03-05 14:45:22 +08:00
Yeuoly
38e5952417 Fix/agent react output parser (#2689) 2024-03-05 14:02:07 +08:00
Yash Parmar
7f891939f1 FEAT: add tavily tool for searching... A search engine for LLM (#2681) 2024-03-05 10:23:44 +08:00
Charlie.Wei
69a5ce1e31 Fix tts play logic (#2683)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-05 09:22:36 +08:00
takatost
534802b761 bump version to 0.5.8 (#2685) 2024-03-05 01:37:53 +08:00
takatost
5c258e212c feat: add Anthropic claude-3 models support (#2684) 2024-03-05 01:37:42 +08:00
Charlie.Wei
6a6133c102 Fix voice selection (#2664)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-04 17:50:06 +08:00
Joel
3c1825187a fix: auto generate prompt result not show (#2678) 2024-03-04 17:36:11 +08:00
Joshua
8523b34be7 add jina-reranker-v1-base-en (#2676) 2024-03-04 17:31:01 +08:00
Bowen Liang
65cfd4360a fix: typo in wecom tool (#2674) 2024-03-04 17:25:42 +08:00
Joel
bbf5f42c87 fix: CE edition limits upload file nums (#2677) 2024-03-04 17:25:31 +08:00
Jyong
3631e53ff0 Feat/add annotation migrate (#2675)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-04 17:22:06 +08:00
waltcow
f322d9bddb Fix vdb merge error (#2650) 2024-03-04 16:35:50 +08:00
Yeuoly
05ce7b9d5e fix: deep copy customColletion (#2673) 2024-03-04 15:20:20 +08:00
Yeuoly
72ddedfc5c fix: setup default filters while add credentials (#2669) 2024-03-04 14:17:00 +08:00
Yeuoly
36686d7425 fix: test custom tool already exists without decrypting credentials (#2668) 2024-03-04 14:16:47 +08:00
cola
34387ec0f1 fix typo recale to recalc (#2670) 2024-03-04 14:15:53 +08:00
Chenhe Gu
83a6b0c626 Doc/update license (#2666) 2024-03-04 14:10:39 +08:00
takatost
76da66fb7e fix: fix import from explore apps err when OpenAI not inited (#2671) 2024-03-04 14:06:54 +08:00
nan jiang
607f9eda35 Fix/app runner typo (#2661) 2024-03-04 13:32:17 +08:00
Bowen Liang
f25cec265d feat: add Wecom(企业微信) tool for sending message to chat group bot via webhook (#2638) 2024-03-04 10:27:20 +08:00
Garfield Dai
8e66b96221 Feat: Add documents limitation (#2662) 2024-03-03 12:45:06 +08:00
crazywoola
b5c1bb346c Add PubMed to tools (#2652) 2024-03-03 12:44:13 +08:00
Yeuoly
e94b323e6c fix: use English as the default i18n language (#2663) 2024-03-03 12:35:28 +08:00
nan jiang
bc65ee10c0 bugfix: model str maybe empty (#2660) 2024-03-03 11:43:38 +08:00
Rozstone
2001483659 fix: default to allcategories when search params is not from recommended (#2653) 2024-03-02 17:11:25 +08:00
crazywoola
444aba55dd Feat/jpn support (#2651) 2024-03-02 13:47:51 +08:00
Joel
3f640b1037 fix: click tool item in app debug page would show detail (#2644) 2024-03-01 18:47:12 +08:00
Yeuoly
b07084711c fix: missing description (#2643) 2024-03-01 18:19:04 +08:00
Joel
fa8ab2134f feat: displaying the tool description when clicking on a custom tool (#2642) 2024-03-01 17:58:38 +08:00
takatost
1a677da792 fix: custom tool max tool (#2641) 2024-03-01 16:43:47 +08:00
taokuizu
b6d61a818e fix: Replace path.join with urljoin. (#2631) 2024-03-01 13:07:15 +08:00
Bowen Liang
8495ffaa45 fix: typo in gaode tool (#2636) 2024-03-01 10:12:48 +08:00
Yash Parmar
dbd1d79770 FEAT: Add arxiv tool for searching scientific papers and articles fro… (#2632) 2024-02-29 19:46:10 +08:00
takatost
1910178199 fix: default mail type invalid in .env.example (#2628) 2024-02-29 17:29:48 +08:00
Bowen Liang
839a6a2c8a add logs for vdb-migrate command (#2626) 2024-02-29 16:24:51 +08:00
Yeuoly
a769edbc89 Fix/custom tool any of (#2625) 2024-02-29 14:39:05 +08:00
Yeuoly
57ffecd0e5 fix: remove unnecessary credentials of custom tool (#2621) 2024-02-29 12:58:12 +08:00
Bowen Liang
801d135390 generalize the generation of new collection name by dataset id (#2620) 2024-02-29 12:47:10 +08:00
Bowen Liang
0428f44113 chore: bump superlinter action from v5 to v6 (#2325) 2024-02-29 12:45:06 +08:00
zxhlyh
7beff3fd5a fix: model parameter load presets config (#2622) 2024-02-29 12:43:46 +08:00
takatost
88a095e40e fix: wrong default model parameters when creating app (#2623) 2024-02-29 12:43:07 +08:00
takatost
dd961985f0 refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614) 2024-02-28 23:32:47 +08:00
Yeuoly
d44b05a9e5 feat: support auth type like basic bearer and custom (#2613) 2024-02-28 23:19:08 +08:00
532 changed files with 19191 additions and 4005 deletions

View File

@@ -12,6 +12,8 @@ Please delete options that are not relevant.
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
- [ ] Improvementincluding but not limited to code refactoring, performance optimization, and UI/UX improvement
- [ ] Dependency upgrade
# How Has This Been Tested?

View File

@@ -1,17 +1,32 @@
name: Build and Push API Image
name: Build and Push API & Web
on:
push:
branches:
- 'main'
- 'deploy/dev'
- "main"
- "deploy/dev"
release:
types: [ published ]
types: [published]
env:
DOCKERHUB_USER: ${{ secrets.DOCKERHUB_USER }}
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
DIFY_WEB_IMAGE_NAME: ${{ vars.DIFY_WEB_IMAGE_NAME || 'langgenius/dify-web' }}
DIFY_API_IMAGE_NAME: ${{ vars.DIFY_API_IMAGE_NAME || 'langgenius/dify-api' }}
jobs:
build-and-push:
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
strategy:
matrix:
include:
- service_name: "web"
image_name_env: "DIFY_WEB_IMAGE_NAME"
context: "web"
- service_name: "api"
image_name_env: "DIFY_API_IMAGE_NAME"
context: "api"
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
@@ -22,14 +37,14 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USER }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: langgenius/dify-api
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
@@ -39,22 +54,11 @@ jobs:
- name: Build and push
uses: docker/build-push-action@v5
with:
context: "{{defaultContext}}:api"
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}
build-args: |
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Deploy to server
if: github.ref == 'refs/heads/deploy/dev'
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ secrets.SSH_SCRIPT }}

View File

@@ -1,60 +0,0 @@
name: Build and Push WEB Image
on:
push:
branches:
- 'main'
- 'deploy/dev'
release:
types: [ published ]
jobs:
build-and-push:
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USER }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: langgenius/dify-web
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
- name: Build and push
uses: docker/build-push-action@v5
with:
context: "{{defaultContext}}:web"
platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }}
build-args: |
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Deploy to server
if: github.ref == 'refs/heads/deploy/dev'
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ secrets.SSH_SCRIPT }}

24
.github/workflows/deploy-dev.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Deploy Dev
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/dev"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

View File

@@ -41,6 +41,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup NodeJS
uses: actions/setup-node@v4
@@ -60,11 +62,10 @@ jobs:
yarn run lint
- name: Super-linter
uses: super-linter/super-linter/slim@v5
uses: super-linter/super-linter/slim@v6
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: main
ERROR_ON_MISSING_EXEC_BIT: true
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true

6
.gitignore vendored
View File

@@ -145,10 +145,14 @@ docker/volumes/db/data/*
docker/volumes/redis/data/*
docker/volumes/weaviate/*
docker/volumes/qdrant/*
docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
sdks/python-client/build
sdks/python-client/dist
sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json
!.vscode/launch.json
pyrightconfig.json

View File

@@ -155,4 +155,4 @@ And that's it! Once your PR is merged, you will be featured as a contributor in
## Getting Help
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/AhzKf7dNgk) for a quick chat.
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.

View File

@@ -152,4 +152,4 @@ Dify的后端使用Python编写使用[Flask](https://flask.palletsprojects.co
## 获取帮助
如果你在贡献过程中遇到困难或者有任何问题,可以通过相关的 GitHub 问题提出你的疑问,或者加入我们的 [Discord](https://discord.gg/AhzKf7dNgk) 进行快速交流。
如果你在贡献过程中遇到困难或者有任何问题,可以通过相关的 GitHub 问题提出你的疑问,或者加入我们的 [Discord](https://discord.gg/8Tpq4AcN9c) 进行快速交流。

22
LICENSE
View File

@@ -1,24 +1,26 @@
# Dify Open Source License
# Open Source License
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
Dify is licensed under the Apache License 2.0, with the following additional conditions:
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
Please contact business@dify.ai by email to inquire about licensing matters.
2. As a contributor, you should agree that your contributed code:
2. As a contributor, you should agree that:
a. The producer can adjust the open-source agreement to be more strict or relaxed.
b. Can be used for commercial purposes, such as Dify's cloud business.
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations.
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0.
The interactive design of this product is protected by appearance patent.
© 2023 LangGenius, Inc.
© 2024 LangGenius, Inc.
----------

43
Makefile Normal file
View File

@@ -0,0 +1,43 @@
# Variables
DOCKER_REGISTRY=langgenius
WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
docker build -t $(WEB_IMAGE):$(VERSION) ./web
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
build-api:
@echo "Building API Docker image: $(API_IMAGE):$(VERSION)..."
docker build -t $(API_IMAGE):$(VERSION) ./api
@echo "API Docker image built successfully: $(API_IMAGE):$(VERSION)"
# Push Docker images
push-web:
@echo "Pushing web Docker image: $(WEB_IMAGE):$(VERSION)..."
docker push $(WEB_IMAGE):$(VERSION)
@echo "Web Docker image pushed successfully: $(WEB_IMAGE):$(VERSION)"
push-api:
@echo "Pushing API Docker image: $(API_IMAGE):$(VERSION)..."
docker push $(API_IMAGE):$(VERSION)
@echo "API Docker image pushed successfully: $(API_IMAGE):$(VERSION)"
# Build all images
build-all: build-web build-api
# Push all images
push-all: push-web push-api
build-push-api: build-api push-api
build-push-web: build-web push-web
# Build and push all images
build-push-all: build-all push-all
@echo "All Docker images have been built and pushed."
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all

View File

@@ -22,19 +22,8 @@
</p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
<a href="https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6" target="_blank">
📌 Check out Dify Premium on AWS and deploy it to your own AWS VPC with one-click.
</a>
</p>
@@ -48,6 +37,9 @@
You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5.
### Looking to purchase via AWS?
Check out [Dify Premium on AWS](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click.
## Dify vs. LangChain vs. Assistants API
| Feature | Dify.AI | Assistants API | LangChain |
@@ -108,10 +100,12 @@ docker compose up -d
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization installation process.
### Helm Chart
#### Deploy with Helm Chart
Big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
You can go to https://github.com/BorisPolonsky/dify-helm for deployment information.
[Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
### Configuration
@@ -128,6 +122,7 @@ For those who'd like to contribute code, see our [Contribution Guide](https://gi
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
### Contributors
<a href="https://github.com/langgenius/dify/graphs/contributors">
@@ -136,7 +131,7 @@ At the same time, please consider supporting Dify by sharing it on social media
### Translations
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README_EN.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/AhzKf7dNgk).
We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
## Community & Support

View File

@@ -94,10 +94,12 @@ docker compose up -d
运行后,可以在浏览器上访问 [http://localhost/install](http://localhost/install) 进入 Dify 控制台并开始初始化安装操作。
### Helm Chart
#### 使用 Helm Chart 部署
非常感谢 @BorisPolonsky 为我们提供了一个 [Helm Chart](https://helm.sh/) 版本,可以在 Kubernetes 上部署 Dify。
您可以前往 https://github.com/BorisPolonsky/dify-helm 来获取部署信息。
使用 [Helm Chart](https://helm.sh/) 版本,可以在 Kubernetes 上部署 Dify。
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
### 配置

View File

@@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3
# storage type: local, s3, azure-blob
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
@@ -47,6 +47,11 @@ S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@@ -82,7 +87,7 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
MULTIMODAL_SEND_IMAGE_FORMAT=base64
# Mail configuration, support: resend, smtp
MAIL_TYPE=resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
@@ -131,4 +136,4 @@ UNSTRUCTURED_API_URL=
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
BATCH_UPLOAD_LIMIT=10
BATCH_UPLOAD_LIMIT=10

View File

@@ -5,7 +5,7 @@
1. Start the docker-compose stack
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
```bash
cd ../docker
docker-compose -f docker-compose.middleware.yaml -p dify up -d
@@ -15,7 +15,7 @@
3. Generate a `SECRET_KEY` in the `.env` file.
```bash
openssl rand -base64 42
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
3.5 If you use annaconda, create a new environment and activate it
```bash
@@ -46,7 +46,7 @@
```
pip install -r requirements.txt --upgrade --force-reinstall
```
6. Start backend:
```bash
flask run --host 0.0.0.0 --port=5001 --debug

View File

@@ -26,6 +26,7 @@ from config import CloudEditionConfig, Config
from extensions import (
ext_celery,
ext_code_based_extension,
ext_compress,
ext_database,
ext_hosting_provider,
ext_login,
@@ -96,6 +97,7 @@ def create_app(test_config=None) -> Flask:
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_compress.init_app(app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)

View File

@@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
from models.account import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account
from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
from models.provider import Provider, ProviderModel
@@ -109,28 +109,138 @@ def reset_encrypt_key_pair():
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
return
tenant = db.session.query(Tenant).first()
if not tenant:
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
db.session.query(ProviderModel).delete()
db.session.commit()
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(click.style('Congratulations! '
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
click.echo(click.style('Congratulations! '
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('vdb-migrate', help='migrate vector db.')
def vdb_migrate():
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']:
migrate_knowledge_vector_database()
if scope in ['annotation', 'all']:
migrate_annotation_vector_database()
def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style('Start migrate annotation data.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
page = 1
while True:
try:
# get apps info
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
try:
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
)
documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
raise e
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
raise e
click.echo(f'Successfully migrated app annotation {app.id}.')
create_count += 1
except Exception as e:
click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1
@@ -143,14 +253,19 @@ def vdb_migrate():
page += 1
for dataset in datasets:
total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ''
if vector_type == "weaviate":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
@@ -167,7 +282,7 @@ def vdb_migrate():
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
@@ -176,7 +291,7 @@ def vdb_migrate():
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
@@ -186,11 +301,17 @@ def vdb_migrate():
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
click.echo(f"Start to migrate dataset {dataset.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
@@ -201,6 +322,7 @@ def vdb_migrate():
).all()
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
@@ -220,15 +342,22 @@ def vdb_migrate():
)
documents.append(document)
segments_count = segments_count + 1
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
raise e
click.echo(f"Dataset {dataset.id} create successfully.")
db.session.add(dataset)
db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.')
create_count += 1
except Exception as e:
db.session.rollback()
@@ -237,7 +366,9 @@ def vdb_migrate():
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
def register_commands(app):

View File

@@ -22,6 +22,7 @@ DEFAULTS = {
'SERVICE_API_URL': 'https://api.dify.ai',
'APP_WEB_URL': 'https://udify.app',
'FILES_URL': '',
'S3_ADDRESS_STYLE': 'auto',
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
@@ -59,7 +60,8 @@ DEFAULTS = {
'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20
'BATCH_UPLOAD_LIMIT': 20,
'TOOL_ICON_CACHE_MAX_AGE': 3600,
}
@@ -90,7 +92,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.7"
self.CURRENT_VERSION = "0.5.11"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -180,6 +182,11 @@ class Config:
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
# ------------------------
# Vector Store Configurations.
@@ -293,6 +300,9 @@ class Config:
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
class CloudEditionConfig(Config):

View File

@@ -2,7 +2,7 @@ import json
from models.model import AppModelConfig
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
language_timezone_mapping = {
'en-US': 'America/New_York',
@@ -16,6 +16,7 @@ language_timezone_mapping = {
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
'vi-VN': 'Asia/Ho_Chi_Minh',
}
@@ -79,6 +80,16 @@ user_input_form_template = {
}
}
],
"vi-VN": [
{
"paragraph": {
"label": "Nội dung truy vấn",
"variable": "default_input",
"required": False,
"default": ""
}
}
],
}
demo_model_templates = {
@@ -208,7 +219,6 @@ demo_model_templates = {
)
}
],
'zh-Hans': [
{
'name': '翻译助手',
@@ -335,91 +345,92 @@ demo_model_templates = {
)
}
],
'uk-UA': [{
"name": "Помічник перекладу",
"icon": "",
"icon_background": "",
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
"mode": "completion",
"model_config": AppModelConfig(
provider="openai",
model_id="gpt-3.5-turbo-instruct",
configs={
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
"prompt_variables": [
{
"key": "target_language",
"name": "Цільова мова",
"description": "Мова, на яку ви хочете перекласти.",
"type": "select",
"default": "Ukrainian",
"options": [
"Chinese",
"English",
"Japanese",
"French",
"Russian",
"German",
"Spanish",
"Korean",
"Italian",
],
'uk-UA': [
{
"name": "Помічник перекладу",
"icon": "",
"icon_background": "",
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
"mode": "completion",
"model_config": AppModelConfig(
provider="openai",
model_id="gpt-3.5-turbo-instruct",
configs={
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
"prompt_variables": [
{
"key": "target_language",
"name": "Цільова мова",
"description": "Мова, на яку ви хочете перекласти.",
"type": "select",
"default": "Ukrainian",
"options": [
"Chinese",
"English",
"Japanese",
"French",
"Russian",
"German",
"Spanish",
"Korean",
"Italian",
],
},
],
"completion_params": {
"max_token": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
],
"completion_params": {
"max_token": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
},
opening_statement="",
suggested_questions=None,
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
}),
user_input_form=json.dumps([
{
"select": {
"label": "Цільова мова",
"variable": "target_language",
"description": "Мова, на яку ви хочете перекласти.",
"default": "Chinese",
"required": True,
'options': [
'Chinese',
'English',
'Japanese',
'French',
'Russian',
'German',
'Spanish',
'Korean',
'Italian',
]
opening_statement="",
suggested_questions=None,
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
}),
user_input_form=json.dumps([
{
"select": {
"label": "Цільова мова",
"variable": "target_language",
"description": "Мова, на яку ви хочете перекласти.",
"default": "Chinese",
"required": True,
'options': [
'Chinese',
'English',
'Japanese',
'French',
'Russian',
'German',
'Spanish',
'Korean',
'Italian',
]
}
}, {
"paragraph": {
"label": "Запит",
"variable": "query",
"required": True,
"default": ""
}
}
}, {
"paragraph": {
"label": "Запит",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
},
])
)
},
{
"name": "AI інтерв’юер фронтенду",
"icon": "",
@@ -460,5 +471,132 @@ demo_model_templates = {
),
}
],
'vi-VN': [
{
'name': 'Trợ lý dịch thuật',
'icon': '',
'icon_background': '',
'description': 'Trình dịch đa ngôn ngữ cung cấp khả năng dịch bằng nhiều ngôn ngữ, dịch thông tin đầu vào của người dùng sang ngôn ngữ họ cần.',
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "Hãy dịch đoạn văn bản sau sang ngôn ngữ {{target_language}}:\n",
'prompt_variables': [
{
"key": "target_language",
"name": "Ngôn ngữ đích",
"description": "Ngôn ngữ bạn muốn dịch sang.",
"type": "select",
"default": "Vietnamese",
'options': [
'Chinese',
'English',
'Japanese',
'French',
'Russian',
'German',
'Spanish',
'Korean',
'Italian',
'Vietnamese',
]
}
],
'completion_params': {
'max_token': 1000,
'temperature': 0,
'top_p': 0,
'presence_penalty': 0.1,
'frequency_penalty': 0.1,
}
},
opening_statement='',
suggested_questions=None,
pre_prompt="Hãy dịch đoạn văn bản sau sang {{target_language}}:\n{{query}}\ndịch:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1
}
}),
user_input_form=json.dumps([
{
"select": {
"label": "Ngôn ngữ đích",
"variable": "target_language",
"description": "Ngôn ngữ bạn muốn dịch sang.",
"default": "Vietnamese",
"required": True,
'options': [
'Chinese',
'English',
'Japanese',
'French',
'Russian',
'German',
'Spanish',
'Korean',
'Italian',
'Vietnamese',
]
}
}, {
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
},
{
'name': 'Phỏng vấn front-end AI',
'icon': '',
'icon_background': '',
'description': 'Một người phỏng vấn front-end mô phỏng để kiểm tra mức độ kỹ năng phát triển front-end thông qua việc đặt câu hỏi.',
'mode': 'chat',
'model_config': AppModelConfig(
provider='openai',
model_id='gpt-3.5-turbo',
configs={
'introduction': 'Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ',
'prompt_template': "Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n",
'prompt_variables': [],
'completion_params': {
'max_token': 300,
'temperature': 0.8,
'top_p': 0.9,
'presence_penalty': 0.1,
'frequency_penalty': 0.1,
}
},
opening_statement='Xin chào, chào mừng đến với cuộc phỏng vấn của chúng tôi. Tôi là người phỏng vấn cho công ty công nghệ này và tôi sẽ kiểm tra kỹ năng phát triển web front-end của bạn. Tiếp theo, tôi sẽ hỏi bạn một số câu hỏi kỹ thuật. Hãy trả lời chúng càng kỹ lưỡng càng tốt. ',
suggested_questions=None,
pre_prompt="Bạn sẽ đóng vai người phỏng vấn cho một công ty công nghệ, kiểm tra kỹ năng phát triển web front-end của người dùng và đặt ra 5-10 câu hỏi kỹ thuật sắc bén.\n\nXin lưu ý:\n- Mỗi lần chỉ hỏi một câu hỏi.\n - Sau khi người dùng trả lời một câu hỏi, hãy hỏi trực tiếp câu hỏi tiếp theo mà không cố gắng sửa bất kỳ lỗi nào mà thí sinh mắc phải.\n- Nếu bạn cho rằng người dùng đã không trả lời đúng cho một số câu hỏi liên tiếp, hãy hỏi ít câu hỏi hơn.\n- Sau đặt câu hỏi cuối cùng, bạn có thể hỏi câu hỏi này: Tại sao bạn lại rời bỏ công việc cuối cùng của mình? Sau khi người dùng trả lời câu hỏi này, vui lòng bày tỏ sự hiểu biết và ủng hộ của bạn.\n",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
"top_p": 0.9,
"presence_penalty": 0.1,
"frequency_penalty": 0.1
}
}),
user_input_form=None
)
}
],
}

View File

@@ -13,30 +13,14 @@ model_templates = {
'status': 'normal'
},
'model_config': {
'provider': 'openai',
'model_id': 'gpt-3.5-turbo-instruct',
'configs': {
'prompt_template': '',
'prompt_variables': [],
'completion_params': {
'max_token': 512,
'temperature': 1,
'top_p': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
}
},
'provider': '',
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
"completion_params": {}
}),
'user_input_form': json.dumps([
{
@@ -64,30 +48,14 @@ model_templates = {
'status': 'normal'
},
'model_config': {
'provider': 'openai',
'model_id': 'gpt-3.5-turbo',
'configs': {
'prompt_template': '',
'prompt_variables': [],
'completion_params': {
'max_token': 512,
'temperature': 1,
'top_p': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
}
},
'provider': '',
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
"completion_params": {}
})
}
},

View File

@@ -27,7 +27,9 @@ from fields.app_fields import (
from libs.login import login_required
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity
def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@@ -129,7 +131,7 @@ class AppListApi(Resource):
"No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model_entity.provider
model_config_dict["model"]["provider"] = default_model_entity.provider.provider
model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration(
@@ -236,7 +238,44 @@ class AppApi(Resource):
def get(self, app_id):
"""Get app detail"""
app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id)
app: App = _get_app(app_id, current_user.current_tenant_id)
# get original app model config
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
return app

View File

@@ -88,7 +88,7 @@ class ChatMessageTextApi(Resource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)

View File

@@ -1,3 +1,4 @@
import json
from flask import request
from flask_login import current_user
@@ -7,6 +8,9 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
@@ -38,6 +42,91 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
except Exception as e:
continue
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
except Exception as e:
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config)
db.session.flush()

View File

@@ -11,7 +11,7 @@ from controllers.console.datasets.error import (
UnsupportedFileTypeError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
@@ -39,6 +39,7 @@ class FileApi(Resource):
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check(resource='documents')
def post(self):
# get file from request

View File

@@ -85,7 +85,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
return {'data': response.data.decode('latin1')}

View File

@@ -1,6 +1,6 @@
import io
from flask import send_file
from flask import current_app, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -80,8 +80,33 @@ class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
return send_file(io.BytesIO(icon_bytes), mimetype=minetype, max_age=icon_cache_max_age)
class ToolModelProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
class ToolModelProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
return ToolManageService.list_model_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
)
class ToolApiProviderAddApi(Resource):
@setup_required
@@ -259,6 +284,7 @@ class ToolApiProviderPreviousTestApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
@@ -268,6 +294,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
args['credentials'],
args['parameters'],
@@ -281,6 +308,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')

View File

@@ -56,6 +56,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit
if resource == 'members' and 0 < members.limit <= members.size:
@@ -64,6 +65,13 @@ def cloud_edition_billing_resource_check(resource: str,
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
abort(403, error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
source = request.args.get('source')
if source == 'datasets':
abort(403, error_msg)
else:
return view(*args, **kwargs)
elif resource == 'workspace_custom' and not features.can_replace_logo:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:

View File

@@ -44,7 +44,7 @@ class AudioApi(Resource):
response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id,
file=file,
end_user=end_user
end_user=end_user.get_id()
)
return response
@@ -75,7 +75,7 @@ class AudioApi(Resource):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
@@ -86,7 +86,7 @@ class TextApi(Resource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=end_user,
end_user=end_user.get_id(),
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)

View File

@@ -28,6 +28,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_resource_check('documents', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
@@ -153,6 +154,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_resource_check('documents', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}

View File

@@ -197,11 +197,11 @@ class DatasetSegmentApi(DatasetApiResource):
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
parser.add_argument('segment', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args['segments'], document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
SegmentService.segment_create_args_validate(args['segment'], document)
segment = SegmentService.update_segment(args['segment'], segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form

View File

@@ -89,6 +89,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
if resource == 'members' and 0 < members.limit <= members.size:
raise Unauthorized(error_msg)
@@ -96,6 +97,8 @@ def cloud_edition_billing_resource_check(resource: str,
raise Unauthorized(error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
raise Unauthorized(error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Unauthorized(error_msg)
else:
return view(*args, **kwargs)

View File

@@ -84,7 +84,7 @@ class TextApi(WebApiResource):
tenant_id=app_model.tenant_id,
text=request.form['text'],
end_user=end_user.external_user_id,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)

View File

@@ -1,49 +0,0 @@
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass

View File

@@ -1,361 +0,0 @@
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = 0
for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
except Exception as e:
raise e
self.model_config.parameters['max_tokens'] = original_max_tokens
return True if result.message.tool_calls else False
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0
try:
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.message.content or "",
additional_kwargs={
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod
def get_system_message(cls):
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_config.provider == 'azure_openai':
model = model_config.model
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,306 +0,0 @@
import re
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
OutputParserException,
get_buffer_string,
)
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observatons
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
raise e
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
self.moving_summary_index = len(intermediate_steps)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
if 'chat_history' in kwargs:
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)

View File

@@ -84,7 +84,7 @@ class AppRunner:
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance

View File

@@ -1,4 +1,3 @@
import json
import logging
from typing import cast
@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance
model_instance = ModelInstance(
@@ -201,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
@@ -290,38 +288,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables
})
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""

View File

@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
@@ -181,7 +181,7 @@ class BasicApplicationRunner(AppRunner):
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
self.recalc_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
@@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
model=app_orchestration_config.model_config.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,

View File

@@ -89,6 +89,10 @@ class GenerateTaskPipeline:
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if stream:
return self._process_stream_response()
else:
@@ -303,6 +307,7 @@ class GenerateTaskPipeline:
.first()
)
db.session.refresh(agent_thought)
db.session.close()
if agent_thought:
response = {
@@ -330,6 +335,8 @@ class GenerateTaskPipeline:
.filter(MessageFile.id == event.message_file_id)
.first()
)
db.session.close()
# get extension
if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
@@ -413,6 +420,7 @@ class GenerateTaskPipeline:
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens

View File

@@ -35,7 +35,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.file.file_obj import FileObj
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager
@@ -195,13 +195,11 @@ class ApplicationManager:
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
@@ -233,8 +231,6 @@ class ApplicationManager:
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
@@ -651,6 +647,7 @@ class ApplicationManager:
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation = (
db.session.query(Conversation)
@@ -689,6 +686,7 @@ class ApplicationManager:
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(

View File

@@ -0,0 +1,8 @@
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'

View File

@@ -1,13 +1,14 @@
import enum
import importlib.util
import importlib
import json
import logging
import os
from collections import OrderedDict
from typing import Any, Optional
from pydantic import BaseModel
from core.utils.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum):
MODERATION = 'moderation'
@@ -36,7 +37,8 @@ class Extensible:
@classmethod
def scan_extensions(cls):
extensions = {}
extensions: list[ModuleExtension] = []
position_map = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
@@ -63,6 +65,7 @@ class Extensible:
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip())
position_map[extension_name] = position
if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
@@ -96,16 +99,15 @@ class Extensible:
with open(json_path, encoding='utf-8') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(
extensions.append(ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get('label'),
form_schema=json_data.get('form_schema'),
builtin=builtin,
position=position
)
))
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
return sorted_extensions

View File

@@ -1,199 +0,0 @@
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool

View File

@@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
@@ -144,7 +145,7 @@ class BaseAssistantApplicationRunner(AppRunner):
result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you should tell user to check it now."
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
else:
result += f"tool response: {response.message}."
@@ -154,9 +155,9 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tenant_id=self.application_generate_entity.tenant_id,
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
agent_tool=tool,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
@@ -171,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner):
}
)
runtime_parameters = {}
parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
@@ -213,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity
@@ -305,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner):
tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
@@ -320,18 +259,17 @@ class BaseAssistantApplicationRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
@@ -404,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
created_by=self.user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.commit()
db.session.close()
return result
def create_agent_thought(self, message_id: str, message: str,
@@ -447,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1
@@ -464,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None:
agent_thought.thought = thought
@@ -514,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
@@ -586,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
convert tool variables to db variables
"""
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
@@ -613,7 +566,11 @@ class BaseAssistantApplicationRunner(AppRunner):
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_inputs = json.loads(agent_thought.tool_input)
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input))
tool_inputs = { agent_thought.tool: agent_thought.tool_input }
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
@@ -644,4 +601,6 @@ class BaseAssistantApplicationRunner(AppRunner):
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
return result
db.session.close()
return result

View File

@@ -28,6 +28,9 @@ from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
def run(self, conversation: Conversation,
message: Message,
query: str,
@@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
# TODO: stop words
if 'Observation' not in app_orchestration_config.model_config.stop:
if 'Observation' not in app_orchestration_config.model_config.stop:
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
app_orchestration_config.model_config.stop.append('Observation')
# override inputs
@@ -130,8 +131,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
input=query
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# recalc llm max tokens
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@@ -181,7 +182,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=json.dumps(chunk)
content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text
),
usage=None
)
@@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
)
)
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
agent_scratchpad.append(scratchpad)
# get llm usage
@@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# invoke tool
error_response = None
try:
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
tool_response = tool_instance.invoke(
user_id=self.user_id,
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
tool_parameters=tool_call_args
)
# transform tool response to llm friendly response
tool_response = self.transform_tool_invoke_messages(tool_response)
@@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content,
thought=message.content or 'I am thinking about how to help you',
action_str='',
action=None,
observation=None,
@@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
result = ''
for scratchpad in agent_scratchpad:
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
return result
@@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
))
# add assistant message
if len(agent_scratchpad) > 0:
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
prompt_messages.append(AssistantPromptMessage(
content=(agent_scratchpad[-1].thought or '')
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
))
# add user message
if len(agent_scratchpad) > 0:
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
prompt_messages.append(UserPromptMessage(
content=(agent_scratchpad[-1].observation or ''),
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
))
self._is_first_iteration = False
return prompt_messages
elif mode == "completion":
# parse agent scratchpad
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
self._is_first_iteration = False
# parse prompt messages
return [UserPromptMessage(
content=first_prompt.replace("{{instruction}}", instruction)
@@ -655,4 +667,4 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
try:
return json.dumps(tools, ensure_ascii=False)
except json.JSONDecodeError:
return json.dumps(tools)
return json.dumps(tools)

View File

@@ -105,8 +105,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
messages_ids=message_file_ids
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# recalc llm max tokens
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,

View File

@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):

View File

@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):

View File

@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.

View File

@@ -1,4 +1,3 @@
import enum
import logging
from typing import Optional, Union
@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]

View File

@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

View File

@@ -0,0 +1,54 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
return cached_tool_parameter
else:
return None
def set(self, parameters: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@@ -82,6 +82,8 @@ class HostingConfiguration:
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)

View File

@@ -1,3 +1,4 @@
import concurrent.futures
import datetime
import json
import logging
@@ -62,7 +63,8 @@ class IndexingRunner:
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
@@ -120,7 +122,8 @@ class IndexingRunner:
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
@@ -186,7 +189,7 @@ class IndexingRunner:
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load(
index_processor=index_processor,
dataset=dataset,
@@ -414,9 +417,14 @@ class IndexingRunner:
if separator:
separator = separator.replace('\\n', '\n')
if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
chunk_overlap = segmentation['chunk_overlap']
else:
chunk_overlap = 0
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get('chunk_overlap', 0),
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ".", " ", ""],
embedding_model_instance=embedding_model_instance
@@ -643,17 +651,44 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 100
chunk_size = 10
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))
for i in range(0, len(documents), chunk_size):
for future in futures:
tokens += future.result()
indexing_end_at = time.perf_counter()
# update document status to completed
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="completed",
extra_update_params={
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.utcnow(),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
}
)
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance):
with flask_app.app_context():
# check document is paused
self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size]
tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
tokens += sum(
embedding_model_type_instance.get_num_tokens(
@@ -663,9 +698,9 @@ class IndexingRunner:
)
for document in chunk_documents
)
# load index
index_processor.load(dataset, chunk_documents)
db.session.add(dataset)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
@@ -680,18 +715,7 @@ class IndexingRunner:
db.session.commit()
indexing_end_at = time.perf_counter()
# update document status to completed
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="completed",
extra_update_params={
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.utcnow(),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
}
)
return tokens
def _check_document_paused_status(self, document_id: str):
indexing_cache_key = 'document_{}_is_paused'.format(document_id)
@@ -750,7 +774,7 @@ class IndexingRunner:
index_processor.load(dataset, documents)
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
text_docs: list[Document], process_rule: dict) -> list[Document]:
text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
@@ -768,7 +792,8 @@ class IndexingRunner:
)
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
process_rule=process_rule)
process_rule=process_rule, tenant_id=dataset.tenant_id,
doc_language=doc_language)
return documents

View File

@@ -47,11 +47,14 @@ class TokenBufferMemory:
files, message.app_model_config
)
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.query))

View File

@@ -147,7 +147,7 @@
- `input` (float) Input price, i.e., Prompt price
- `output` (float) Output price, i.e., returned content price
- `unit` (float) Pricing unit, e.g., per 100K price is `0.000001`
- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
- `currency` (string) Currency unit
### ProviderCredentialSchema

View File

@@ -149,7 +149,7 @@
- `input` (float) 输入单价,即 Prompt 单价
- `output` (float) 输出单价,即返回内容单价
- `unit` (float) 价格单位,如:每 100K 的单价`0.000001`
- `unit` (float) 价格单位,如以 1M tokens 计价,则单价对应的单位 token 数`0.000001`
- `currency` (string) 货币单位
### ProviderCredentialSchema

View File

@@ -73,8 +73,8 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
'type': 'int',
'help': {
'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.',
'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记这些标记在提示和完成之间共享',
'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.',
'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数',
},
'required': False,
'default': 64,

View File

@@ -17,7 +17,7 @@ class ModelType(Enum):
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
# TEXT2IMG = "text2img"
TEXT2IMG = "text2img"
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
@@ -36,6 +36,8 @@ class ModelType(Enum):
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
@@ -59,10 +61,11 @@ class ModelType(Enum):
return 'tts'
elif self == self.MODERATION:
return 'moderation'
elif self == self.TEXT2IMG:
return 'text2img'
else:
raise ValueError(f'invalid model type {self}')
class FetchFrom(Enum):
"""
Enum class for fetch from.
@@ -130,7 +133,7 @@ class ModelPropertyKey(Enum):
DEFAULT_VOICE = "default_voice"
VOICES = "voices"
WORD_LIMIT = "word_limit"
AUDOI_TYPE = "audio_type"
AUDIO_TYPE = "audio_type"
MAX_WORKERS = "max_workers"

View File

@@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.utils.position_helper import get_position_map, sort_by_position_map
class AIModel(ABC):
@@ -148,15 +149,7 @@ class AIModel(ABC):
]
# get _position.yaml file path
position_file_path = os.path.join(provider_model_type_path, '_position.yaml')
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
position_map = get_position_map(provider_model_type_path)
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
@@ -206,8 +199,7 @@ class AIModel(ABC):
model_schemas.append(model_schema)
# resort model schemas by position
if position_map:
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
# cache model schemas
self.model_schemas = model_schemas

View File

@@ -1,4 +1,3 @@
import importlib
import os
from abc import ABC, abstractmethod
@@ -7,6 +6,7 @@ import yaml
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
class ModelProvider(ABC):
@@ -104,17 +104,10 @@ class ModelProvider(ABC):
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break
mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel)), None)
if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')

View File

@@ -0,0 +1,48 @@
from abc import abstractmethod
from typing import IO, Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
try:
return self._invoke(model, credentials, prompt, model_parameters, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
raise NotImplementedError

View File

@@ -94,8 +94,8 @@ class TTSModel(AIModel):
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""

View File

@@ -2,13 +2,17 @@
- anthropic
- azure_openai
- google
- nvidia
- cohere
- bedrock
- togetherai
- ollama
- mistralai
- groq
- replicate
- huggingface_hub
- xinference
- triton_inference_server
- zhipuai
- baichuan
- spark
@@ -18,7 +22,7 @@
- moonshot
- jina
- chatglm
- xinference
- yi
- openllm
- localai
- openai_api_compatible

View File

@@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider):
# Use `claude-instant-1` model for validate,
model_instance.validate_credentials(
model='claude-instant-1',
model='claude-instant-1.2',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@@ -2,8 +2,8 @@ provider: anthropic
label:
en_US: Anthropic
description:
en_US: Anthropics powerful models, such as Claude 2 and Claude Instant.
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant
en_US: Anthropics powerful models, such as Claude 3.
zh_Hans: Anthropic 的强大模型,例如 Claude 3
icon_small:
en_US: icon_s_en.svg
icon_large:

View File

@@ -0,0 +1,6 @@
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

View File

@@ -34,3 +34,4 @@ pricing:
output: '24.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@@ -0,0 +1,37 @@
model: claude-3-haiku-20240307
label:
en_US: claude-3-haiku-20240307
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '0.25'
output: '1.25'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,37 @@
model: claude-3-opus-20240229
label:
en_US: claude-3-opus-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '15.00'
output: '75.00'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,37 @@
model: claude-3-sonnet-20240229
label:
en_US: claude-3-sonnet-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,35 @@
model: claude-instant-1.2
label:
en_US: claude-instant-1.2
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD

View File

@@ -33,3 +33,4 @@ pricing:
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

View File

@@ -1,18 +1,32 @@
import base64
import mimetypes
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
import anthropic
import requests
from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
completion_create_params,
)
from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@@ -35,6 +49,7 @@ if you are not sure about the structure.
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
@@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
self._transform_json_prompts(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
# chat model
self._transform_chat_json_prompts(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
self._generate(
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[
@@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
],
model_parameters={
"temperature": 0,
"max_tokens_to_sample": 20,
"max_tokens": 20,
},
stream=False
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
response = client.completions.create(
model=model,
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
Handle llm chat response
:param model: model name
:param credentials: credentials
@@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.completion
content=response.content[0].text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
usage=usage
)
return result
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
Handle llm chat stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
:return: llm response chunk generator
"""
index = -1
full_assistant_content = ''
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
for chunk in response:
content = chunk.completion
if chunk.stop_reason is None and (content is None or content == ''):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)
index += 1
if chunk.stop_reason is not None:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if isinstance(chunk, MessageStartEvent):
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=chunk.model,
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=chunk.stop_reason,
index=index + 1,
message=AssistantPromptMessage(
content=''
),
finish_reason=finish_reason,
usage=usage
)
)
else:
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
index = chunk.index
yield LLMResultChunk(
model=chunk.model,
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
index=chunk.index,
message=assistant_prompt_message,
)
)
@@ -289,6 +337,88 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content=message.content.strip()
if first_loop:
system=message.content
first_loop=False
else:
system+="\n"
system+=message.content
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
@@ -302,8 +432,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{human_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{human_prompt} [IMAGE]"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{ai_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{ai_prompt} [IMAGE]"
elif isinstance(message, SystemPromptMessage):
message_text = content
else:

View File

@@ -15,10 +15,11 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN
class _CommonAzureOpenAI:
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION)
credentials_kwargs = {
"api_key": credentials['openai_api_key'],
"azure_endpoint": credentials['openai_api_base'],
"api_version": AZURE_OPENAI_API_VERSION,
"api_version": api_version,
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}

View File

@@ -14,8 +14,7 @@ from core.model_runtime.entities.model_entities import (
PriceConfig,
)
AZURE_OPENAI_API_VERSION = '2023-12-01-preview'
AZURE_OPENAI_API_VERSION = '2024-02-15-preview'
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
rule = ParameterRule(
@@ -124,6 +123,65 @@ LLM_BASE_MODELS = [
)
)
),
AzureBaseModel(
base_model_name='gpt-35-turbo-0125',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 16385,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
),
],
pricing=PriceConfig(
input=0.0005,
output=0.0015,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='gpt-4',
entity=AIModelEntity(
@@ -274,6 +332,81 @@ LLM_BASE_MODELS = [
)
)
),
AzureBaseModel(
base_model_name='gpt-4-0125-preview',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
),
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
),
],
pricing=PriceConfig(
input=0.01,
output=0.03,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='gpt-4-1106-preview',
entity=AIModelEntity(
@@ -524,5 +657,172 @@ EMBEDDING_BASE_MODELS = [
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-small',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00002,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-large',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00013,
unit=0.001,
currency='USD',
)
)
)
]
SPEECH2TEXT_BASE_MODELS = [
AzureBaseModel(
base_model_name='whisper-1',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
}
)
)
]
TTS_BASE_MODELS = [
AzureBaseModel(
base_model_name='tts-1',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
ModelPropertyKey.VOICES: [
{
'mode': 'alloy',
'name': 'Alloy',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'echo',
'name': 'Echo',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'fable',
'name': 'Fable',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'onyx',
'name': 'Onyx',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'nova',
'name': 'Nova',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'shimmer',
'name': 'Shimmer',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
],
ModelPropertyKey.WORD_LIMIT: 120,
ModelPropertyKey.AUDIO_TYPE: 'mp3',
ModelPropertyKey.MAX_WORKERS: 5
},
pricing=PriceConfig(
input=0.015,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='tts-1-hd',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
ModelPropertyKey.VOICES: [
{
'mode': 'alloy',
'name': 'Alloy',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'echo',
'name': 'Echo',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'fable',
'name': 'Fable',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'onyx',
'name': 'Onyx',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'nova',
'name': 'Nova',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'shimmer',
'name': 'Shimmer',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
],
ModelPropertyKey.WORD_LIMIT: 120,
ModelPropertyKey.AUDIO_TYPE: 'mp3',
ModelPropertyKey.MAX_WORKERS: 5
},
pricing=PriceConfig(
input=0.03,
unit=0.001,
currency='USD',
)
)
)
]

View File

@@ -15,6 +15,8 @@ help:
supported_model_types:
- llm
- text-embedding
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
@@ -44,6 +46,22 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API key here
- variable: openai_api_version
label:
zh_Hans: API 版本
en_US: API Version
type: select
required: true
options:
- label:
en_US: 2024-02-15-preview
value: 2024-02-15-preview
- label:
en_US: 2023-12-01-preview
value: 2023-12-01-preview
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
- variable: base_model_name
label:
en_US: Base Model
@@ -57,6 +75,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-0125
value: gpt-35-turbo-0125
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-16k
value: gpt-35-turbo-16k
@@ -75,6 +99,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-0125-preview
value: gpt-4-0125-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-1106-preview
value: gpt-4-1106-preview
@@ -99,6 +129,36 @@ model_credential_schema:
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
show_on:
- variable: __model_type
value: speech2text
- label:
en_US: tts-1
value: tts-1
show_on:
- variable: __model_type
value: tts
- label:
en_US: tts-1-hd
value: tts-1-hd
show_on:
- variable: __model_type
value: tts
placeholder:
zh_Hans: 在此输入您的模型版本
en_US: Enter your model version

View File

@@ -0,0 +1,82 @@
import copy
from typing import IO, Optional
from openai import AzureOpenAI
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
return self._speech2text_invoke(model, credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
self._speech2text_invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:return: text for given audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# init model client
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.transcriptions.create(model=model, file=file)
return response.text
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@@ -0,0 +1,174 @@
import concurrent.futures
import copy
from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from openai import AzureOpenAI
from pydub import AudioSegment
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
from extensions.ext_storage import storage
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:param user: unique user id
:return: text translated to audio file
"""
try:
self._tts_invoke(
model=model,
credentials=credentials,
content_text='Hello Dify!',
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
"""
_tts_invoke text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
audio_bytes_list = list()
# Create a thread pool and map the function to the list of sentences
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
credentials=credentials) for sentence in sentences]
for future in futures:
try:
if future.result():
audio_bytes_list.append(future.result())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
if len(audio_bytes_list) > 0:
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
audio_bytes_list if audio_bytes]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try:
client = AzureOpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
# response.stream_to_file(file_path)
storage.save(file_path, response.read())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _process_sentence(self, sentence: str, model: str,
voice, credentials: dict):
"""
_tts_invoke openai text2speech model api
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param sentence: text content to be translated
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
if isinstance(response.read(), bytes):
return response.read()
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in TTS_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@@ -108,7 +108,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(e)
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:
@@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
elif err == 'insufficient_quota':
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
raise InvalidAuthenticationError(msg)
raise InvalidAuthenticationError(msg)
elif err and 'rate' in err:
raise RateLimitReachedError(msg)
elif err and 'internal' in err:

View File

@@ -18,9 +18,10 @@ class BedrockProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `gemini-pro` model for validate,
# Use `amazon.titan-text-lite-v1` model by default for validating credentials
model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1')
model_instance.validate_credentials(
model='amazon.titan-text-lite-v1',
model=model_for_validation,
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@@ -48,24 +48,33 @@ provider_credential_schema:
- value: us-east-1
label:
en_US: US East (N. Virginia)
zh_Hans: US East (N. Virginia)
zh_Hans: 美国东部 (弗吉尼亚北部)
- value: us-west-2
label:
en_US: US West (Oregon)
zh_Hans: US West (Oregon)
zh_Hans: 美国西部 (俄勒冈州)
- value: ap-southeast-1
label:
en_US: Asia Pacific (Singapore)
zh_Hans: Asia Pacific (Singapore)
zh_Hans: 亚太地区 (新加坡)
- value: ap-northeast-1
label:
en_US: Asia Pacific (Tokyo)
zh_Hans: Asia Pacific (Tokyo)
zh_Hans: 亚太地区 (东京)
- value: eu-central-1
label:
en_US: Europe (Frankfurt)
zh_Hans: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)
zh_Hans: AWS GovCloud (US-West)
- variable: model_for_validation
required: false
label:
en_US: Available Model Name
zh_Hans: 可用模型名称
type: secret-input
placeholder:
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如amazon.titan-text-lite-v1)

View File

@@ -4,6 +4,8 @@
- anthropic.claude-v1
- anthropic.claude-v2
- anthropic.claude-v2:1
- anthropic.claude-3-sonnet-v1:0
- anthropic.claude-3-haiku-v1:0
- cohere.command-light-text-v14
- cohere.command-text-v14
- meta.llama2-13b-chat-v1

View File

@@ -0,0 +1,57 @@
model: anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@@ -0,0 +1,56 @@
model: anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00025'
output: '0.00125'
unit: '0.001'
currency: USD

View File

@@ -1,33 +1,50 @@
model: anthropic.claude-instant-v1
label:
en_US: Claude Instant V1
en_US: Claude Instant 1
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: topK
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.0008'
output: '0.0024'

View File

@@ -1,33 +1,50 @@
model: anthropic.claude-v1
label:
en_US: Claude V1
en_US: Claude 1
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.008'
output: '0.024'

View File

@@ -1,33 +1,50 @@
model: anthropic.claude-v2:1
label:
en_US: Claude V2.1
en_US: Claude 2.1
model_type: llm
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.008'
output: '0.024'

View File

@@ -1,33 +1,50 @@
model: anthropic.claude-v2
label:
en_US: Claude V2
en_US: Claude 2
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top K
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 250
min: 0
max: 500
- name: max_tokens_to_sample
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.008'
output: '0.024'

View File

@@ -1,9 +1,22 @@
import base64
import json
import logging
import mimetypes
import time
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
import boto3
import requests
from anthropic import AnthropicBedrock, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
)
from botocore.config import Config
from botocore.exceptions import (
ClientError,
@@ -13,14 +26,18 @@ from botocore.exceptions import (
UnknownServiceError,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@@ -54,9 +71,293 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
# invoke anthropic models via anthropic official SDK
if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
# invoke other models via boto3 client
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke Anthropic large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
# use Anthropic official SDK references
# - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
# - https://github.com/anthropics/anthropic-sdk-python
client = AnthropicBedrock(
aws_access_key=credentials["aws_access_key_id"],
aws_secret_key=credentials["aws_secret_access_key"],
aws_region=credentials["aws_region"],
)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
# Notice: If you request the current version of the SDK to the bedrock server,
# you will get the following error message and you need to wait for the service or SDK to be updated.
# Response: Error code: 400
# {'message': 'Malformed input request: #: subject must not be valid against schema
# {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
# TODO: Open in the future when the interface is properly supported
# if user:
# ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
# extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
return self._handle_claude_response(model, credentials, response, prompt_messages)
def _handle_claude_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: full response chunk generator result
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.content[0].text
)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
return response
def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage], ) -> Generator:
"""
Handle llm chat stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: full response or stream response chunk generator result
"""
try:
full_assistant_content = ''
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
for chunk in response:
if isinstance(chunk, MessageStartEvent):
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(
content=''
),
finish_reason=finish_reason,
usage=usage
)
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
)
index = chunk.index
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
)
except Exception as ex:
raise InvokeError(str(ex))
def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param prompt_tokens: prompt tokens
:param completion_tokens: completion tokens
:return: usage
"""
# get prompt price info
prompt_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=prompt_tokens,
)
# get completion price info
completion_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.OUTPUT,
tokens=completion_tokens
)
# transform usage
usage = LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=prompt_price_info.unit_price,
prompt_price_unit=prompt_price_info.unit,
prompt_price=prompt_price_info.total_amount,
completion_tokens=completion_tokens,
completion_unit_price=completion_price_info.unit_price,
completion_price_unit=completion_price_info.unit,
completion_price=completion_price_info.total_amount,
total_tokens=prompt_tokens + completion_tokens,
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
currency=prompt_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content=message.content.strip()
if first_loop:
system=message.content
first_loop=False
else:
system+="\n"
system+=message.content
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
@@ -101,7 +402,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param credentials: model credentials
:return:
"""
if "anthropic.claude-3" in model:
try:
self._invoke_claude(model=model,
credentials=credentials,
prompt_messages=[{"role": "user", "content": "ping"}],
model_parameters={},
stop=None,
stream=False)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
try:
ping_message = UserPromptMessage(content="ping")
self._generate(model=model,

View File

@@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
else:
raise ValueError(f"Got unknown type {message}")
if message.name is not None:
if message.name:
message_dict["user_name"] = message.name
return message_dict

View File

@@ -0,0 +1,11 @@
<svg width="112" height="24" viewBox="0 0 112 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M57.4336 17.092C56.4746 16.5453 55.7216 15.7924 55.1749 14.8244C54.6283 13.8564 54.3594 12.763 54.3594 11.544C54.3594 10.3251 54.6283 9.2137 55.1749 8.24571C55.7216 7.27772 56.4746 6.52485 57.4336 5.98708C58.3926 5.4493 59.4861 5.18042 60.6961 5.18042C61.6999 5.18042 62.623 5.3776 63.4476 5.77197C64.2722 6.16633 64.9445 6.73995 65.4554 7.49284L64.568 8.13816C64.1199 7.51076 63.5642 7.04469 62.9009 6.731C62.2377 6.41729 61.5027 6.26492 60.705 6.26492C59.7281 6.26492 58.8498 6.48899 58.0789 6.92818C57.2992 7.36736 56.6986 7.98579 56.2505 8.79244C55.8113 9.59014 55.5872 10.5133 55.5872 11.553C55.5872 12.5926 55.8113 13.5159 56.2505 14.3136C56.6896 15.1112 57.2992 15.7297 58.0789 16.1778C58.8587 16.617 59.7281 16.8411 60.705 16.8411C61.5027 16.8411 62.2377 16.6888 62.9009 16.375C63.5642 16.0613 64.1199 15.5953 64.568 14.9678L65.4554 15.6132C64.9445 16.366 64.2722 16.9396 63.4476 17.334C62.623 17.7284 61.7089 17.9255 60.6961 17.9255C59.4771 17.9255 58.3926 17.6568 57.4336 17.11V17.092Z" fill="#F55036"/>
<path d="M67.2754 0H68.4763V17.8181H67.2754V0Z" fill="#F55036"/>
<path d="M73.6754 17.092C72.7254 16.5454 71.9725 15.7924 71.4347 14.8244C70.888 13.8564 70.6191 12.763 70.6191 11.544C70.6191 10.3251 70.888 9.23163 71.4347 8.26364C71.9814 7.29566 72.7254 6.54277 73.6754 5.99604C74.6255 5.4493 75.6921 5.18042 76.8841 5.18042C78.0762 5.18042 79.1338 5.4493 80.0928 5.99604C81.0429 6.54277 81.7957 7.29566 82.3335 8.26364C82.8803 9.23163 83.1492 10.3251 83.1492 11.544C83.1492 12.763 82.8803 13.8564 82.3335 14.8244C81.7868 15.7924 81.0429 16.5454 80.0928 17.092C79.1427 17.6387 78.0673 17.9076 76.8841 17.9076C75.7011 17.9076 74.6344 17.6387 73.6754 17.092ZM79.4655 16.1599C80.2273 15.7118 80.8277 15.0843 81.2669 14.2867C81.7062 13.489 81.9302 12.5747 81.9302 11.553C81.9302 10.5312 81.7062 9.61703 81.2669 8.81933C80.8277 8.02164 80.2273 7.39425 79.4655 6.9461C78.7036 6.49796 77.8431 6.27389 76.8841 6.27389C75.9251 6.27389 75.0646 6.49796 74.3028 6.9461C73.5409 7.39425 72.9405 8.02164 72.5013 8.81933C72.0621 9.61703 71.838 10.5312 71.838 11.553C71.838 12.5747 72.0621 13.489 72.5013 14.2867C72.9405 15.0843 73.5409 15.7118 74.3028 16.1599C75.0646 16.608 75.9251 16.8322 76.8841 16.8322C77.8431 16.8322 78.7036 16.608 79.4655 16.1599Z" fill="#F55036"/>
<path d="M96.2799 5.27905V17.8091H95.1237V15.1203C94.7114 15.9986 94.0929 16.6887 93.2774 17.1728C92.4618 17.6567 91.5027 17.9077 90.4003 17.9077C88.769 17.9077 87.4873 17.4506 86.5553 16.5364C85.6231 15.6222 85.166 14.3136 85.166 12.6017V5.27905H86.367V12.5031C86.367 13.9102 86.7255 14.9858 87.4515 15.7207C88.1775 16.4557 89.1903 16.8232 90.4989 16.8232C91.9061 16.8232 93.0264 16.384 93.851 15.5057C94.6756 14.6272 95.0878 13.4442 95.0878 11.9563V5.27905H96.2889H96.2799Z" fill="#F55036"/>
<path d="M110.952 0V17.8181H109.777V14.8604C109.284 15.8374 108.585 16.5902 107.689 17.119C106.793 17.6479 105.78 17.9077 104.642 17.9077C103.503 17.9077 102.419 17.6389 101.469 17.0922C100.528 16.5454 99.7838 15.7925 99.246 14.8336C98.7083 13.8745 98.4395 12.781 98.4395 11.5441C98.4395 10.3073 98.7083 9.2138 99.246 8.24582C99.7838 7.27783 100.519 6.52496 101.469 5.98718C102.41 5.44941 103.468 5.18053 104.642 5.18053C105.816 5.18053 106.766 5.44044 107.653 5.96925C108.541 6.49807 109.24 7.23301 109.75 8.17411V0H110.952ZM107.295 16.16C108.057 15.7119 108.657 15.0844 109.096 14.2868C109.535 13.4891 109.759 12.5749 109.759 11.5531C109.759 10.5313 109.535 9.61713 109.096 8.81944C108.657 8.02174 108.057 7.39434 107.295 6.9462C106.533 6.49807 105.672 6.27399 104.713 6.27399C103.754 6.27399 102.894 6.49807 102.132 6.9462C101.37 7.39434 100.77 8.02174 100.331 8.81944C99.8914 9.61713 99.6673 10.5313 99.6673 11.5531C99.6673 12.5749 99.8914 13.4891 100.331 14.2868C100.77 15.0844 101.37 15.7119 102.132 16.16C102.894 16.6081 103.754 16.8322 104.713 16.8322C105.672 16.8322 106.533 16.6081 107.295 16.16Z" fill="#F55036"/>
<path d="M30.6085 5.27024C27.077 5.27024 24.209 8.13835 24.209 11.6697C24.209 15.201 27.077 18.0692 30.6085 18.0692C34.1399 18.0692 37.0079 15.201 37.0079 11.6697C37.0079 8.13835 34.1399 5.27921 30.6085 5.27024ZM30.6085 15.6672C28.4036 15.6672 26.611 13.8746 26.611 11.6697C26.611 9.46486 28.4036 7.67228 30.6085 7.67228C32.8133 7.67228 34.6059 9.46486 34.6059 11.6697C34.6059 13.8746 32.8133 15.6672 30.6085 15.6672Z" fill="black"/>
<path d="M6.45358 5.23422C2.92222 5.19837 0.036187 8.0396 0.000335591 11.571C-0.0355158 15.1023 2.80571 17.9974 6.33706 18.0242C6.37292 18.0242 6.41773 18.0242 6.45358 18.0242H8.55986V15.6311H6.45358C4.24873 15.658 2.43823 13.8923 2.41134 11.6785C2.38445 9.47365 4.15014 7.66315 6.36395 7.63626C6.39084 7.63626 6.4267 7.63626 6.45358 7.63626C8.65844 7.63626 10.46 9.42884 10.46 11.6337V17.5222C10.46 19.7092 8.67637 21.4929 6.48943 21.5197C5.44078 21.5197 4.44591 21.0895 3.71095 20.3455L2.01698 22.0395C3.1911 23.2227 4.7865 23.8949 6.45358 23.9128H6.54321C10.0298 23.859 12.8351 21.0357 12.853 17.5491V11.4724C12.7635 8.00374 9.93116 5.23422 6.46254 5.23422H6.45358Z" fill="black"/>
<path d="M51.2406 11.5082C51.151 8.03961 48.3187 5.27009 44.8501 5.27009C41.3187 5.23423 38.4237 8.07545 38.3968 11.6068C38.361 15.1382 41.2022 18.0331 44.7335 18.0601C44.7694 18.0601 44.8143 18.0601 44.8501 18.0601H46.9563V15.667H44.8501C42.6452 15.6939 40.8347 13.9282 40.8078 11.7144C40.7809 9.5095 42.5467 7.69902 44.7604 7.67213C44.7874 7.67213 44.8232 7.67213 44.8501 7.67213C47.055 7.67213 48.8565 9.46469 48.8565 11.6696V23.626L51.2406 23.6528V11.5082Z" fill="black"/>
<path d="M14.6808 18.0602H17.0649V11.6607C17.0649 9.45589 18.8575 7.66332 21.0623 7.66332C21.7883 7.66332 22.4695 7.8605 23.0611 8.2011L24.2621 6.12172C23.3209 5.57498 22.2276 5.27024 21.0713 5.27024C17.5399 5.27024 14.6719 8.13835 14.6719 11.6697V18.0692L14.6808 18.0602Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 5.8 KiB

View File

@@ -0,0 +1,4 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="12" fill="#F55036"/>
<path d="M12.146 6.00022C9.87734 5.97718 8.02325 7.80249 8.00022 10.0712C7.97718 12.3398 9.80249 14.1997 12.0712 14.217C12.0942 14.217 12.123 14.217 12.146 14.217H13.4992V12.6796H12.146C10.7295 12.6968 9.56641 11.5625 9.54913 10.1403C9.53186 8.72377 10.6662 7.56065 12.0884 7.54337C12.1057 7.54337 12.1287 7.54337 12.146 7.54337C13.5625 7.54337 14.7199 8.69498 14.7199 10.1115V13.8945C14.7199 15.2995 13.574 16.4453 12.169 16.4626C11.4953 16.4626 10.8562 16.1862 10.384 15.7083L9.29578 16.7965C10.0501 17.5566 11.075 17.9885 12.146 18H12.2036C14.4435 17.9654 16.2457 16.1516 16.2572 13.9117V10.0078C16.1997 7.77945 14.3801 6.00022 12.1518 6.00022H12.146Z" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 828 B

Some files were not shown because too many files have changed in this diff Show More