mirror of
https://github.com/langgenius/dify.git
synced 2026-03-22 16:27:07 +00:00
Compare commits
738 Commits
fix/reset-
...
feat/ga-tr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8adb882c4b | ||
|
|
56abf7313a | ||
|
|
c7a63ed01a | ||
|
|
2b8be2c869 | ||
|
|
897458fbe1 | ||
|
|
f31519067d | ||
|
|
01cfe2c197 | ||
|
|
fa7a0fd3d1 | ||
|
|
2853ee4a68 | ||
|
|
2a20a96123 | ||
|
|
0ae2d2b631 | ||
|
|
da2421f378 | ||
|
|
02b11c2fe3 | ||
|
|
ddd643961f | ||
|
|
eef4223063 | ||
|
|
5824ab766c | ||
|
|
613b6cb035 | ||
|
|
197ba5f6ef | ||
|
|
692cd37402 | ||
|
|
7b424d6270 | ||
|
|
89447930b6 | ||
|
|
461078a0cb | ||
|
|
cb9148a39a | ||
|
|
cdd85ff736 | ||
|
|
07a4d7a182 | ||
|
|
56dd17cdaa | ||
|
|
92ab453a5a | ||
|
|
2cd322af21 | ||
|
|
97f0dff9e1 | ||
|
|
1205960c59 | ||
|
|
8254a40e75 | ||
|
|
7e21e8ab01 | ||
|
|
a5ec05acbc | ||
|
|
61a0fcc2ea | ||
|
|
f627348b11 | ||
|
|
87fb9a6b69 | ||
|
|
97a2e2ec2e | ||
|
|
68d357d7f6 | ||
|
|
a103ad3ee7 | ||
|
|
f65d5a9761 | ||
|
|
6e0a5f5bbd | ||
|
|
22f858152f | ||
|
|
83f2e14249 | ||
|
|
d8fe0c916b | ||
|
|
775d2e14fc | ||
|
|
744b287e67 | ||
|
|
c0fc5d98f0 | ||
|
|
08ea79d730 | ||
|
|
f31b821cc0 | ||
|
|
34be16874f | ||
|
|
e9738b891f | ||
|
|
829796514a | ||
|
|
ef1db35f80 | ||
|
|
f9c67621ca | ||
|
|
e29e8e3180 | ||
|
|
7a81e720d4 | ||
|
|
55600c0eb1 | ||
|
|
35e41d7d68 | ||
|
|
b610cf9a11 | ||
|
|
c8e9edc024 | ||
|
|
471cd760d7 | ||
|
|
7f48c57edf | ||
|
|
6569801162 | ||
|
|
9dd83f50a7 | ||
|
|
59c56b1b0d | ||
|
|
94cd2de940 | ||
|
|
3c23375607 | ||
|
|
56047f638f | ||
|
|
9c01d3e775 | ||
|
|
c85c87f3da | ||
|
|
eaa02e3d55 | ||
|
|
0219222a60 | ||
|
|
dba659b220 | ||
|
|
ee6458768e | ||
|
|
ed3d02dc6d | ||
|
|
95471b1188 | ||
|
|
6190cfbfd8 | ||
|
|
11f2f95103 | ||
|
|
2abbc14703 | ||
|
|
b2b2816ade | ||
|
|
4461df1bd9 | ||
|
|
f7f6b4a8b0 | ||
|
|
41be581594 | ||
|
|
20ad5b7ac2 | ||
|
|
a1c0bd7a1c | ||
|
|
fd7c4e8a6d | ||
|
|
41e549af14 | ||
|
|
b7360140ee | ||
|
|
c71f7c7613 | ||
|
|
c905c47775 | ||
|
|
4ca7ba000c | ||
|
|
f260627660 | ||
|
|
1e9142c213 | ||
|
|
82890fe38e | ||
|
|
7dc7c8af98 | ||
|
|
addebc465a | ||
|
|
69e3ccaef0 | ||
|
|
5ab315aeaf | ||
|
|
f092bc1912 | ||
|
|
23b49b8304 | ||
|
|
9e97248ede | ||
|
|
d532b06310 | ||
|
|
07a2281730 | ||
|
|
42385f3ffa | ||
|
|
c597234374 | ||
|
|
3de73f07c6 | ||
|
|
0caeaf6e5c | ||
|
|
2c2b3092f6 | ||
|
|
3395297c3e | ||
|
|
e60a7c7143 | ||
|
|
0e62a66cc2 | ||
|
|
ff32dff163 | ||
|
|
543c5236e7 | ||
|
|
341b3ae7c9 | ||
|
|
f01907aac2 | ||
|
|
a7c855cab8 | ||
|
|
29afc0657d | ||
|
|
d9860b8907 | ||
|
|
4a797ab2d8 | ||
|
|
63b74c17df | ||
|
|
dc1ae57dc6 | ||
|
|
d6bd2a9bdb | ||
|
|
c9eed67cf6 | ||
|
|
0ded6303c1 | ||
|
|
1058961fd8 | ||
|
|
b340234711 | ||
|
|
79200d89d8 | ||
|
|
a93cbc0461 | ||
|
|
b6e0abadab | ||
|
|
4736819dd9 | ||
|
|
b5e23b84ce | ||
|
|
5a4afcf8fd | ||
|
|
43bcf40f80 | ||
|
|
b407cf0189 | ||
|
|
74c0afe5f5 | ||
|
|
8b7fe6add7 | ||
|
|
f06025a342 | ||
|
|
b827bcdd47 | ||
|
|
0b021273bc | ||
|
|
4590f7daa5 | ||
|
|
7d19ca6a03 | ||
|
|
74511c6fdc | ||
|
|
3bcadbe673 | ||
|
|
24fb95b050 | ||
|
|
49fca63927 | ||
|
|
ce5fe86430 | ||
|
|
666586b59c | ||
|
|
8a2851551a | ||
|
|
a2fe4a28c3 | ||
|
|
417ebd160b | ||
|
|
82be305680 | ||
|
|
03002f4971 | ||
|
|
1e7e8a8988 | ||
|
|
873de82d01 | ||
|
|
a715d5ac23 | ||
|
|
398c8117fe | ||
|
|
9eb46c297b | ||
|
|
24fbbbb07b | ||
|
|
7551d14256 | ||
|
|
f45c18ee35 | ||
|
|
48db4ab271 | ||
|
|
15c1db42dd | ||
|
|
a31c01f8d9 | ||
|
|
3d08c79c3e | ||
|
|
62753cdf13 | ||
|
|
2e0a7857f0 | ||
|
|
dc7ce125ad | ||
|
|
eabdb09f8e | ||
|
|
fa6d03c979 | ||
|
|
634fb192ef | ||
|
|
a4b38e7521 | ||
|
|
8ff6de91b0 | ||
|
|
7fa0ad3161 | ||
|
|
53b21eea61 | ||
|
|
2f3a61b51b | ||
|
|
cf6742e9f2 | ||
|
|
b48a7c7cda | ||
|
|
a799493bf7 | ||
|
|
dc21ab5f59 | ||
|
|
e7a575a33c | ||
|
|
ffd3a461f6 | ||
|
|
8bca7814f4 | ||
|
|
92c81b1833 | ||
|
|
f3b63e5126 | ||
|
|
44553d412c | ||
|
|
b5f25c85a5 | ||
|
|
95ce224df0 | ||
|
|
8555635967 | ||
|
|
e843fe8aa6 | ||
|
|
b198c9474a | ||
|
|
6feeef96c8 | ||
|
|
ffc551f003 | ||
|
|
2b64ff1af2 | ||
|
|
1a351d2832 | ||
|
|
5d839497fc | ||
|
|
572abae37a | ||
|
|
cdf3cffb30 | ||
|
|
daf474dd3a | ||
|
|
d380b2f431 | ||
|
|
01daef21e2 | ||
|
|
e8cb4a5aa7 | ||
|
|
e554790a03 | ||
|
|
7222b1f36f | ||
|
|
a29f3107e9 | ||
|
|
43ccac4197 | ||
|
|
784633a0dc | ||
|
|
3f1e4a201e | ||
|
|
0326786cad | ||
|
|
5fb6b51998 | ||
|
|
753234fdfe | ||
|
|
a23bf53d2b | ||
|
|
6087a80455 | ||
|
|
21f7ccf67d | ||
|
|
b3316e3755 | ||
|
|
d9ef302259 | ||
|
|
7434460b5c | ||
|
|
698a94cc3e | ||
|
|
1f4c541c0d | ||
|
|
faefc3492e | ||
|
|
50bbd5fd65 | ||
|
|
c79977ab57 | ||
|
|
4905c26379 | ||
|
|
28c5d3898f | ||
|
|
40e3d4dc7d | ||
|
|
8cf4a0d3ad | ||
|
|
c22ba3f537 | ||
|
|
c200bbb9fc | ||
|
|
66bca831cc | ||
|
|
a2b1a148d3 | ||
|
|
d7905fb5fe | ||
|
|
1705c07967 | ||
|
|
69eb38060d | ||
|
|
ab6e00a6c1 | ||
|
|
a85b7dac33 | ||
|
|
3d3e09d54d | ||
|
|
af68599ce5 | ||
|
|
977188505e | ||
|
|
8aa4db0c77 | ||
|
|
65a3646ce7 | ||
|
|
c7b82f2236 | ||
|
|
37c28c5edb | ||
|
|
6af9aeb345 | ||
|
|
946f0b00e4 | ||
|
|
b1a8db77db | ||
|
|
7193333b75 | ||
|
|
8a0f14fde4 | ||
|
|
106f0fba0b | ||
|
|
cb73335599 | ||
|
|
f4fa57dac9 | ||
|
|
df1859f6b8 | ||
|
|
b6b1140a21 | ||
|
|
a12350f0a0 | ||
|
|
8555dd8154 | ||
|
|
06f364f2c8 | ||
|
|
979d35d87f | ||
|
|
7ca06931ec | ||
|
|
52b708aaad | ||
|
|
7275485bc1 | ||
|
|
6405228f3f | ||
|
|
fa1b497e0f | ||
|
|
23d073f16c | ||
|
|
40af17fdac | ||
|
|
f4d62baaca | ||
|
|
584921081b | ||
|
|
8a68c2877b | ||
|
|
c12b1f4abd | ||
|
|
ca46718a88 | ||
|
|
d14413f3b0 | ||
|
|
4cb42499ae | ||
|
|
3c6035490d | ||
|
|
4a9fe55976 | ||
|
|
7d91f4783b | ||
|
|
5c6a2af448 | ||
|
|
4fd968270c | ||
|
|
708a7dd362 | ||
|
|
cd85b75312 | ||
|
|
d685da377e | ||
|
|
8583992d23 | ||
|
|
7877706b5f | ||
|
|
7a54607219 | ||
|
|
e3f3cac10f | ||
|
|
8090342f60 | ||
|
|
0e60b30d23 | ||
|
|
651814b78e | ||
|
|
8ee6ae5a25 | ||
|
|
54fc8f4390 | ||
|
|
4e5405dce2 | ||
|
|
e54131e95b | ||
|
|
23fec75c90 | ||
|
|
d98f375926 | ||
|
|
ebe7303894 | ||
|
|
79fb977f10 | ||
|
|
d5a7a537e5 | ||
|
|
c0af3414a3 | ||
|
|
0a6da0bf2f | ||
|
|
341ee17042 | ||
|
|
7ac1629797 | ||
|
|
ed97dce06e | ||
|
|
9c6d059227 | ||
|
|
064075ab5f | ||
|
|
1857d37fae | ||
|
|
60fdbb56a9 | ||
|
|
4c7853164d | ||
|
|
1f0cbfbdc4 | ||
|
|
1a699cb52d | ||
|
|
b211147ff0 | ||
|
|
2096aa61e8 | ||
|
|
7a6bf12453 | ||
|
|
b95f4822c3 | ||
|
|
ae2bd7d116 | ||
|
|
6d2673ca3f | ||
|
|
d1f34cc44c | ||
|
|
4a05b9ab90 | ||
|
|
d5bb050567 | ||
|
|
ce728bcff4 | ||
|
|
572619017a | ||
|
|
0a6cc130b2 | ||
|
|
af0e109f00 | ||
|
|
0d73133982 | ||
|
|
ec064ccdd8 | ||
|
|
cc5291f6af | ||
|
|
7ac496af7b | ||
|
|
04e69be990 | ||
|
|
6c7a3ce4bb | ||
|
|
b4d412fead | ||
|
|
a9e74b21f1 | ||
|
|
150250454e | ||
|
|
a538f80e95 | ||
|
|
e6730f7164 | ||
|
|
3344723393 | ||
|
|
c571185a91 | ||
|
|
325c1cfa41 | ||
|
|
1069421753 | ||
|
|
b33a97ea5b | ||
|
|
d2c1d4c337 | ||
|
|
67762cf1d8 | ||
|
|
eadce0287c | ||
|
|
7fc822379d | ||
|
|
b6d9360f72 | ||
|
|
9fc2a0a3a1 | ||
|
|
ecaff5b63f | ||
|
|
44dcad0d24 | ||
|
|
a300c9ef96 | ||
|
|
44fe71e4db | ||
|
|
0ac32188c5 | ||
|
|
9aaace706b | ||
|
|
b22de5a824 | ||
|
|
97463661c1 | ||
|
|
239a11855a | ||
|
|
0632557d91 | ||
|
|
44be7d4c51 | ||
|
|
efb4a9d327 | ||
|
|
a077a3f609 | ||
|
|
0d8bc70601 | ||
|
|
7eab5b794d | ||
|
|
2bab4e7164 | ||
|
|
3ccec0aab0 | ||
|
|
ca0a149f06 | ||
|
|
db83b54a88 | ||
|
|
f4567fbf9e | ||
|
|
8fd088754a | ||
|
|
05751ac492 | ||
|
|
e7d63a9fa3 | ||
|
|
3006133f0e | ||
|
|
a1e3a72274 | ||
|
|
c5566d707a | ||
|
|
79beb25530 | ||
|
|
b47b228164 | ||
|
|
be91db14d9 | ||
|
|
120893209e | ||
|
|
1a4600ce77 | ||
|
|
27e8a076ff | ||
|
|
f19630bcf5 | ||
|
|
9d93fda471 | ||
|
|
d986659add | ||
|
|
00dab7ca5f | ||
|
|
a4add403fb | ||
|
|
e9cdc96c74 | ||
|
|
6af1fea232 | ||
|
|
45d5d9e44f | ||
|
|
376a084aca | ||
|
|
d1f42d47fe | ||
|
|
64b8fd87ad | ||
|
|
364be48248 | ||
|
|
2bce046278 | ||
|
|
1120d552b6 | ||
|
|
4055fcd891 | ||
|
|
857d82109d | ||
|
|
bf20f3aa8b | ||
|
|
048a2c7979 | ||
|
|
d0b99170f6 | ||
|
|
1d16528dff | ||
|
|
c0ed353c10 | ||
|
|
b786fdf484 | ||
|
|
69cab0817f | ||
|
|
26b63693aa | ||
|
|
5883dc8876 | ||
|
|
3e14077e8d | ||
|
|
3a7f729176 | ||
|
|
fa11f0e488 | ||
|
|
8673234f12 | ||
|
|
ac3bd5161c | ||
|
|
8c55f99740 | ||
|
|
6e0aa7766c | ||
|
|
c4d03bf378 | ||
|
|
61d9428064 | ||
|
|
f6038a4557 | ||
|
|
c90e564d99 | ||
|
|
6c039be2ca | ||
|
|
832dabc8a4 | ||
|
|
c09d205776 | ||
|
|
468ec438ab | ||
|
|
eba703083e | ||
|
|
560b026523 | ||
|
|
b4f9698289 | ||
|
|
c9a2dc0b13 | ||
|
|
b6114266af | ||
|
|
33b576b9d5 | ||
|
|
1da2028d9d | ||
|
|
63de2cc7a0 | ||
|
|
c9c3aa0810 | ||
|
|
841b7fa7ce | ||
|
|
dee4399060 | ||
|
|
7c3f6dcc8d | ||
|
|
1472884eb5 | ||
|
|
d3ea98037e | ||
|
|
2c408445ff | ||
|
|
8d2b5c5464 | ||
|
|
97b5d4bba1 | ||
|
|
68c7e43d8c | ||
|
|
e0af930acd | ||
|
|
dfc8bc4aec | ||
|
|
63b4bca7d8 | ||
|
|
517f8aafdc | ||
|
|
d05ba90779 | ||
|
|
b6620c1f42 | ||
|
|
b60e1f4222 | ||
|
|
4df606f439 | ||
|
|
2a5a497f15 | ||
|
|
3f1da39aee | ||
|
|
740f970041 | ||
|
|
ec22b1c706 | ||
|
|
a1712df7c2 | ||
|
|
a40e11cb3e | ||
|
|
61c46bea40 | ||
|
|
1c5c28a82c | ||
|
|
2310145937 | ||
|
|
6a9c9cadd0 | ||
|
|
a52edc6cc1 | ||
|
|
7774ff9944 | ||
|
|
c367f80ec5 | ||
|
|
09998612e7 | ||
|
|
f71ad55d58 | ||
|
|
5b81397054 | ||
|
|
da353a42da | ||
|
|
e056e0835a | ||
|
|
e1819fb7e5 | ||
|
|
9b0f172f91 | ||
|
|
85dfc013ea | ||
|
|
5a1fae1171 | ||
|
|
33d4c95470 | ||
|
|
659cbc05a9 | ||
|
|
6ce65de2cd | ||
|
|
93b2eb3ff6 | ||
|
|
bf71300635 | ||
|
|
37ecd4a0bc | ||
|
|
827a1b181b | ||
|
|
c4e7cb75cd | ||
|
|
98e4bfcda8 | ||
|
|
ee48ca7671 | ||
|
|
4ba6de1116 | ||
|
|
bfbe636555 | ||
|
|
930fdc8fb4 | ||
|
|
fc90a8fb32 | ||
|
|
791f33fd0b | ||
|
|
1e0a3b163e | ||
|
|
bb1f1a56a5 | ||
|
|
15be85514d | ||
|
|
80cb85b845 | ||
|
|
0b131f1a8c | ||
|
|
0360f0b33b | ||
|
|
e51f2d68cb | ||
|
|
a6d4bf3399 | ||
|
|
113aa4ae08 | ||
|
|
5b40bf6d4e | ||
|
|
06ad8efd89 | ||
|
|
8513afbcf6 | ||
|
|
54ae43ef47 | ||
|
|
7a74b5ee3e | ||
|
|
560fe8a0f6 | ||
|
|
da27d261b0 | ||
|
|
c3e3a18ab4 | ||
|
|
ab34cea714 | ||
|
|
db0780cfa8 | ||
|
|
2b51fc23d9 | ||
|
|
0641773395 | ||
|
|
d745c2e8e3 | ||
|
|
e974c696f7 | ||
|
|
2ff280c4bf | ||
|
|
0e9d43d605 | ||
|
|
cc54363c27 | ||
|
|
c41140d654 | ||
|
|
89affe3139 | ||
|
|
f7fd065bee | ||
|
|
96c7c86e9d | ||
|
|
2c4977dbb1 | ||
|
|
e240175116 | ||
|
|
2398ed6fe8 | ||
|
|
a8420ac33c | ||
|
|
8470be6411 | ||
|
|
3d6295c622 | ||
|
|
ff2f7206f3 | ||
|
|
b937fc8978 | ||
|
|
86a9a51952 | ||
|
|
4188c9a1dd | ||
|
|
8833fee232 | ||
|
|
5bf642c3f9 | ||
|
|
8c00f89e36 | ||
|
|
9e8ac5c96b | ||
|
|
3d7d4182a6 | ||
|
|
75c221038d | ||
|
|
b7b5b0b8d0 | ||
|
|
05a67f4716 | ||
|
|
6eab6a675c | ||
|
|
f49476a206 | ||
|
|
c1e9c56e25 | ||
|
|
d5dd73cacf | ||
|
|
21f7a49b4e | ||
|
|
716ac04e13 | ||
|
|
c28a32fc47 | ||
|
|
31cba28e8a | ||
|
|
48cd7e6481 | ||
|
|
47aba1c9f9 | ||
|
|
d94e598a89 | ||
|
|
0f3f8bc0d9 | ||
|
|
e0df12c212 | ||
|
|
eb448d9bb8 | ||
|
|
0ba77f13db | ||
|
|
f0a2eb843c | ||
|
|
28acb70118 | ||
|
|
7c35aaa99d | ||
|
|
82fcf1da64 | ||
|
|
c662b95b80 | ||
|
|
a8c2a300f6 | ||
|
|
d654d9d8b1 | ||
|
|
eedc0ca6ea | ||
|
|
e4c3213978 | ||
|
|
520bc55da5 | ||
|
|
ebbb82178f | ||
|
|
d60a9ac63a | ||
|
|
f1486224e9 | ||
|
|
88b02d5a07 | ||
|
|
394b7d09b8 | ||
|
|
e9313b9c1b | ||
|
|
5cf3d9e4d9 | ||
|
|
41958f55cd | ||
|
|
600ad232e1 | ||
|
|
7a3825cfce | ||
|
|
9519653422 | ||
|
|
efa2307c73 | ||
|
|
068fa3d0e3 | ||
|
|
13d8dbd542 | ||
|
|
e59cc3311d | ||
|
|
4f4e94f753 | ||
|
|
258970f489 | ||
|
|
781abe5f8f | ||
|
|
b442ba8b2b | ||
|
|
10e36d2355 | ||
|
|
13c53fedad | ||
|
|
4bda1bd884 | ||
|
|
3abe7850d6 | ||
|
|
b50284d864 | ||
|
|
81c6e52401 | ||
|
|
847d257366 | ||
|
|
687662cf1f | ||
|
|
6432d98469 | ||
|
|
088ccf8b8d | ||
|
|
e8683bf957 | ||
|
|
4653981b6b | ||
|
|
e2547413d3 | ||
|
|
ea17f41b5b | ||
|
|
29178d8adf | ||
|
|
7e86ead574 | ||
|
|
72debcb228 | ||
|
|
72737dabc7 | ||
|
|
f6e5cb4381 | ||
|
|
ffad3b5fb1 | ||
|
|
cba9fc3020 | ||
|
|
e776accaf3 | ||
|
|
3592240d14 | ||
|
|
3eac26929a | ||
|
|
4d3adec738 | ||
|
|
ac5dd1f45a | ||
|
|
3005cf3282 | ||
|
|
54b272206e | ||
|
|
685f199f91 | ||
|
|
89bed479e4 | ||
|
|
5547247aa9 | ||
|
|
faa2d00cc6 | ||
|
|
fb8a356616 | ||
|
|
0c3aa8f5ec | ||
|
|
e2fd3f2983 | ||
|
|
f137af4ec5 | ||
|
|
fdd673a3a9 | ||
|
|
aed9955105 | ||
|
|
22f6d285c7 | ||
|
|
10aa16b471 | ||
|
|
3d761a3189 | ||
|
|
e3903f34e4 | ||
|
|
f4f055fb36 | ||
|
|
b3838581fd | ||
|
|
affbe7ccdb | ||
|
|
8563ae5511 | ||
|
|
2c765ccfae | ||
|
|
626e7b2211 | ||
|
|
516b6b0fa8 | ||
|
|
613d086f1e | ||
|
|
9e0630f012 | ||
|
|
d6d9554954 | ||
|
|
dd8577f832 | ||
|
|
2a532ab729 | ||
|
|
03eef65b25 | ||
|
|
ad07d63994 | ||
|
|
8685f055ea | ||
|
|
3b868a1cec | ||
|
|
ab389eaa8e | ||
|
|
008f778e8f | ||
|
|
d7f5da5df4 | ||
|
|
9fda130b3a | ||
|
|
72cdbdba0f | ||
|
|
b92a153902 | ||
|
|
9f2927979b | ||
|
|
75257232c3 | ||
|
|
1721314c62 | ||
|
|
fc230bcc59 | ||
|
|
b4636ddf44 | ||
|
|
f16151ea29 | ||
|
|
b1140301a4 | ||
|
|
58cd785da6 | ||
|
|
2035186cd2 | ||
|
|
53ba6aadff | ||
|
|
aa44c38b58 | ||
|
|
f091868b7c | ||
|
|
89bedae0d3 | ||
|
|
c8acc48976 | ||
|
|
21fee59b22 | ||
|
|
957a8253f8 | ||
|
|
d5fc3e7bed | ||
|
|
ab438b42da | ||
|
|
3867fece4a | ||
|
|
2b908d4fbe | ||
|
|
8ff062ec8b | ||
|
|
294fc41aec | ||
|
|
684f7df158 | ||
|
|
c3287755e3 | ||
|
|
9f97f4d79e | ||
|
|
34eb421649 | ||
|
|
850b05573e | ||
|
|
6ec8bfdfee | ||
|
|
81638c248e | ||
|
|
2e11b1298e | ||
|
|
20320f3a27 | ||
|
|
4019c12d26 | ||
|
|
cf72184ce4 | ||
|
|
ca8d15bc64 | ||
|
|
a91c897fd3 | ||
|
|
816bdf0320 | ||
|
|
d4a6acbd99 | ||
|
|
e421db4005 | ||
|
|
6af168cb31 | ||
|
|
29f56cf0cf | ||
|
|
11b6ea742d | ||
|
|
05d231ad33 | ||
|
|
48f3c69c69 | ||
|
|
9067c2a9c1 | ||
|
|
8b68020453 | ||
|
|
9f7321ca1a | ||
|
|
5fa01132b9 | ||
|
|
4d2fc66a8d | ||
|
|
f72ed4898c | ||
|
|
e082b6d599 | ||
|
|
d44be2d835 | ||
|
|
85a73181cc | ||
|
|
e31e4ab677 | ||
|
|
0d95c2192e | ||
|
|
7dc8557033 | ||
|
|
1fa8b26e55 | ||
|
|
4b085d46f6 | ||
|
|
72037a1865 | ||
|
|
635c4ed4ce | ||
|
|
7ffcf8dd6f | ||
|
|
97cd21d3be | ||
|
|
a13cb7e1c5 | ||
|
|
7b602e9003 | ||
|
|
5a26ebec8f | ||
|
|
8341b8b1c1 | ||
|
|
bbb640c9a2 | ||
|
|
0c97bbf137 | ||
|
|
45fddc70d5 | ||
|
|
f977dc410a | ||
|
|
d535818505 | ||
|
|
fcf4e1f37d | ||
|
|
38130c8502 | ||
|
|
f284c91988 | ||
|
|
584b2cefa3 | ||
|
|
42091b4a79 | ||
|
|
2d1621c43d | ||
|
|
d1a5db3310 | ||
|
|
ad8fd8fecc | ||
|
|
be74b76079 | ||
|
|
dd64af728f | ||
|
|
e43b46786d | ||
|
|
3f3b37b843 | ||
|
|
2ecf9f6ddf | ||
|
|
48c069fe68 | ||
|
|
9c5c597c85 | ||
|
|
c2eec8545d | ||
|
|
2395d4be26 | ||
|
|
9455476705 | ||
|
|
494e223706 | ||
|
|
348fd18230 | ||
|
|
7233b4de55 | ||
|
|
af6df05685 | ||
|
|
965b65db6e | ||
|
|
4cc01c8aa8 | ||
|
|
41372168b6 | ||
|
|
f4438b0a08 | ||
|
|
897c842637 | ||
|
|
ee86ceb906 | ||
|
|
e298732499 | ||
|
|
4081937e22 | ||
|
|
f9aedb2118 | ||
|
|
74b4719af8 | ||
|
|
2f35cc9188 | ||
|
|
2f966d8c38 | ||
|
|
b0868d9136 | ||
|
|
37440e9416 | ||
|
|
0d7d27ec0b |
@@ -11,7 +11,7 @@
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
},
|
||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
||||
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||
"package": "typescript",
|
||||
"version": "latest"
|
||||
},
|
||||
|
||||
@@ -6,7 +6,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@@ -103,6 +103,11 @@ jobs:
|
||||
run: |
|
||||
pnpm run lint
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -97,6 +97,7 @@ __pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat-schedule.db
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
|
||||
9
.vscode/launch.json.template
vendored
9
.vscode/launch.json.template
vendored
@@ -8,8 +8,7 @@
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
@@ -28,9 +27,7 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
@@ -40,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
10
README.md
10
README.md
@@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
<br/>
|
||||
|
||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
@@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
- **Cloud <br/>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
- **Self-hosting Dify Community Edition<br/>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
@@ -156,6 +159,9 @@ SUPABASE_URL=your-server-url
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
|
||||
# Provide the registrable domain (e.g. example.com); leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
@@ -368,6 +374,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@@ -605,3 +617,6 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -15,7 +15,11 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@@ -49,7 +53,9 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
39
api/app.py
39
api/app.py
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@@ -8,33 +9,35 @@ def is_db_command():
|
||||
|
||||
|
||||
# create app
|
||||
celery = None
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = flask_app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
|
||||
@@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> DifyApp:
|
||||
def create_app() -> tuple[any, DifyApp]:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
import socketio
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return app
|
||||
return socketio_app, app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
|
||||
@@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
if not datasets.items:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
@@ -1599,7 +1601,7 @@ def transform_datasource_credentials():
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
|
||||
@@ -331,12 +331,42 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP-related configurations for the application
|
||||
"""
|
||||
|
||||
COOKIE_DOMAIN: str = Field(
|
||||
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||
default="",
|
||||
)
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="Enable or disable gzip compression for HTTP responses",
|
||||
default=False,
|
||||
@@ -841,6 +871,16 @@ class MailConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
@@ -915,6 +955,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
@@ -1048,6 +1093,13 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1107,6 +1159,13 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantSelfTaskQueueConfig(BaseSettings):
|
||||
TENANT_SELF_TASK_QUEUE_PULL_SIZE: int = Field(
|
||||
description="Default batch size for tenant self task queue pull operations",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -1131,11 +1190,13 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantSelfTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
HOSTED_POOL_CREDITS: int = Field(
|
||||
description="Pool credits for hosted service",
|
||||
default=200,
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
@@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-4-turbo,"
|
||||
"gpt-4.1,"
|
||||
"gpt-4.1-2025-04-14,"
|
||||
"gpt-4.1-mini,"
|
||||
"gpt-4.1-mini-2025-04-14,"
|
||||
"gpt-4.1-nano,"
|
||||
"gpt-4.1-nano-2025-04-14,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted OpenAI service usage",
|
||||
default=200,
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003,"
|
||||
"chatgpt-4o-latest,"
|
||||
"gpt-4o,"
|
||||
"gpt-4o-2024-05-13,"
|
||||
"gpt-4o-2024-08-06,"
|
||||
"gpt-4o-2024-11-20,"
|
||||
"gpt-4o-audio-preview,"
|
||||
"gpt-4o-audio-preview-2025-06-03,"
|
||||
"gpt-4o-mini,"
|
||||
"gpt-4o-mini-2024-07-18,"
|
||||
"o3-mini,"
|
||||
"o3-mini-2025-01-31,"
|
||||
"gpt-5-mini-2025-08-07,"
|
||||
"gpt-5-mini,"
|
||||
"o4-mini,"
|
||||
"o4-mini-2025-04-16,"
|
||||
"gpt-5-chat-latest,"
|
||||
"gpt-5,"
|
||||
"gpt-5-2025-08-07,"
|
||||
"gpt-5-nano,"
|
||||
"gpt-5-nano-2025-08-07",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
@@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-4-turbo,"
|
||||
"gpt-4.1,"
|
||||
"gpt-4.1-2025-04-14,"
|
||||
"gpt-4.1-mini,"
|
||||
"gpt-4.1-mini-2025-04-14,"
|
||||
"gpt-4.1-nano,"
|
||||
"gpt-4.1-nano-2025-04-14,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
@@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003",
|
||||
"text-davinci-003,"
|
||||
"chatgpt-4o-latest,"
|
||||
"gpt-4o,"
|
||||
"gpt-4o-2024-05-13,"
|
||||
"gpt-4o-2024-08-06,"
|
||||
"gpt-4o-2024-11-20,"
|
||||
"gpt-4o-audio-preview,"
|
||||
"gpt-4o-audio-preview-2025-06-03,"
|
||||
"gpt-4o-mini,"
|
||||
"gpt-4o-mini-2024-07-18,"
|
||||
"o3-mini,"
|
||||
"o3-mini-2025-01-31,"
|
||||
"gpt-5-mini-2025-08-07,"
|
||||
"gpt-5-mini,"
|
||||
"o4-mini,"
|
||||
"o4-mini-2025-04-16,"
|
||||
"gpt-5-chat-latest,"
|
||||
"gpt-5,"
|
||||
"gpt-5-2025-08-07,"
|
||||
"gpt-5-nano,"
|
||||
"gpt-5-nano-2025-08-07",
|
||||
)
|
||||
|
||||
|
||||
class HostedGeminiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Gemini service
|
||||
"""
|
||||
|
||||
HOSTED_GEMINI_API_KEY: str | None = Field(
|
||||
description="API key for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Gemini API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
|
||||
class HostedXAIConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching XAI service
|
||||
"""
|
||||
|
||||
HOSTED_XAI_API_KEY: str | None = Field(
|
||||
description="API key for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted XAI API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedDeepseekConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Deepseek service
|
||||
"""
|
||||
|
||||
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
|
||||
description="API key for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Deepseek API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Deepseek service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
|
||||
@@ -144,16 +326,30 @@ class HostedAnthropicConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted Anthropic service usage",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
@@ -250,5 +446,8 @@ class HostedServiceConfig(
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
HostedGeminiConfig,
|
||||
HostedXAIConfig,
|
||||
HostedDeepseekConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
|
||||
@@ -56,11 +56,15 @@ else:
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||
|
||||
# webapp
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
|
||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||
|
||||
@@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@@ -58,11 +58,13 @@ from .app import (
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
online_user,
|
||||
ops_trace,
|
||||
site,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@@ -106,10 +108,12 @@ from .datasets.rag_pipeline import (
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
@@ -143,6 +147,7 @@ __all__ = [
|
||||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
@@ -196,6 +201,7 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trial",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
|
||||
@@ -16,7 +16,7 @@ from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@@ -52,6 +52,8 @@ class InsertExploreAppListApi(Resource):
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
"can_trial": fields.Boolean(required=True, description="Can trial"),
|
||||
"trial_limit": fields.Integer(required=True, description="Trial limit"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -71,6 +73,8 @@ class InsertExploreAppListApi(Resource):
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
|
||||
.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -108,6 +112,20 @@ class InsertExploreAppListApi(Resource):
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
@@ -122,6 +140,20 @@ class InsertExploreAppListApi(Resource):
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
@@ -167,7 +199,83 @@ class InsertExploreAppApi(Resource):
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBanner(Resource):
|
||||
@api.doc("insert_explore_banner")
|
||||
@api.doc(description="Insert an explore banner")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InsertExploreBannerRequest",
|
||||
{
|
||||
"content": fields.String(required=True, description="Banner content"),
|
||||
"link": fields.String(required=True, description="Banner link"),
|
||||
"sort": fields.Integer(required=True, description="Banner sort"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Banner inserted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("title", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("img-src", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
parser.add_argument("language", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
content = {
|
||||
"category": args["category"],
|
||||
"title": args["title"],
|
||||
"description": args["description"],
|
||||
"img-src": args["img-src"],
|
||||
}
|
||||
|
||||
if not args["language"]:
|
||||
args["language"] = "en-US"
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
link=args["link"],
|
||||
sort=args["sort"],
|
||||
language=args["language"],
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBanner(Resource):
|
||||
@api.doc("delete_explore_banner")
|
||||
@api.doc(description="Delete an explore banner")
|
||||
@api.response(204, "Banner deleted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -16,6 +16,7 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -175,8 +176,10 @@ class AnnotationApi(Resource):
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
@@ -193,11 +196,14 @@ class AnnotationApi(Resource):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
@@ -19,7 +17,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
@@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
|
||||
@@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
@@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
@@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
@@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@marshal_with(annotation_fields)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
|
||||
339
api/controllers/console/app/online_user.py
Normal file
339
api/controllers/console/app/online_user.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from werkzeug.wrappers import Request as WerkzeugRequest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from services.account_service import AccountService
|
||||
|
||||
SESSION_STATE_TTL_SECONDS = 3600
|
||||
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
|
||||
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
|
||||
WS_SID_MAP_PREFIX = "ws_sid_map:"
|
||||
|
||||
|
||||
def _workflow_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
|
||||
|
||||
|
||||
def _leader_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
|
||||
|
||||
|
||||
def _sid_key(sid: str) -> str:
|
||||
return f"{WS_SID_MAP_PREFIX}{sid}"
|
||||
|
||||
|
||||
def _refresh_session_state(workflow_id: str, sid: str) -> None:
|
||||
"""
|
||||
Refresh TTLs for workflow + session keys so healthy sessions do not linger forever after crashes.
|
||||
"""
|
||||
workflow_key = _workflow_key(workflow_id)
|
||||
sid_key = _sid_key(sid)
|
||||
if redis_client.exists(workflow_key):
|
||||
redis_client.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
|
||||
if redis_client.exists(sid_key):
|
||||
redis_client.expire(sid_key, SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
|
||||
@sio.on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
token = None
|
||||
if auth and isinstance(auth, dict):
|
||||
token = auth.get("token")
|
||||
|
||||
if not token:
|
||||
try:
|
||||
request_environ = WerkzeugRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@sio.on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
session = sio.get_session(sid)
|
||||
user_id = session.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
# Each session is stored independently with sid as key
|
||||
session_info = {
|
||||
"user_id": user_id,
|
||||
"username": session.get("username", "Unknown"),
|
||||
"avatar": session.get("avatar", None),
|
||||
"sid": sid,
|
||||
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
|
||||
}
|
||||
|
||||
workflow_key = _workflow_key(workflow_id)
|
||||
# Store session info with sid as key
|
||||
redis_client.hset(workflow_key, sid, json.dumps(session_info))
|
||||
redis_client.set(
|
||||
_sid_key(sid),
|
||||
json.dumps({"workflow_id": workflow_id, "user_id": user_id}),
|
||||
ex=SESSION_STATE_TTL_SECONDS,
|
||||
)
|
||||
_refresh_session_state(workflow_id, sid)
|
||||
|
||||
# Leader election: first session becomes the leader
|
||||
leader_sid = get_or_set_leader(workflow_id, sid)
|
||||
is_leader = leader_sid == sid
|
||||
|
||||
sio.enter_room(sid, workflow_id)
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
# Notify this session of their leader status
|
||||
sio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@sio.on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
mapping = redis_client.get(_sid_key(sid))
|
||||
if mapping:
|
||||
data = json.loads(mapping)
|
||||
workflow_id = data["workflow_id"]
|
||||
|
||||
# Remove this specific session
|
||||
redis_client.hdel(_workflow_key(workflow_id), sid)
|
||||
redis_client.delete(_sid_key(sid))
|
||||
|
||||
# Handle leader re-election if the leader session disconnected
|
||||
handle_leader_disconnect(workflow_id, sid)
|
||||
|
||||
broadcast_online_users(workflow_id)
|
||||
|
||||
|
||||
def _clear_session_state(workflow_id: str, sid: str) -> None:
|
||||
redis_client.hdel(_workflow_key(workflow_id), sid)
|
||||
redis_client.delete(_sid_key(sid))
|
||||
|
||||
|
||||
def _is_session_active(workflow_id: str, sid: str) -> bool:
|
||||
if not sid:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not sio.manager.is_connected(sid, "/"):
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if not redis_client.hexists(_workflow_key(workflow_id), sid):
|
||||
return False
|
||||
|
||||
if not redis_client.exists(_sid_key(sid)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_or_set_leader(workflow_id: str, sid: str) -> str:
|
||||
"""
|
||||
Get current leader session or set this session as leader if no valid leader exists.
|
||||
Returns the leader session id (sid).
|
||||
"""
|
||||
raw_leader = redis_client.get(_leader_key(workflow_id))
|
||||
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
|
||||
leader_replaced = False
|
||||
|
||||
if current_leader and not _is_session_active(workflow_id, current_leader):
|
||||
_clear_session_state(workflow_id, current_leader)
|
||||
redis_client.delete(_leader_key(workflow_id))
|
||||
current_leader = None
|
||||
leader_replaced = True
|
||||
|
||||
if not current_leader:
|
||||
redis_client.set(_leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) # Expire in 1 hour
|
||||
if leader_replaced:
|
||||
broadcast_leader_change(workflow_id, sid)
|
||||
return sid
|
||||
|
||||
return current_leader
|
||||
|
||||
|
||||
def handle_leader_disconnect(workflow_id, disconnected_sid):
|
||||
"""
|
||||
Handle leader re-election when a session disconnects.
|
||||
If the disconnected session was the leader, elect a new leader from remaining sessions.
|
||||
"""
|
||||
current_leader = redis_client.get(_leader_key(workflow_id))
|
||||
|
||||
if current_leader:
|
||||
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
|
||||
|
||||
if current_leader == disconnected_sid:
|
||||
# Leader session disconnected, elect a new leader
|
||||
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
|
||||
|
||||
if sessions_json:
|
||||
# Get the first remaining session as new leader
|
||||
new_leader_sid = list(sessions_json.keys())[0]
|
||||
if isinstance(new_leader_sid, bytes):
|
||||
new_leader_sid = new_leader_sid.decode("utf-8")
|
||||
|
||||
redis_client.set(_leader_key(workflow_id), new_leader_sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
# Notify all sessions about the new leader
|
||||
broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
else:
|
||||
# No sessions left, remove leader
|
||||
redis_client.delete(_leader_key(workflow_id))
|
||||
|
||||
|
||||
def broadcast_leader_change(workflow_id, new_leader_sid):
|
||||
"""
|
||||
Broadcast leader change to all sessions in the workflow.
|
||||
"""
|
||||
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
|
||||
|
||||
for sid, session_info_json in sessions_json.items():
|
||||
try:
|
||||
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
|
||||
is_leader = sid_str == new_leader_sid
|
||||
# Emit to each session whether they are the new leader
|
||||
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def get_current_leader(workflow_id):
|
||||
"""
|
||||
Get the current leader for a workflow.
|
||||
"""
|
||||
leader = redis_client.get(_leader_key(workflow_id))
|
||||
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
|
||||
|
||||
|
||||
def broadcast_online_users(workflow_id):
|
||||
"""
|
||||
Broadcast online users to the workflow room.
|
||||
Each session is shown as a separate user (even if same person has multiple tabs).
|
||||
"""
|
||||
sessions_json = redis_client.hgetall(_workflow_key(workflow_id))
|
||||
users = []
|
||||
|
||||
for sid, session_info_json in sessions_json.items():
|
||||
try:
|
||||
session_info = json.loads(session_info_json)
|
||||
# Each session appears as a separate "user" in the UI
|
||||
users.append(
|
||||
{
|
||||
"user_id": session_info["user_id"],
|
||||
"username": session_info["username"],
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": session_info["sid"],
|
||||
"connected_at": session_info.get("connected_at"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Sort by connection time to maintain consistent order
|
||||
users.sort(key=lambda x: x.get("connected_at") or 0)
|
||||
|
||||
# Get current leader session
|
||||
leader_sid = get_current_leader(workflow_id)
|
||||
|
||||
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
|
||||
|
||||
|
||||
@sio.on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
|
||||
"""
|
||||
mapping = redis_client.get(_sid_key(sid))
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
user_id = mapping_data["user_id"]
|
||||
_refresh_session_state(workflow_id, sid)
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
if not event_type:
|
||||
return {"msg": "invalid event type"}, 400
|
||||
|
||||
sio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=workflow_id,
|
||||
skip_sid=sid,
|
||||
)
|
||||
|
||||
return {"msg": "event_broadcasted"}
|
||||
|
||||
|
||||
@sio.on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
mapping = redis_client.get(_sid_key(sid))
|
||||
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
mapping_data = json.loads(mapping)
|
||||
workflow_id = mapping_data["workflow_id"]
|
||||
_refresh_session_state(workflow_id, sid)
|
||||
|
||||
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "graph_update_broadcasted"}
|
||||
@@ -1,9 +1,7 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
@@ -11,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
@@ -56,26 +55,16 @@ WHERE
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -120,8 +109,11 @@ class DailyConversationStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
@@ -134,18 +126,10 @@ class DailyConversationStatistic(Resource):
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
if start_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
if end_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
@@ -198,26 +182,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -273,26 +248,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -357,26 +323,17 @@ FROM
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -446,26 +403,17 @@ WHERE
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -525,26 +473,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@@ -602,26 +541,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
@@ -5,10 +5,12 @@ from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from pydantic_core import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -21,7 +23,9 @@ from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@@ -99,10 +103,22 @@ class DraftWorkflowApi(Resource):
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
||||
@api.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
api.model(
|
||||
"SyncDraftWorkflowResponse",
|
||||
{
|
||||
"result": fields.String,
|
||||
"hash": fields.String,
|
||||
"updated_at": fields.String,
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid workflow configuration")
|
||||
@api.response(403, "Permission denied")
|
||||
@edit_permission_required
|
||||
@@ -122,6 +138,8 @@ class DraftWorkflowApi(Resource):
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
.add_argument("force_upload", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("memory_blocks", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
@@ -139,6 +157,8 @@ class DraftWorkflowApi(Resource):
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"memory_blocks": data.get("memory_blocks"),
|
||||
"force_upload": data.get("force_upload", False),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
@@ -155,6 +175,10 @@ class DraftWorkflowApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
memory_blocks_list = args.get("memory_blocks") or []
|
||||
from core.memory.entities import MemoryBlockSpec
|
||||
|
||||
memory_blocks = [MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
@@ -163,9 +187,13 @@ class DraftWorkflowApi(Resource):
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
force_upload=args.get("force_upload", False),
|
||||
memory_blocks=memory_blocks,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValidationError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@@ -734,6 +762,45 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@api.doc("get_workflow_config")
|
||||
@api.doc(description="Get workflow configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Workflow configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("features", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
features = args.get("features")
|
||||
|
||||
# Update draft workflow features
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.doc("get_all_published_workflows")
|
||||
@@ -915,3 +982,105 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_ids", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowFeaturesApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/features",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowTaskStopApi,
|
||||
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigsApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToWorkflowApi,
|
||||
"/apps/<uuid:app_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeLastRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")
|
||||
|
||||
240
api/controllers/console/app/workflow_comment.py
Normal file
240
api/controllers/console/app/workflow_comment.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import account_with_role_fields
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_basic_fields, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_create_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("position_x", type=float, required=True, location="json")
|
||||
parser.add_argument("position_y", type=float, required=True, location="json")
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=args.content,
|
||||
position_x=args.position_x,
|
||||
position_y=args.position_y,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_detail_fields)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_update_fields)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("position_x", type=float, required=False, location="json")
|
||||
parser.add_argument("position_y", type=float, required=False, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=args.content,
|
||||
position_x=args.position_x,
|
||||
position_y=args.position_y,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_resolve_fields)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_reply_create_fields)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=args.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=args.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(workflow_comment_reply_update_fields)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, location="json")
|
||||
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"users": members}
|
||||
|
||||
|
||||
# Register API routes
|
||||
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
|
||||
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
api.add_resource(
|
||||
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
|
||||
)
|
||||
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
@@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
@@ -355,7 +355,7 @@ class VariableApi(Resource):
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
@@ -448,8 +448,35 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_variables", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@api.doc("get_system_variables")
|
||||
@api.doc(description="Get system variables for workflow")
|
||||
@@ -499,3 +526,44 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
|
||||
@@ -1,23 +1,26 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_runs_statistic")
|
||||
@api.doc(description="Get workflow daily runs statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@@ -37,57 +40,32 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_terminals_statistic")
|
||||
@api.doc(description="Get workflow daily terminals statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@@ -107,57 +85,32 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||
@api.doc(description="Get workflow daily token cost statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@@ -177,62 +130,32 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||
@api.doc(description="Get workflow average app interaction statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@@ -252,67 +175,20 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
@@ -23,6 +23,15 @@ def _load_app_model(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
@@ -62,3 +71,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model_with_trial(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
modes = mode
|
||||
else:
|
||||
modes = [mode]
|
||||
|
||||
if app_mode not in modes:
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@@ -25,10 +25,14 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.passport import PassportService
|
||||
from libs.token import (
|
||||
check_csrf_token,
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_access_token,
|
||||
extract_refresh_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
@@ -270,7 +274,7 @@ class EmailCodeLoginApi(Resource):
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = request.cookies.get("refresh_token")
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
if not refresh_token:
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
@@ -288,3 +292,18 @@ class RefreshTokenApi(Resource):
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
# this api helps frontend to check whether user is authenticated
|
||||
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
|
||||
@console_ns.route("/login/status")
|
||||
class LoginStatus(Resource):
|
||||
def get(self):
|
||||
token = extract_access_token(request)
|
||||
res = True
|
||||
try:
|
||||
validated = PassportService().verify(token=token)
|
||||
check_csrf_token(request=request, user_id=validated.get("user_id"))
|
||||
except:
|
||||
res = False
|
||||
return {"logged_in": res}
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@@ -16,7 +17,13 @@ class Subscription(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
@@ -12,9 +12,17 @@ from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -26,12 +34,6 @@ class DataSourceContentPreviewApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
|
||||
43
api/controllers/console/explore/banner.py
Normal file
43
api/controllers/console/explore/banner.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"id": banner.id,
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
@@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
||||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
||||
@@ -27,6 +27,7 @@ recommended_app_fields = {
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
||||
514
api/controllers/console/explore/trial.py
Normal file
514
api/controllers/console/explore/trial.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common import fields
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NeedAddIdsError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model_with_trial
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotWorkflowAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
|
||||
tenant_id = app_model.tenant_id
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
|
||||
else:
|
||||
raise NeedAddIdsError()
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
|
||||
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
@@ -22,7 +22,7 @@ from core.errors.error import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user as current_user_
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@@ -31,8 +31,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_user = current_user_._get_current_object() # type: ignore
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
@@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
@@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
@@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
|
||||
@@ -2,14 +2,16 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import InstalledApp
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -71,6 +73,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
@@ -80,3 +135,13 @@ class InstalledAppResource(Resource):
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
@@ -66,13 +66,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
.add_argument("api_key", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = api.payload
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
@@ -125,13 +119,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
.add_argument("api_key", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = api.payload
|
||||
|
||||
extension_data_from_db.name = args["name"]
|
||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||
|
||||
@@ -8,6 +8,7 @@ import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
@@ -39,6 +40,7 @@ class FileApi(Resource):
|
||||
return {
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
@@ -82,6 +84,8 @@ class FileApi(Resource):
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from controllers.console.wraps import (
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -128,6 +129,17 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("avatar", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask_restx import (
|
||||
Resource,
|
||||
reqparse,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -15,20 +16,21 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@@ -42,7 +44,9 @@ def is_valid_url(url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except Exception:
|
||||
except (ValueError, TypeError):
|
||||
# ValueError: Invalid URL format
|
||||
# TypeError: url is not a string
|
||||
return False
|
||||
|
||||
|
||||
@@ -886,29 +890,34 @@ class ToolProviderMCPApi(Resource):
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args["timeout"],
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -923,31 +932,43 @@ class ToolProviderMCPApi(Resource):
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
MCPToolManageService.update_mcp_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
headers=args.get("headers"),
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -958,8 +979,11 @@ class ToolProviderMCPApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
@@ -976,37 +1000,53 @@ class ToolMCPAuthApi(Resource):
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
try:
|
||||
with MCPClient(
|
||||
provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
for_list=True,
|
||||
headers=provider.decrypted_headers,
|
||||
timeout=provider.timeout,
|
||||
sse_read_timeout=provider.sse_read_timeout,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials=provider.decrypted_credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
except MCPAuthError:
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
|
||||
# Try to connect without active transaction
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Update credentials in new transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider_credentials(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials={},
|
||||
authed=False,
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@@ -1017,8 +1057,10 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/mcp")
|
||||
@@ -1029,9 +1071,12 @@ class ToolMCPListAllApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
# Skip sensitive data decryption for list view to improve performance
|
||||
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
|
||||
|
||||
return [tool.to_dict() for tool in tools]
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
@@ -1041,11 +1086,13 @@ class ToolMCPUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
@@ -1059,5 +1106,15 @@ class ToolMCPCallbackApi(Resource):
|
||||
args = parser.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
|
||||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(
|
||||
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
@@ -21,6 +21,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@@ -50,6 +51,9 @@ tenant_fields = {
|
||||
"in_trial": fields.Boolean,
|
||||
"trial_end_reason": fields.String,
|
||||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
"trial_credits": fields.Integer,
|
||||
"trial_credits_used": fields.Integer,
|
||||
"next_credit_reset_date": fields.Integer,
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
@@ -83,7 +87,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
@@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
|
||||
@@ -14,10 +14,25 @@ from services.file_service import FileService
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||
class ImagePreviewApi(Resource):
|
||||
"""
|
||||
Deprecated
|
||||
"""
|
||||
"""Deprecated endpoint for retrieving image previews."""
|
||||
|
||||
@files_ns.doc("get_image_preview")
|
||||
@files_ns.doc(description="Retrieve a signed image preview for a file")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"file_id": "ID of the file to preview",
|
||||
"timestamp": "Unix timestamp used in the signature",
|
||||
"nonce": "Random string used in the signature",
|
||||
"sign": "HMAC signature verifying the request",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "Image preview returned successfully",
|
||||
400: "Missing or invalid signature parameters",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
@@ -43,6 +58,25 @@ class ImagePreviewApi(Resource):
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/file-preview")
|
||||
class FilePreviewApi(Resource):
|
||||
@files_ns.doc("get_file_preview")
|
||||
@files_ns.doc(description="Download a file preview or attachment using signed parameters")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"file_id": "ID of the file to preview",
|
||||
"timestamp": "Unix timestamp used in the signature",
|
||||
"nonce": "Random string used in the signature",
|
||||
"sign": "HMAC signature verifying the request",
|
||||
"as_attachment": "Whether to download the file as an attachment",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "File stream returned successfully",
|
||||
400: "Missing or invalid signature parameters",
|
||||
404: "File not found",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
@@ -101,6 +135,20 @@ class FilePreviewApi(Resource):
|
||||
|
||||
@files_ns.route("/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
@files_ns.doc("get_workspace_webapp_logo")
|
||||
@files_ns.doc(description="Fetch the custom webapp logo for a workspace")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"workspace_id": "Workspace identifier",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "Logo returned successfully",
|
||||
404: "Webapp logo not configured",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
|
||||
@@ -13,6 +13,26 @@ from extensions.ext_database import db as global_db
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
class ToolFileApi(Resource):
|
||||
@files_ns.doc("get_tool_file")
|
||||
@files_ns.doc(description="Download a tool file by ID using signed parameters")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"file_id": "Tool file identifier",
|
||||
"extension": "Expected file extension",
|
||||
"timestamp": "Unix timestamp used in the signature",
|
||||
"nonce": "Random string used in the signature",
|
||||
"sign": "HMAC signature verifying the request",
|
||||
"as_attachment": "Whether to download the file as an attachment",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "Tool file stream returned successfully",
|
||||
403: "Forbidden - invalid signature",
|
||||
404: "File not found",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from .app import (
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
chatflow_memory,
|
||||
completion,
|
||||
conversation,
|
||||
file,
|
||||
@@ -40,6 +41,7 @@ __all__ = [
|
||||
"annotation",
|
||||
"app",
|
||||
"audio",
|
||||
"chatflow_memory",
|
||||
"completion",
|
||||
"conversation",
|
||||
"dataset",
|
||||
|
||||
109
api/controllers/service_api/app/chatflow_memory.py
Normal file
109
api/controllers/service_api/app/chatflow_memory.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.memory.entities import MemoryBlock, MemoryCreatedBy
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from models import App, EndUser
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MemoryListApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("memory_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("version", required=False, type=int | None, default=None)
|
||||
args = parser.parse_args()
|
||||
conversation_id: str | None = args.get("conversation_id")
|
||||
memory_id = args.get("memory_id")
|
||||
version = args.get("version")
|
||||
|
||||
if conversation_id:
|
||||
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
result = [*result, *session_memories]
|
||||
else:
|
||||
result = ChatflowMemoryService.get_persistent_memories(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), version
|
||||
)
|
||||
|
||||
if memory_id:
|
||||
result = [it for it in result if it.spec.id == memory_id]
|
||||
return [it for it in result if it.spec.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def put(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True)
|
||||
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("node_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("update", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow = WorkflowService().get_published_workflow(app_model)
|
||||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {"error": "Invalid update"}, 400
|
||||
if not workflow:
|
||||
return {"error": "Workflow not found"}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args["id"]), None)
|
||||
if not memory_spec:
|
||||
return {"error": "Memory not found"}, 404
|
||||
|
||||
# First get existing memory
|
||||
existing_memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
# Create updated memory instance with incremented version
|
||||
updated_memory = MemoryBlock(
|
||||
spec=existing_memory.spec,
|
||||
tenant_id=existing_memory.tenant_id,
|
||||
app_id=existing_memory.app_id,
|
||||
conversation_id=existing_memory.conversation_id,
|
||||
node_id=existing_memory.node_id,
|
||||
value=update, # New value
|
||||
version=existing_memory.version + 1, # Increment version for update
|
||||
edited_by_user=True,
|
||||
created_by=existing_memory.created_by,
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
|
||||
return "", 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get("id")
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(app_model, memory_id, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(app_model, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, "/memories")
|
||||
api.add_resource(MemoryEditApi, "/memory-edit")
|
||||
api.add_resource(MemoryDeleteApi, "/memories")
|
||||
@@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
@@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -66,6 +67,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
@@ -74,7 +76,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
@@ -89,6 +90,28 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||
else:
|
||||
# For service API without end-user context, ensure an Account is logged in
|
||||
# so services relying on current_account_with_tenant() work correctly.
|
||||
tenant_owner_info = (
|
||||
db.session.query(Tenant, Account)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.join(Account, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
Tenant.id == app_model.tenant_id,
|
||||
TenantAccountJoin.role == "owner",
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if tenant_owner_info:
|
||||
tenant_model, account = tenant_owner_info
|
||||
account.current_tenant = tenant_model
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account not found or tenant is not active.")
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
@@ -138,7 +161,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ web_ns = Namespace("web", description="Web application API operations", path="/"
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
chatflow_memory,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
@@ -39,6 +40,7 @@ __all__ = [
|
||||
"app",
|
||||
"audio",
|
||||
"bp",
|
||||
"chatflow_memory",
|
||||
"completion",
|
||||
"conversation",
|
||||
"feature",
|
||||
|
||||
108
api/controllers/web/chatflow_memory.py
Normal file
108
api/controllers/web/chatflow_memory.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from flask_restx import reqparse
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.memory.entities import MemoryBlock, MemoryCreatedBy
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from models import App, EndUser
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class MemoryListApi(WebApiResource):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("memory_id", required=False, type=str | None, default=None)
|
||||
parser.add_argument("version", required=False, type=int | None, default=None)
|
||||
args = parser.parse_args()
|
||||
conversation_id: str | None = args.get("conversation_id")
|
||||
memory_id = args.get("memory_id")
|
||||
version = args.get("version")
|
||||
|
||||
if conversation_id:
|
||||
result = ChatflowMemoryService.get_persistent_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
session_memories = ChatflowMemoryService.get_session_memories_with_conversation(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), conversation_id, version
|
||||
)
|
||||
result = [*result, *session_memories]
|
||||
else:
|
||||
result = ChatflowMemoryService.get_persistent_memories(
|
||||
app_model, MemoryCreatedBy(end_user_id=end_user.id), version
|
||||
)
|
||||
|
||||
if memory_id:
|
||||
result = [it for it in result if it.spec.id == memory_id]
|
||||
return [it for it in result if it.spec.end_user_visible]
|
||||
|
||||
|
||||
class MemoryEditApi(WebApiResource):
|
||||
def put(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True)
|
||||
parser.add_argument("conversation_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("node_id", type=str | None, required=False, default=None)
|
||||
parser.add_argument("update", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow = WorkflowService().get_published_workflow(app_model)
|
||||
update = args.get("update")
|
||||
conversation_id = args.get("conversation_id")
|
||||
node_id = args.get("node_id")
|
||||
if not isinstance(update, str):
|
||||
return {"error": "Update must be a string"}, 400
|
||||
if not workflow:
|
||||
return {"error": "Workflow not found"}, 404
|
||||
memory_spec = next((it for it in workflow.memory_blocks if it.id == args["id"]), None)
|
||||
if not memory_spec:
|
||||
return {"error": "Memory not found"}, 404
|
||||
if not memory_spec.end_user_editable:
|
||||
return {"error": "Memory not editable"}, 403
|
||||
|
||||
# First get existing memory
|
||||
existing_memory = ChatflowMemoryService.get_memory_by_spec(
|
||||
spec=memory_spec,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=MemoryCreatedBy(end_user_id=end_user.id),
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
# Create updated memory instance with incremented version
|
||||
updated_memory = MemoryBlock(
|
||||
spec=existing_memory.spec,
|
||||
tenant_id=existing_memory.tenant_id,
|
||||
app_id=existing_memory.app_id,
|
||||
conversation_id=existing_memory.conversation_id,
|
||||
node_id=existing_memory.node_id,
|
||||
value=update, # New value
|
||||
version=existing_memory.version + 1, # Increment version for update
|
||||
edited_by_user=True,
|
||||
created_by=existing_memory.created_by,
|
||||
)
|
||||
|
||||
ChatflowMemoryService.save_memory(updated_memory, VariablePool(), False)
|
||||
return "", 204
|
||||
|
||||
|
||||
class MemoryDeleteApi(WebApiResource):
|
||||
def delete(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=False, default=None)
|
||||
args = parser.parse_args()
|
||||
memory_id = args.get("id")
|
||||
|
||||
if memory_id:
|
||||
ChatflowMemoryService.delete_memory(app_model, memory_id, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 204
|
||||
else:
|
||||
ChatflowMemoryService.delete_all_user_memories(app_model, MemoryCreatedBy(end_user_id=end_user.id))
|
||||
return "", 200
|
||||
|
||||
|
||||
api.add_resource(MemoryListApi, "/memories")
|
||||
api.add_resource(MemoryEditApi, "/memory-edit")
|
||||
api.add_resource(MemoryDeleteApi, "/memories")
|
||||
@@ -17,8 +17,8 @@ from libs.helper import email
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
extract_access_token,
|
||||
clear_webapp_access_token_from_cookie,
|
||||
extract_webapp_access_token,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
@@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
|
||||
)
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
token = extract_access_token(request)
|
||||
token = extract_webapp_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
"logged_in": bool(token),
|
||||
@@ -128,7 +128,7 @@ class LogoutApi(Resource):
|
||||
response = make_response({"result": "success"})
|
||||
# enterprise SSO sets same site to None in https deployment
|
||||
# so we need to logout by calling api
|
||||
clear_access_token_from_cookie(response, samesite="None")
|
||||
clear_webapp_access_token_from_cookie(response, samesite="None")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from controllers.web import web_ns
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from libs.token import extract_webapp_access_token
|
||||
from models.model import App, EndUser, Site
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
@@ -35,7 +35,7 @@ class PassportResource(Resource):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
user_id = request.args.get("user_id")
|
||||
access_token = extract_access_token(request)
|
||||
access_token = extract_webapp_access_token(request)
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
if system_features.webapp_auth.enabled:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -20,12 +21,16 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope
|
||||
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphRunSucceededEvent
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@@ -38,6 +43,8 @@ from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,11 +68,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@@ -78,6 +87,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def run(self):
|
||||
ChatflowMemoryService.wait_for_sync_memory_completion(
|
||||
workflow=self._workflow, conversation_id=self.conversation.id
|
||||
)
|
||||
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
@@ -140,6 +153,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
memory_blocks=self._fetch_memory_blocks(),
|
||||
)
|
||||
|
||||
# init graph
|
||||
@@ -195,12 +209,39 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
try:
|
||||
self._check_app_memory_updates(variable_pool)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to check app memory updates", exc_info=e)
|
||||
|
||||
@override
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: Any) -> None:
|
||||
super()._handle_event(workflow_entry, event)
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
workflow_outputs = event.outputs
|
||||
if not workflow_outputs:
|
||||
logger.warning("Chatflow output is empty.")
|
||||
return
|
||||
assistant_message = workflow_outputs.get("answer")
|
||||
if not assistant_message:
|
||||
logger.warning("Chatflow output does not contain 'answer'.")
|
||||
return
|
||||
if not isinstance(assistant_message, str):
|
||||
logger.warning("Chatflow output 'answer' is not a string.")
|
||||
return
|
||||
try:
|
||||
self._sync_conversation_to_chatflow_tables(assistant_message)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to sync conversation to memory tables", exc_info=e)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
@@ -398,3 +439,67 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# Return combined list
|
||||
return existing_variables + new_variables
|
||||
|
||||
def _fetch_memory_blocks(self) -> Mapping[str, str]:
|
||||
"""fetch all memory blocks for current app"""
|
||||
|
||||
memory_blocks_dict: MutableMapping[str, str] = {}
|
||||
is_draft = self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
conversation_id = self.conversation.id
|
||||
memory_block_specs = self._workflow.memory_blocks
|
||||
# Get runtime memory values
|
||||
memories = ChatflowMemoryService.get_memories_by_specs(
|
||||
memory_block_specs=memory_block_specs,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
node_id=None,
|
||||
conversation_id=conversation_id,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_created_by(),
|
||||
)
|
||||
|
||||
# Build memory_id -> value mapping
|
||||
for memory in memories:
|
||||
if memory.spec.scope == MemoryScope.APP:
|
||||
# App level: use memory_id directly
|
||||
memory_blocks_dict[memory.spec.id] = memory.value
|
||||
else: # NODE scope
|
||||
node_id = memory.node_id
|
||||
if not node_id:
|
||||
logger.warning("Memory block %s has no node_id, skip.", memory.spec.id)
|
||||
continue
|
||||
key = f"{node_id}.{memory.spec.id}"
|
||||
memory_blocks_dict[key] = memory.value
|
||||
|
||||
return memory_blocks_dict
|
||||
|
||||
def _sync_conversation_to_chatflow_tables(self, assistant_message: str):
|
||||
ChatflowHistoryService.save_app_message(
|
||||
prompt_message=UserPromptMessage(content=(self.application_generate_entity.query)),
|
||||
conversation_id=self.conversation.id,
|
||||
app_id=self._workflow.app_id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
ChatflowHistoryService.save_app_message(
|
||||
prompt_message=AssistantPromptMessage(content=assistant_message),
|
||||
conversation_id=self.conversation.id,
|
||||
app_id=self._workflow.app_id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
|
||||
def _check_app_memory_updates(self, variable_pool: VariablePool):
|
||||
is_draft = self.application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
|
||||
ChatflowMemoryService.update_app_memory_if_needed(
|
||||
workflow=self._workflow,
|
||||
conversation_id=self.conversation.id,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_created_by(),
|
||||
)
|
||||
|
||||
def _get_created_by(self) -> MemoryCreatedBy:
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
return MemoryCreatedBy(account_id=self.application_generate_entity.user_id)
|
||||
else:
|
||||
return MemoryCreatedBy(end_user_id=self.application_generate_entity.user_id)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
@@ -60,6 +61,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
@@ -391,6 +393,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if should_direct_answer:
|
||||
return
|
||||
|
||||
current_time = time.perf_counter()
|
||||
if self._task_state.first_token_time is None and delta_text.strip():
|
||||
self._task_state.first_token_time = current_time
|
||||
self._task_state.is_streaming_response = True
|
||||
|
||||
if delta_text.strip():
|
||||
self._task_state.last_token_time = current_time
|
||||
|
||||
# Only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
@@ -772,7 +782,33 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
# Set usage first before dumping metadata
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
usage = LLMUsage.empty_usage()
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self._base_task_pipeline.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@@ -790,20 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class AppRunner:
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: str | None = None,
|
||||
query: str = "",
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@@ -105,7 +105,7 @@ class AppRunner:
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
@@ -51,7 +51,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from services.variable_truncator import VariableTruncator
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
|
||||
@@ -70,6 +70,8 @@ class _NodeSnapshot:
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
_truncator: BaseTruncator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -81,7 +83,13 @@ class WorkflowResponseConverter:
|
||||
self._user = user
|
||||
self._system_variables = system_variables
|
||||
self._workflow_inputs = self._prepare_workflow_inputs()
|
||||
self._truncator = VariableTruncator.default()
|
||||
|
||||
# Disable truncation for SERVICE_API calls to keep backward compatibility.
|
||||
if application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
self._truncator = DummyVariableTruncator()
|
||||
else:
|
||||
self._truncator = VariableTruncator.default()
|
||||
|
||||
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
|
||||
self._workflow_execution_id: str | None = None
|
||||
self._workflow_started_at: datetime | None = None
|
||||
|
||||
@@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
|
||||
@@ -40,19 +40,15 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -255,7 +251,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
|
||||
@@ -135,6 +135,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
@@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
app_id: str,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
self._queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
self._graph_engine_layers = graph_engine_layers
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
|
||||
@@ -129,7 +129,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: str | None = None
|
||||
query: str = ""
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -48,6 +48,9 @@ class WorkflowTaskState(TaskState):
|
||||
"""
|
||||
|
||||
answer: str = ""
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class StreamEvent(StrEnum):
|
||||
|
||||
71
api/core/app/layers/pause_state_persist_layer.py
Normal file
71
api/core/app/layers/pause_state_persist_layer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
It generally should id of the creator of workflow.
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=self.graph_runtime_state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
@@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
@@ -140,7 +140,27 @@ class MessageCycleManager:
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
|
||||
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
|
||||
|
||||
# Add new unique resources from the event
|
||||
for resource in event.retriever_resources or []:
|
||||
if not resource:
|
||||
continue
|
||||
|
||||
is_duplicate = (
|
||||
resource.dataset_id
|
||||
and resource.document_id
|
||||
and (resource.dataset_id, resource.document_id) in existing_ids
|
||||
)
|
||||
|
||||
if not is_duplicate:
|
||||
merged_resources.append(resource)
|
||||
|
||||
for i, resource in enumerate(merged_resources, 1):
|
||||
resource.position = i
|
||||
|
||||
self._task_state.metadata.retriever_resources = merged_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
|
||||
"""
|
||||
|
||||
13
api/core/entities/document_task.py
Normal file
13
api/core/entities/document_task.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: list[str]
|
||||
329
api/core/entities/mcp_provider.py
Normal file
329
api/core/entities/mcp_provider.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
# Constants
|
||||
CLIENT_NAME = "Dify"
|
||||
CLIENT_URI = "https://github.com/langgenius/dify"
|
||||
DEFAULT_TOKEN_TYPE = "Bearer"
|
||||
DEFAULT_EXPIRES_IN = 3600
|
||||
MASK_CHAR = "*"
|
||||
MIN_UNMASK_LENGTH = 6
|
||||
|
||||
|
||||
class MCPSupportGrantType(StrEnum):
|
||||
"""The supported grant types for MCP"""
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
|
||||
|
||||
class MCPConfiguration(BaseModel):
|
||||
timeout: float = 30
|
||||
sse_read_timeout: float = 300
|
||||
|
||||
|
||||
class MCPProviderEntity(BaseModel):
|
||||
"""MCP Provider domain entity for business logic operations"""
|
||||
|
||||
# Basic identification
|
||||
id: str
|
||||
provider_id: str # server_identifier
|
||||
name: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
# Server connection info
|
||||
server_url: str # encrypted URL
|
||||
headers: dict[str, str] # encrypted headers
|
||||
timeout: float
|
||||
sse_read_timeout: float
|
||||
|
||||
# Authentication related
|
||||
authed: bool
|
||||
credentials: dict[str, Any] # encrypted credentials
|
||||
code_verifier: str | None = None # for OAuth
|
||||
|
||||
# Tools and display info
|
||||
tools: list[dict[str, Any]] # parsed tools list
|
||||
icon: str | dict[str, str] # parsed icon
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||
"""Create entity from database model with decryption"""
|
||||
|
||||
return cls(
|
||||
id=db_provider.id,
|
||||
provider_id=db_provider.server_identifier,
|
||||
name=db_provider.name,
|
||||
tenant_id=db_provider.tenant_id,
|
||||
user_id=db_provider.user_id,
|
||||
server_url=db_provider.server_url,
|
||||
headers=db_provider.headers,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
authed=db_provider.authed,
|
||||
credentials=db_provider.credentials,
|
||||
tools=db_provider.tool_dict,
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""OAuth redirect URL"""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
# Get grant type from credentials
|
||||
credentials = self.decrypt_credentials()
|
||||
|
||||
# Try to get grant_type from different locations
|
||||
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||
|
||||
# For nested structure, check if client_information has grant_types
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
client_info = credentials["client_information"]
|
||||
# If grant_types is specified in client_information, use it to determine grant_type
|
||||
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
|
||||
if "client_credentials" in client_info["grant_types"]:
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
elif "authorization_code" in client_info["grant_types"]:
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
|
||||
|
||||
# Configure based on grant type
|
||||
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
|
||||
grant_types = ["refresh_token"]
|
||||
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
|
||||
|
||||
response_types = [] if is_client_credentials else ["code"]
|
||||
redirect_uris = [] if is_client_credentials else [self.redirect_url]
|
||||
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=redirect_uris,
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=grant_types,
|
||||
response_types=response_types,
|
||||
client_name=CLIENT_NAME,
|
||||
client_uri=CLIENT_URI,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> dict[str, str] | str:
|
||||
"""Get provider icon, handling both dict and string formats"""
|
||||
if isinstance(self.icon, dict):
|
||||
return self.icon
|
||||
try:
|
||||
return json.loads(self.icon)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not JSON, assume it's a file path
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
|
||||
"""Convert to API response format
|
||||
|
||||
Args:
|
||||
user_name: User name to display
|
||||
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
|
||||
"""
|
||||
response = {
|
||||
"id": self.id,
|
||||
"author": user_name or "Anonymous",
|
||||
"name": self.name,
|
||||
"icon": self.provider_icon,
|
||||
"type": ToolProviderType.MCP.value,
|
||||
"is_team_authorization": self.authed,
|
||||
"server_url": self.masked_server_url(),
|
||||
"server_identifier": self.provider_id,
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
}
|
||||
|
||||
# Add configuration
|
||||
response["configuration"] = {
|
||||
"timeout": str(self.timeout),
|
||||
"sse_read_timeout": str(self.sse_read_timeout),
|
||||
}
|
||||
|
||||
# Skip expensive operations when sensitive data is not needed (e.g., list view)
|
||||
if not include_sensitive:
|
||||
response["masked_headers"] = {}
|
||||
response["is_dynamic_registration"] = True
|
||||
else:
|
||||
# Add masked headers
|
||||
response["masked_headers"] = self.masked_headers()
|
||||
|
||||
# Add authentication info if available
|
||||
masked_creds = self.masked_credentials()
|
||||
if masked_creds:
|
||||
response["authentication"] = masked_creds
|
||||
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
|
||||
"is_dynamic_registration", True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def retrieve_client_information(self) -> OAuthClientInformation | None:
|
||||
"""OAuth client information if available"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" not in credentials:
|
||||
return None
|
||||
client_info_data = credentials["client_information"]
|
||||
if isinstance(client_info_data, dict):
|
||||
if "encrypted_client_secret" in client_info_data:
|
||||
client_info_data["client_secret"] = encrypter.decrypt_token(
|
||||
self.tenant_id, client_info_data["encrypted_client_secret"]
|
||||
)
|
||||
return OAuthClientInformation.model_validate(client_info_data)
|
||||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def masked_server_url(self) -> str:
|
||||
"""Masked server URL for display"""
|
||||
parsed = urlparse(self.decrypt_server_url())
|
||||
if parsed.path and parsed.path != "/":
|
||||
masked = parsed._replace(path="/******")
|
||||
return masked.geturl()
|
||||
return parsed.geturl()
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask a sensitive value for display"""
|
||||
if len(value) > MIN_UNMASK_LENGTH:
|
||||
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
return MASK_CHAR * len(value)
|
||||
|
||||
def masked_headers(self) -> dict[str, str]:
|
||||
"""Masked headers for display"""
|
||||
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
|
||||
|
||||
def masked_credentials(self) -> dict[str, str]:
|
||||
"""Masked credentials for display"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return {}
|
||||
|
||||
masked = {}
|
||||
|
||||
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
|
||||
return {}
|
||||
client_info = credentials["client_information"]
|
||||
# Mask sensitive fields from nested structure
|
||||
if client_info.get("client_id"):
|
||||
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||
if client_info.get("encrypted_client_secret"):
|
||||
masked["client_secret"] = self._mask_value(
|
||||
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
|
||||
)
|
||||
if client_info.get("client_secret"):
|
||||
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||
return masked
|
||||
|
||||
def decrypt_server_url(self) -> str:
|
||||
"""Decrypt server URL"""
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Only decrypt fields that are actually encrypted
|
||||
# For nested structures, client_information is not encrypted as a whole
|
||||
encrypted_fields = []
|
||||
for key, value in data.items():
|
||||
# Skip nested objects - they are not encrypted
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
# Only process string values that might be encrypted
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_fields.append(key)
|
||||
|
||||
if not encrypted_fields:
|
||||
return data
|
||||
|
||||
# Create dynamic config only for encrypted fields
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt only the encrypted fields
|
||||
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||
|
||||
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||
result = data.copy()
|
||||
result.update(decrypted_data)
|
||||
|
||||
return result
|
||||
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
return self._decrypt_dict(self.headers)
|
||||
|
||||
def decrypt_credentials(self) -> dict[str, Any]:
|
||||
"""Decrypt credentials"""
|
||||
return self._decrypt_dict(self.credentials)
|
||||
|
||||
def decrypt_authentication(self) -> dict[str, Any]:
|
||||
"""Decrypt authentication"""
|
||||
# Option 1: if headers is provided, use it and don't need to get token
|
||||
headers = self.decrypt_headers()
|
||||
|
||||
# Option 2: Add OAuth token if authed and no headers provided
|
||||
if not self.headers and self.authed:
|
||||
token = self.retrieve_tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
return headers
|
||||
@@ -74,6 +74,10 @@ class File(BaseModel):
|
||||
storage_key: str | None = None,
|
||||
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
|
||||
url: str | None = None,
|
||||
# Legacy compatibility fields - explicitly handle known extra fields
|
||||
tool_file_id: str | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
datasource_file_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
|
||||
@@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(
|
||||
f"""
|
||||
// declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
// decode and prepare input object
|
||||
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
|
||||
@@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
var output_json = JSON.stringify(output_obj)
|
||||
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
|
||||
console.log(result)
|
||||
"""
|
||||
)
|
||||
""")
|
||||
return runner_script
|
||||
|
||||
@@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
# declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
@@ -29,6 +29,18 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
|
||||
if not plugin_ids:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("data", {}).get("plugins", [])
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
plugin_ids: list[str],
|
||||
) -> Sequence[MarketplacePluginDeclaration]:
|
||||
|
||||
@@ -56,6 +56,9 @@ class HostingConfiguration:
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@@ -128,7 +131,7 @@ class HostingConfiguration:
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
@@ -156,18 +159,49 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
def init_gemini(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_GEMINI_API_BASE:
|
||||
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
@@ -185,6 +219,66 @@ class HostingConfiguration:
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_xai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_XAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_XAI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_XAI_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_deepseek(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
|
||||
@@ -415,7 +415,6 @@ class IndexingRunner:
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="splitting",
|
||||
extra_update_params={
|
||||
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
||||
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
||||
},
|
||||
)
|
||||
@@ -755,6 +754,7 @@ class IndexingRunner:
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
DatasetDocument.word_count: sum(len(doc.page_content) for doc in documents),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -14,10 +14,12 @@ from core.llm_generator.prompts import (
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.memory.entities import MemoryBlock, MemoryBlockSpec
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
@@ -28,6 +30,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
@@ -562,3 +565,35 @@ class LLMGenerator:
|
||||
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
|
||||
)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def update_memory_block(
|
||||
tenant_id: str,
|
||||
visible_history: Sequence[tuple[str, str]],
|
||||
variable_pool: VariablePool,
|
||||
memory_block: MemoryBlock,
|
||||
memory_spec: MemoryBlockSpec,
|
||||
) -> str:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=memory_spec.model.provider,
|
||||
model=memory_spec.model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
formatted_history = ""
|
||||
for sender, message in visible_history:
|
||||
formatted_history += f"{sender}: {message}\n"
|
||||
filled_instruction = variable_pool.convert_template(memory_spec.instruction).text
|
||||
formatted_prompt = PromptTemplateParser(MEMORY_UPDATE_PROMPT).format(
|
||||
inputs={
|
||||
"formatted_history": formatted_history,
|
||||
"current_value": memory_block.value,
|
||||
"instruction": filled_instruction,
|
||||
}
|
||||
)
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content=formatted_prompt)],
|
||||
model_parameters=memory_spec.model.completion_params,
|
||||
stream=False,
|
||||
)
|
||||
return llm_result.message.get_text_content()
|
||||
|
||||
@@ -422,3 +422,18 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
MEMORY_UPDATE_PROMPT = """
|
||||
Based on the following conversation history, update the memory content:
|
||||
|
||||
Conversation history:
|
||||
{{formatted_history}}
|
||||
|
||||
Current memory:
|
||||
{{current_value}}
|
||||
|
||||
Update instruction:
|
||||
{{instruction}}
|
||||
|
||||
Please output only the updated memory content, no other text like greeting:
|
||||
"""
|
||||
|
||||
@@ -6,11 +6,15 @@ import secrets
|
||||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
@@ -19,21 +23,10 @@ from core.mcp.types import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
@@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||
"""
|
||||
Handle the callback from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||
"""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
@@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
return full_state_data, tokens
|
||||
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
if b_fragment:
|
||||
url_for_resource_discovery += f"#{b_fragment}"
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
response = httpx.get(url_for_resource_discovery, headers=headers)
|
||||
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
if "authorization_server_url" in body:
|
||||
# Support both singular and plural forms
|
||||
if body.get("authorization_servers"):
|
||||
return True, body["authorization_servers"][0]
|
||||
elif body.get("authorization_server_url"):
|
||||
return True, body["authorization_server_url"][0]
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError:
|
||||
except RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
@@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
url = oauth_discovery_url
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
continue
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
|
||||
return None # No metadata found
|
||||
|
||||
|
||||
def start_authorization(
|
||||
@@ -213,7 +223,7 @@ def exchange_authorization(
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@@ -233,7 +243,7 @@ def exchange_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
@@ -246,7 +256,7 @@ def refresh_authorization(
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@@ -263,10 +273,55 @@ def refresh_authorization(
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
try:
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
scope: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Execute Client Credentials Flow to get access token."""
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
# Support both Basic Auth and body parameters for client authentication
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"grant_type": grant_type}
|
||||
|
||||
if scope:
|
||||
data["scope"] = scope
|
||||
|
||||
# If client_secret is provided, use Basic Auth (preferred method)
|
||||
if client_information.client_secret:
|
||||
credentials = f"{client_information.client_id}:{client_information.client_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
else:
|
||||
# Fall back to including credentials in the body
|
||||
data["client_id"] = client_information.client_id
|
||||
if client_information.client_secret:
|
||||
data["client_secret"] = client_information.client_secret
|
||||
|
||||
response = ssrf_proxy.post(token_url, headers=headers, data=data)
|
||||
if not response.is_success:
|
||||
raise ValueError(
|
||||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
@@ -283,7 +338,7 @@ def register_client(
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = httpx.post(
|
||||
response = ssrf_proxy.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
@@ -294,28 +349,111 @@ def register_client(
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
||||
This function performs only network operations and returns actions that need
|
||||
to be performed by the caller (such as saving data to database).
|
||||
|
||||
Args:
|
||||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
raise ValueError("Failed to discover OAuth metadata from server")
|
||||
|
||||
supported_grant_types = server_metadata.grant_types_supported or []
|
||||
|
||||
# Convert to lowercase for comparison
|
||||
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
|
||||
|
||||
# Determine which grant type to use
|
||||
effective_grant_type = None
|
||||
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
|
||||
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
|
||||
# For client credentials flow, we don't need to register client dynamically
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Client should provide client_id and client_secret directly
|
||||
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except httpx.RequestError as e:
|
||||
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
|
||||
# Return action to save client information
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||
data={"client_information": full_information.model_dump()},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Handle client credentials flow
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=token_data,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Client credentials flow failed: {e}")
|
||||
|
||||
# Exchange authorization code for tokens (Authorization Code flow)
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
@@ -335,35 +473,69 @@ def auth(
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
server_metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
# Return action to save tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
|
||||
provider_tokens = provider.retrieve_tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
new_tokens = refresh_authorization(
|
||||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Return action to save new tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=new_tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
# Start new authorization flow (only for authorization code flow)
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
server_metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
# Return action to save code verifier
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||
data={"code_verifier": code_verifier},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> OAuthClientInformation | None:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull):
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> OAuthTokens | None:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens):
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str):
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
191
api/core/mcp/auth_client.py
Normal file
191
api/core/mcp/auth_client.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
MCP Client with Authentication Retry Support
|
||||
|
||||
This module provides an enhanced MCPClient that automatically handles
|
||||
authentication failures and retries operations after refreshing tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
An enhanced MCPClient that provides automatic authentication retry.
|
||||
|
||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||
to refresh authentication before retrying failed operations.
|
||||
|
||||
Note: This class uses lazy session creation - database sessions are only
|
||||
created when authentication retry is actually needed, not on every request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
headers: Optional headers for requests
|
||||
timeout: Request timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
provider_entity: Provider entity for authentication
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
"""
|
||||
Handle authentication error by refreshing tokens.
|
||||
|
||||
This method creates a short-lived database session only when authentication
|
||||
retry is needed, minimizing database connection hold time.
|
||||
|
||||
Args:
|
||||
error: The authentication error
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
||||
self._has_retried = True
|
||||
|
||||
try:
|
||||
# Create a temporary session only for auth retry
|
||||
# This session is short-lived and only exists during the auth operation
|
||||
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
|
||||
# Session is closed here, before we update headers
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
||||
# Update headers with new token
|
||||
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
|
||||
# Clear authorization code after first use
|
||||
self.authorization_code = None
|
||||
|
||||
except MCPAuthError:
|
||||
# Re-raise MCPAuthError as is
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all exceptions during auth retry
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
Args:
|
||||
func: The function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
Any other exceptions from the function
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MCPAuthError as e:
|
||||
self._handle_auth_error(e)
|
||||
|
||||
# Re-initialize the connection with new headers
|
||||
if self._initialized:
|
||||
# Clean up existing connection
|
||||
self._exit_stack.close()
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
# Re-initialize with new headers
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
def initialize_with_retry():
|
||||
super(MCPClientWithAuthRetry, self).__enter__()
|
||||
return self
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
|
||||
Returns:
|
||||
List of available tools
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to invoke
|
||||
tool_args: Arguments for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool invocation
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)
|
||||
0
api/core/mcp/auth_client_comparison.md
Normal file
0
api/core/mcp/auth_client_comparison.md
Normal file
@@ -46,7 +46,7 @@ class SSETransport:
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
):
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
@@ -255,7 +255,7 @@ def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
@@ -276,31 +276,34 @@ def sse_client(
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
executor = ThreadPoolExecutor()
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||
|
||||
@@ -434,45 +434,48 @@ def streamablehttp_client(
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
@@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
|
||||
class AuthActionType(StrEnum):
|
||||
"""Types of actions that can be performed during auth flow."""
|
||||
|
||||
SAVE_CLIENT_INFO = "save_client_info"
|
||||
SAVE_TOKENS = "save_tokens"
|
||||
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||
START_AUTHORIZATION = "start_authorization"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class AuthAction(BaseModel):
|
||||
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||
|
||||
action_type: AuthActionType
|
||||
data: dict[str, Any]
|
||||
provider_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class AuthResult(BaseModel):
|
||||
"""Result of auth function containing actions to be performed and response data."""
|
||||
|
||||
actions: list[AuthAction]
|
||||
response: dict[str, str]
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
"""State data stored in Redis during OAuth callback flow."""
|
||||
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
@@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
pass
|
||||
|
||||
@@ -7,9 +7,9 @@ from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import Tool
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,40 +18,18 @@ class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: str | None = None,
|
||||
for_list: bool = False,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: ClientSession | None = None
|
||||
self._streams_context: AbstractContextManager[Any] | None = None
|
||||
self._session_context: ClientSession | None = None
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
@@ -85,61 +63,42 @@ class MCPClient:
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
||||
):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
|
||||
"""
|
||||
Connect to the MCP server using streamable http or sse.
|
||||
Default to streamable http.
|
||||
Args:
|
||||
client_factory: The client factory to use(streamablehttp_client or sse_client).
|
||||
method_name: The method name to use(mcp or sse).
|
||||
"""
|
||||
streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else self.headers
|
||||
)
|
||||
self._streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
if not self._streams_context:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(streams_context)
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
self._session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
try:
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to authenticate: {e}")
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(session_context)
|
||||
self._session.initialize()
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
if not self._initialized or not self._session:
|
||||
"""List available tools from the MCP server"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
tools = response.tools
|
||||
return tools
|
||||
return response.tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""Call a tool"""
|
||||
if not self._initialized or not self._session:
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
@@ -153,6 +112,4 @@ class MCPClient:
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
|
||||
@@ -201,11 +201,14 @@ class BaseSession(
|
||||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||
except TimeoutError:
|
||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||
pass
|
||||
# Cancel the future to interrupt the receiver loop
|
||||
self._receiver_future.cancel()
|
||||
|
||||
# Shutdown the executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
# Use non-blocking shutdown to prevent hanging
|
||||
# The receiver thread should have already exited due to the None message in the queue
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
|
||||
@@ -109,12 +109,16 @@ class ClientSession(
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
# Only set capabilities if non-default callbacks are provided
|
||||
# This prevents servers from attempting callbacks when we don't actually support them
|
||||
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
|
||||
roots = (
|
||||
types.RootsCapability(
|
||||
# Only enable listChanged if we have a custom callback
|
||||
listChanged=True,
|
||||
)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
@@ -284,7 +288,7 @@ class ClientSession(
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||
from pydantic.networks import AnyUrl, UrlConstraints
|
||||
@@ -33,6 +26,7 @@ for reference.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
# Server support 2024-11-05 to allow claude to use.
|
||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||
ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
@@ -55,14 +49,22 @@ class RequestParams(BaseModel):
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
|
||||
|
||||
class PaginatedRequestParams(RequestParams):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
|
||||
|
||||
class NotificationParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
This parameter name is reserved by MCP to allow clients and servers to attach
|
||||
additional metadata to their notifications.
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
|
||||
|
||||
@@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedRequest(Request[RequestParamsT, MethodT]):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
||||
"""Base class for paginated requests,
|
||||
matching the schema's PaginatedRequest interface."""
|
||||
|
||||
params: PaginatedRequestParams | None = None
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
@@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
class Result(BaseModel):
|
||||
"""Base class for JSON-RPC results."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
This result property is reserved by the protocol to allow clients and servers to
|
||||
attach additional metadata to their responses.
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedResult(Result):
|
||||
@@ -186,10 +186,26 @@ class EmptyResult(Result):
|
||||
"""A response that indicates success but carries no data."""
|
||||
|
||||
|
||||
class Implementation(BaseModel):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
class BaseMetadata(BaseModel):
|
||||
"""Base class for entities with name and optional title fields."""
|
||||
|
||||
name: str
|
||||
"""The programmatic name of the entity."""
|
||||
|
||||
title: str | None = None
|
||||
"""
|
||||
Intended for UI and end-user contexts — optimized to be human-readable and easily understood,
|
||||
even by those unfamiliar with domain-specific terminology.
|
||||
|
||||
If not provided, the name should be used for display (except for Tool,
|
||||
where `annotations.title` should be given precedence over using `name`,
|
||||
if present).
|
||||
"""
|
||||
|
||||
|
||||
class Implementation(BaseMetadata):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
|
||||
version: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
|
||||
|
||||
|
||||
class SamplingCapability(BaseModel):
|
||||
"""Capability for logging operations."""
|
||||
"""Capability for sampling operations."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionsCapability(BaseModel):
|
||||
"""Capability for completions operations."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ServerCapabilities(BaseModel):
|
||||
"""Capabilities that a server may support."""
|
||||
|
||||
@@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
|
||||
"""Present if the server offers any resources to read."""
|
||||
tools: ToolsCapability | None = None
|
||||
"""Present if the server offers any tools to call."""
|
||||
completions: CompletionsCapability | None = None
|
||||
"""Present if the server offers autocompletion suggestions for prompts and resources."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
|
||||
to begin initialization.
|
||||
"""
|
||||
|
||||
method: Literal["initialize"]
|
||||
method: Literal["initialize"] = "initialize"
|
||||
params: InitializeRequestParams
|
||||
|
||||
|
||||
@@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
|
||||
finished.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/initialized"]
|
||||
method: Literal["notifications/initialized"] = "notifications/initialized"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
|
||||
still alive.
|
||||
"""
|
||||
|
||||
method: Literal["ping"]
|
||||
method: Literal["ping"] = "ping"
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
@@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
|
||||
"""
|
||||
total: float | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
message: str | None = None
|
||||
"""
|
||||
Message related to progress. This should provide relevant human readable
|
||||
progress information.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
|
||||
long-running request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/progress"]
|
||||
method: Literal["notifications/progress"] = "notifications/progress"
|
||||
params: ProgressNotificationParams
|
||||
|
||||
|
||||
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
|
||||
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
|
||||
"""Sent from the client to request a list of resources the server has."""
|
||||
|
||||
method: Literal["resources/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["resources/list"] = "resources/list"
|
||||
|
||||
|
||||
class Annotations(BaseModel):
|
||||
@@ -360,13 +388,11 @@ class Annotations(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
class Resource(BaseMetadata):
|
||||
"""A known resource that the server is capable of reading."""
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""The URI of this resource."""
|
||||
name: str
|
||||
"""A human-readable name for this resource."""
|
||||
description: str | None = None
|
||||
"""A description of what this resource represents."""
|
||||
mimeType: str | None = None
|
||||
@@ -379,10 +405,15 @@ class Resource(BaseModel):
|
||||
This can be used by Hosts to display file sizes and estimate context window usage.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
class ResourceTemplate(BaseMetadata):
|
||||
"""A template description for resources available on the server."""
|
||||
|
||||
uriTemplate: str
|
||||
@@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
|
||||
A URI template (according to RFC 6570) that can be used to construct resource
|
||||
URIs.
|
||||
"""
|
||||
name: str
|
||||
"""A human-readable name for the type of resource this template refers to."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of what this template is for."""
|
||||
mimeType: str | None = None
|
||||
@@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
|
||||
included if all resources matching this template have the same type.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
|
||||
resources: list[Resource]
|
||||
|
||||
|
||||
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
|
||||
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
|
||||
"""Sent from the client to request a list of resource templates the server has."""
|
||||
|
||||
method: Literal["resources/templates/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["resources/templates/list"] = "resources/templates/list"
|
||||
|
||||
|
||||
class ListResourceTemplatesResult(PaginatedResult):
|
||||
@@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
|
||||
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
|
||||
"""Sent from the client to the server, to read a specific resource URI."""
|
||||
|
||||
method: Literal["resources/read"]
|
||||
method: Literal["resources/read"] = "resources/read"
|
||||
params: ReadResourceRequestParams
|
||||
|
||||
|
||||
@@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
|
||||
"""The URI of this resource."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -481,7 +519,7 @@ class ResourceListChangedNotification(
|
||||
of resources it can read from has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/list_changed"]
|
||||
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
|
||||
whenever a particular resource changes.
|
||||
"""
|
||||
|
||||
method: Literal["resources/subscribe"]
|
||||
method: Literal["resources/subscribe"] = "resources/subscribe"
|
||||
params: SubscribeRequestParams
|
||||
|
||||
|
||||
@@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
|
||||
the server.
|
||||
"""
|
||||
|
||||
method: Literal["resources/unsubscribe"]
|
||||
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
|
||||
params: UnsubscribeRequestParams
|
||||
|
||||
|
||||
@@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
|
||||
changed and may need to be read again.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/updated"]
|
||||
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
|
||||
params: ResourceUpdatedNotificationParams
|
||||
|
||||
|
||||
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
|
||||
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
|
||||
"""Sent from the client to request a list of prompts and prompt templates."""
|
||||
|
||||
method: Literal["prompts/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["prompts/list"] = "prompts/list"
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
@@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
class Prompt(BaseMetadata):
|
||||
"""A prompt or prompt template that the server offers."""
|
||||
|
||||
name: str
|
||||
"""The name of the prompt or prompt template."""
|
||||
description: str | None = None
|
||||
"""An optional description of what this prompt provides."""
|
||||
arguments: list[PromptArgument] | None = None
|
||||
"""A list of arguments to use for templating the prompt."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
|
||||
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
|
||||
"""Used by the client to get a prompt provided by the server."""
|
||||
|
||||
method: Literal["prompts/get"]
|
||||
method: Literal["prompts/get"] = "prompts/get"
|
||||
params: GetPromptRequestParams
|
||||
|
||||
|
||||
@@ -608,6 +648,11 @@ class TextContent(BaseModel):
|
||||
text: str
|
||||
"""The text content of the message."""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -623,6 +668,31 @@ class ImageContent(BaseModel):
|
||||
image types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AudioContent(BaseModel):
|
||||
"""Audio content for a message."""
|
||||
|
||||
type: Literal["audio"]
|
||||
data: str
|
||||
"""The base64-encoded audio data."""
|
||||
mimeType: str
|
||||
"""
|
||||
The MIME type of the audio. Different providers may support different
|
||||
audio types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
|
||||
"""Describes a message issued to or received from an LLM API."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
|
||||
type: Literal["resource"]
|
||||
resource: TextResourceContents | BlobResourceContents
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceLink(Resource):
|
||||
"""
|
||||
A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||
|
||||
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||
"""
|
||||
|
||||
type: Literal["resource_link"]
|
||||
|
||||
|
||||
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
||||
"""A content block that can be used in prompts and tool results."""
|
||||
|
||||
Content: TypeAlias = ContentBlock
|
||||
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
"""Describes a message returned as part of a prompt."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent | EmbeddedResource
|
||||
content: ContentBlock
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -672,15 +764,14 @@ class PromptListChangedNotification(
|
||||
of prompts it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/prompts/list_changed"]
|
||||
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
|
||||
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
|
||||
"""Sent from the client to request a list of tools the server has."""
|
||||
|
||||
method: Literal["tools/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["tools/list"] = "tools/list"
|
||||
|
||||
|
||||
class ToolAnnotations(BaseModel):
|
||||
@@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
class Tool(BaseMetadata):
|
||||
"""Definition for a tool the client can call."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of the tool."""
|
||||
inputSchema: dict[str, Any]
|
||||
"""A JSON Schema object defining the expected parameters for the tool."""
|
||||
outputSchema: dict[str, Any] | None = None
|
||||
"""
|
||||
An optional JSON Schema object defining the structure of the tool's output
|
||||
returned in the structuredContent field of a CallToolResult.
|
||||
"""
|
||||
annotations: ToolAnnotations | None = None
|
||||
"""Optional additional tool information."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
|
||||
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
||||
"""Used by the client to invoke a tool provided by the server."""
|
||||
|
||||
method: Literal["tools/call"]
|
||||
method: Literal["tools/call"] = "tools/call"
|
||||
params: CallToolRequestParams
|
||||
|
||||
|
||||
class CallToolResult(Result):
|
||||
"""The server's response to a tool call."""
|
||||
|
||||
content: list[TextContent | ImageContent | EmbeddedResource]
|
||||
content: list[ContentBlock]
|
||||
structuredContent: dict[str, Any] | None = None
|
||||
"""An optional JSON object that represents the structured result of the tool call."""
|
||||
isError: bool = False
|
||||
|
||||
|
||||
@@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
|
||||
of tools it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/tools/list_changed"]
|
||||
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
|
||||
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
|
||||
"""A request from the client to the server, to enable or adjust logging."""
|
||||
|
||||
method: Literal["logging/setLevel"]
|
||||
method: Literal["logging/setLevel"] = "logging/setLevel"
|
||||
params: SetLevelRequestParams
|
||||
|
||||
|
||||
@@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
"""The severity of this log message."""
|
||||
logger: str | None = None
|
||||
"""An optional name of the logger issuing this message."""
|
||||
data: Any = None
|
||||
data: Any
|
||||
"""
|
||||
The data to be logged, such as a string message or an object. Any JSON serializable
|
||||
type is allowed here.
|
||||
@@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
|
||||
"""Notification of a log message passed from server to client."""
|
||||
|
||||
method: Literal["notifications/message"]
|
||||
method: Literal["notifications/message"] = "notifications/message"
|
||||
params: LoggingMessageNotificationParams
|
||||
|
||||
|
||||
@@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
|
||||
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
|
||||
"""A request from the server to sample an LLM via the client."""
|
||||
|
||||
method: Literal["sampling/createMessage"]
|
||||
method: Literal["sampling/createMessage"] = "sampling/createMessage"
|
||||
params: CreateMessageRequestParams
|
||||
|
||||
|
||||
@@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
|
||||
"""The client's response to a sampling/create_message request from the server."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model: str
|
||||
"""The name of the model that generated the message."""
|
||||
stopReason: StopReason | None = None
|
||||
"""The reason why sampling stopped, if known."""
|
||||
|
||||
|
||||
class ResourceReference(BaseModel):
|
||||
class ResourceTemplateReference(BaseModel):
|
||||
"""A reference to a resource or resource template definition."""
|
||||
|
||||
type: Literal["ref/resource"]
|
||||
@@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionContext(BaseModel):
|
||||
"""Additional, optional context for completions."""
|
||||
|
||||
arguments: dict[str, str] | None = None
|
||||
"""Previously-resolved variables in a URI template or prompt."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequestParams(RequestParams):
|
||||
"""Parameters for completion requests."""
|
||||
|
||||
ref: ResourceReference | PromptReference
|
||||
ref: ResourceTemplateReference | PromptReference
|
||||
argument: CompletionArgument
|
||||
context: CompletionContext | None = None
|
||||
"""Additional, optional context for completions"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
|
||||
"""A request from the client to the server, to ask for completion options."""
|
||||
|
||||
method: Literal["completion/complete"]
|
||||
method: Literal["completion/complete"] = "completion/complete"
|
||||
params: CompleteRequestParams
|
||||
|
||||
|
||||
@@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
|
||||
structure or access specific locations that the client has permission to read from.
|
||||
"""
|
||||
|
||||
method: Literal["roots/list"]
|
||||
method: Literal["roots/list"] = "roots/list"
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
@@ -1029,6 +1140,11 @@ class Root(BaseModel):
|
||||
identifier for the root, which may be useful for display purposes or for
|
||||
referencing the root in other parts of the application.
|
||||
"""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
|
||||
using the ListRootsRequest.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/roots/list_changed"]
|
||||
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
|
||||
previously-issued request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/cancelled"]
|
||||
method: Literal["notifications/cancelled"] = "notifications/cancelled"
|
||||
params: CancelledNotificationParams
|
||||
|
||||
|
||||
|
||||
119
api/core/memory/entities.py
Normal file
119
api/core/memory/entities.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
|
||||
|
||||
class MemoryScope(StrEnum):
|
||||
"""Memory scope determined by node_id field"""
|
||||
APP = "app" # node_id is None
|
||||
NODE = "node" # node_id is not None
|
||||
|
||||
|
||||
class MemoryTerm(StrEnum):
|
||||
"""Memory term determined by conversation_id field"""
|
||||
SESSION = "session" # conversation_id is not None
|
||||
PERSISTENT = "persistent" # conversation_id is None
|
||||
|
||||
|
||||
class MemoryStrategy(StrEnum):
|
||||
ON_TURNS = "on_turns"
|
||||
|
||||
|
||||
class MemoryScheduleMode(StrEnum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
|
||||
|
||||
class MemoryBlockSpec(BaseModel):
|
||||
"""Memory block specification for workflow configuration"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
description="Unique identifier for the memory block",
|
||||
)
|
||||
name: str = Field(description="Display name of the memory block")
|
||||
description: str = Field(default="", description="Description of the memory block")
|
||||
template: str = Field(description="Initial template content for the memory")
|
||||
instruction: str = Field(description="Instructions for updating the memory")
|
||||
scope: MemoryScope = Field(description="Scope of the memory (app or node level)")
|
||||
term: MemoryTerm = Field(description="Term of the memory (session or persistent)")
|
||||
strategy: MemoryStrategy = Field(description="Update strategy for the memory")
|
||||
update_turns: int = Field(gt=0, description="Number of turns between updates")
|
||||
preserved_turns: int = Field(gt=0, description="Number of conversation turns to preserve")
|
||||
schedule_mode: MemoryScheduleMode = Field(description="Synchronous or asynchronous update mode")
|
||||
model: ModelConfig = Field(description="Model configuration for memory updates")
|
||||
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
|
||||
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
|
||||
|
||||
|
||||
class MemoryCreatedBy(BaseModel):
|
||||
end_user_id: str | None = None
|
||||
account_id: str | None = None
|
||||
|
||||
|
||||
class MemoryBlock(BaseModel):
|
||||
"""Runtime memory block instance
|
||||
|
||||
Design Rules:
|
||||
- app_id = None: Global memory (future feature, not implemented yet)
|
||||
- app_id = str: App-specific memory
|
||||
- conversation_id = None: Persistent memory (cross-conversation)
|
||||
- conversation_id = str: Session memory (conversation-specific)
|
||||
- node_id = None: App-level scope
|
||||
- node_id = str: Node-level scope
|
||||
|
||||
These rules implicitly determine scope and term without redundant storage.
|
||||
"""
|
||||
spec: MemoryBlockSpec
|
||||
tenant_id: str
|
||||
value: str
|
||||
app_id: str
|
||||
conversation_id: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
edited_by_user: bool = False
|
||||
created_by: MemoryCreatedBy
|
||||
version: int = Field(description="Memory block version number")
|
||||
|
||||
|
||||
class MemoryValueData(BaseModel):
|
||||
value: str
|
||||
edited_by_user: bool = False
|
||||
|
||||
|
||||
class ChatflowConversationMetadata(BaseModel):
|
||||
"""Metadata for chatflow conversation with visible message count"""
|
||||
type: str = "mutable_visible_window"
|
||||
visible_count: int = Field(gt=0, description="Number of visible messages to keep")
|
||||
|
||||
|
||||
class MemoryBlockWithConversation(MemoryBlock):
|
||||
"""MemoryBlock with optional conversation metadata for session memories"""
|
||||
conversation_metadata: ChatflowConversationMetadata = Field(
|
||||
description="Conversation metadata, only present for session memories"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_memory_block(
|
||||
cls,
|
||||
memory_block: MemoryBlock,
|
||||
conversation_metadata: ChatflowConversationMetadata
|
||||
) -> MemoryBlockWithConversation:
|
||||
"""Create MemoryBlockWithConversation from MemoryBlock"""
|
||||
return cls(
|
||||
spec=memory_block.spec,
|
||||
tenant_id=memory_block.tenant_id,
|
||||
value=memory_block.value,
|
||||
app_id=memory_block.app_id,
|
||||
conversation_id=memory_block.conversation_id,
|
||||
node_id=memory_block.node_id,
|
||||
edited_by_user=memory_block.edited_by_user,
|
||||
created_by=memory_block.created_by,
|
||||
version=memory_block.version,
|
||||
conversation_metadata=conversation_metadata
|
||||
)
|
||||
6
api/core/memory/errors.py
Normal file
6
api/core/memory/errors.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class MemorySyncTimeoutError(Exception):
|
||||
def __init__(self, app_id: str, conversation_id: str):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.message = "Memory synchronization timeout after 50 seconds"
|
||||
super().__init__(self.message)
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
@@ -18,7 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from models.workflow import Workflow
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
@@ -29,6 +32,14 @@ class TokenBufferMemory:
|
||||
):
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
self._workflow_run_repo: APIWorkflowRunRepository | None = None
|
||||
|
||||
@property
|
||||
def workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||
if self._workflow_run_repo is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return self._workflow_run_repo
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self,
|
||||
@@ -50,7 +61,16 @@ class TokenBufferMemory:
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
|
||||
@@ -38,6 +38,8 @@ class LLMUsageMetadata(TypedDict, total=False):
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
time_to_first_token: float
|
||||
time_to_generate: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
@@ -57,6 +59,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
time_to_first_token: float | None = None
|
||||
time_to_generate: float | None = None
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
@@ -73,6 +77,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
time_to_first_token=None,
|
||||
time_to_generate=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -108,6 +114,8 @@ class LLMUsage(ModelUsage):
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
time_to_first_token=metadata.get("time_to_first_token"),
|
||||
time_to_generate=metadata.get("time_to_generate"),
|
||||
)
|
||||
|
||||
def plus(self, other: LLMUsage) -> LLMUsage:
|
||||
@@ -133,6 +141,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency,
|
||||
time_to_first_token=other.time_to_first_token,
|
||||
time_to_generate=other.time_to_generate,
|
||||
)
|
||||
|
||||
def __add__(self, other: LLMUsage) -> LLMUsage:
|
||||
|
||||
@@ -62,6 +62,9 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
llm_streaming_time_to_generate: float | None = None
|
||||
is_streaming_request: bool = False
|
||||
|
||||
|
||||
class ModerationTraceInfo(BaseTraceInfo):
|
||||
|
||||
@@ -12,9 +12,9 @@ from uuid import UUID, uuid4
|
||||
from cachetools import LRUCache
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
@@ -34,7 +34,8 @@ from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from models.workflow import WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -140,6 +141,8 @@ provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
class OpsTraceManager:
|
||||
ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
|
||||
decrypted_configs_cache: LRUCache = LRUCache(maxsize=128)
|
||||
_decryption_cache_lock = threading.RLock()
|
||||
|
||||
@classmethod
|
||||
def encrypt_tracing_config(
|
||||
@@ -160,7 +163,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
|
||||
new_config = {}
|
||||
new_config: dict[str, Any] = {}
|
||||
# Encrypt necessary keys
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
@@ -190,20 +193,41 @@ class OpsTraceManager:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
config_json = json.dumps(tracing_config, sort_keys=True)
|
||||
decrypted_config_key = (
|
||||
tenant_id,
|
||||
tracing_provider,
|
||||
config_json,
|
||||
)
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
new_config[key] = decrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
# First check without lock for performance
|
||||
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
|
||||
if cached_config is not None:
|
||||
return dict(cached_config)
|
||||
|
||||
return config_class(**new_config).model_dump()
|
||||
with cls._decryption_cache_lock:
|
||||
# Second check (double-checked locking) to prevent race conditions
|
||||
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
|
||||
if cached_config is not None:
|
||||
return dict(cached_config)
|
||||
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config: dict[str, Any] = {}
|
||||
keys_to_decrypt = [key for key in secret_keys if key in tracing_config]
|
||||
if keys_to_decrypt:
|
||||
decrypted_values = batch_decrypt_token(tenant_id, [tracing_config[key] for key in keys_to_decrypt])
|
||||
new_config.update(zip(keys_to_decrypt, decrypted_values))
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
|
||||
decrypted_config = config_class(**new_config).model_dump()
|
||||
cls.decrypted_configs_cache[decrypted_config_key] = decrypted_config
|
||||
return dict(decrypted_config)
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
@@ -218,7 +242,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config = {}
|
||||
new_config: dict[str, Any] = {}
|
||||
for key in secret_keys:
|
||||
if key in decrypt_tracing_config:
|
||||
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
|
||||
@@ -419,6 +443,18 @@ class OpsTraceManager:
|
||||
|
||||
|
||||
class TraceTask:
|
||||
_workflow_run_repo = None
|
||||
_repo_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
@@ -486,27 +522,27 @@ class TraceTask:
|
||||
if not workflow_run_id:
|
||||
return {}
|
||||
|
||||
workflow_run_repo = self._get_workflow_run_repo()
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(run_id=workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||
workflow_run_status = workflow_run.status
|
||||
workflow_run_inputs = workflow_run.inputs_dict
|
||||
workflow_run_outputs = workflow_run.outputs_dict
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error or ""
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalars(workflow_run_stmt).first()
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||
workflow_run_status = workflow_run.status
|
||||
workflow_run_inputs = workflow_run.inputs_dict
|
||||
workflow_run_outputs = workflow_run.outputs_dict
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error or ""
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
@@ -523,43 +559,43 @@ class TraceTask:
|
||||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
metadata = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"tenant_id": tenant_id,
|
||||
"elapsed_time": workflow_run_elapsed_time,
|
||||
"status": workflow_run_status,
|
||||
"version": workflow_run_version,
|
||||
"total_tokens": total_tokens,
|
||||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
metadata = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"tenant_id": tenant_id,
|
||||
"elapsed_time": workflow_run_elapsed_time,
|
||||
"status": workflow_run_status,
|
||||
"version": workflow_run_version,
|
||||
"total_tokens": total_tokens,
|
||||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
||||
workflow_run_status=workflow_run_status,
|
||||
workflow_run_inputs=workflow_run_inputs,
|
||||
workflow_run_outputs=workflow_run_outputs,
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
workflow_app_log_id=workflow_app_log_id,
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
)
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
||||
workflow_run_status=workflow_run_status,
|
||||
workflow_run_inputs=workflow_run_inputs,
|
||||
workflow_run_outputs=workflow_run_outputs,
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
workflow_app_log_id=workflow_app_log_id,
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
)
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id: str | None):
|
||||
@@ -583,6 +619,8 @@ class TraceTask:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -615,6 +653,9 @@ class TraceTask:
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
gen_ai_server_time_to_first_token=streaming_metrics.get("gen_ai_server_time_to_first_token"),
|
||||
llm_streaming_time_to_generate=streaming_metrics.get("llm_streaming_time_to_generate"),
|
||||
is_streaming_request=streaming_metrics.get("is_streaming_request", False),
|
||||
)
|
||||
|
||||
return message_trace_info
|
||||
@@ -840,6 +881,24 @@ class TraceTask:
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
|
||||
try:
|
||||
metadata = json.loads(message_data.message_metadata)
|
||||
usage = metadata.get("usage", {})
|
||||
time_to_first_token = usage.get("time_to_first_token")
|
||||
time_to_generate = usage.get("time_to_generate")
|
||||
|
||||
return {
|
||||
"gen_ai_server_time_to_first_token": time_to_first_token,
|
||||
"llm_streaming_time_to_generate": time_to_generate,
|
||||
"is_streaming_request": time_to_first_token is not None,
|
||||
}
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
return {}
|
||||
|
||||
|
||||
trace_manager_timer: threading.Timer | None = None
|
||||
trace_manager_queue: queue.Queue = queue.Queue()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user