mirror of
https://github.com/langgenius/dify.git
synced 2026-01-04 13:37:22 +00:00
Compare commits
639 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38d895ab5f | ||
|
|
a6ff9b224b | ||
|
|
832bef053f | ||
|
|
81b07dc3be | ||
|
|
e23d7e39ec | ||
|
|
9f14b5db9a | ||
|
|
39d3f58082 | ||
|
|
f33b6c0c73 | ||
|
|
a4eddd7dc2 | ||
|
|
c993a05da7 | ||
|
|
f44f0fa34c | ||
|
|
bfcf09b684 | ||
|
|
cdbba1400c | ||
|
|
618ad4c291 | ||
|
|
1449ed86c4 | ||
|
|
eee72101f4 | ||
|
|
efccbe4039 | ||
|
|
540096a8d8 | ||
|
|
7b7cdad1d8 | ||
|
|
6aba39a2dd | ||
|
|
49bb15fae1 | ||
|
|
e165f4a102 | ||
|
|
83cc484c24 | ||
|
|
1ff9c07a92 | ||
|
|
b25b284d7f | ||
|
|
2414dbb5f8 | ||
|
|
916a8c76e7 | ||
|
|
9783832223 | ||
|
|
b77081a19e | ||
|
|
896906ae77 | ||
|
|
2365a3a5fc | ||
|
|
dd792210f6 | ||
|
|
6ba4a4c165 | ||
|
|
0a6dbf6ee2 | ||
|
|
ca0979dd43 | ||
|
|
0762e5ae50 | ||
|
|
48f53f3b9b | ||
|
|
af64f29e87 | ||
|
|
b9f59e3a75 | ||
|
|
b12a8eeb90 | ||
|
|
e551cf65c9 | ||
|
|
3899211c41 | ||
|
|
335e1e3602 | ||
|
|
725fc72c6f | ||
|
|
b618f3bd9e | ||
|
|
95ba55af4d | ||
|
|
f4e1ea9011 | ||
|
|
3d0e288e85 | ||
|
|
9620d6bcd8 | ||
|
|
f7fbded8b9 | ||
|
|
0c5706b3f6 | ||
|
|
82d0a70cb4 | ||
|
|
55516c4e57 | ||
|
|
cc2cd85ff5 | ||
|
|
6ec742539a | ||
|
|
09e0a54070 | ||
|
|
5d25199f42 | ||
|
|
387826674c | ||
|
|
02ae479636 | ||
|
|
a103324f25 | ||
|
|
643efc5d85 | ||
|
|
43e5798e13 | ||
|
|
8aca70cd50 | ||
|
|
2cf980026e | ||
|
|
224111081b | ||
|
|
4dc6cad588 | ||
|
|
f85e6a0dea | ||
|
|
4b3a54633f | ||
|
|
6f67a34349 | ||
|
|
e51d308312 | ||
|
|
379c92bd82 | ||
|
|
fa9f0ebfb1 | ||
|
|
ac917bb56d | ||
|
|
f7a4e5d1a6 | ||
|
|
515d34bbfb | ||
|
|
66de2e1f0a | ||
|
|
7f7ea92a45 | ||
|
|
a014345688 | ||
|
|
cf66d111ba | ||
|
|
2d01b1a808 | ||
|
|
739ebf2117 | ||
|
|
176b844cd5 | ||
|
|
8fc6684ab1 | ||
|
|
7c41f71248 | ||
|
|
2c2bfb4f54 | ||
|
|
3164f90327 | ||
|
|
90ac52482c | ||
|
|
879ac940dd | ||
|
|
796797d12b | ||
|
|
7ac0f0c08c | ||
|
|
5cc6a2bf33 | ||
|
|
2db0b19044 | ||
|
|
1d2ee9020c | ||
|
|
f2538bf381 | ||
|
|
f37e28a368 | ||
|
|
c5976f5a09 | ||
|
|
64a9181ee4 | ||
|
|
33cd32382f | ||
|
|
9456c59290 | ||
|
|
ce0bd421ae | ||
|
|
f9d04c6975 | ||
|
|
ecb07a5d0d | ||
|
|
a165ba2059 | ||
|
|
12fd2903d8 | ||
|
|
0a2c569b3b | ||
|
|
9ab0d5fe60 | ||
|
|
1d71fd5b56 | ||
|
|
b277acc298 | ||
|
|
8d47d8ce4f | ||
|
|
41fef8a21f | ||
|
|
b853a42e37 | ||
|
|
1633626d23 | ||
|
|
6c7a40c571 | ||
|
|
abb2ed66e7 | ||
|
|
5ae78f79b0 | ||
|
|
e3b3a6d040 | ||
|
|
6622ce6ad8 | ||
|
|
5ccb8d9736 | ||
|
|
55906c8375 | ||
|
|
0908f310fc | ||
|
|
58842898e1 | ||
|
|
1c17c8fa36 | ||
|
|
26aff400e4 | ||
|
|
4b11d29ede | ||
|
|
b2b95412b9 | ||
|
|
5c228bca4f | ||
|
|
7bd2509ad5 | ||
|
|
2a5d70d9e1 | ||
|
|
b0107f4128 | ||
|
|
dc3c5362e4 | ||
|
|
1d106c3660 | ||
|
|
fcb2fa04e7 | ||
|
|
55bff10f0d | ||
|
|
45c9b77e82 | ||
|
|
767860e76b | ||
|
|
80f656f79a | ||
|
|
c891eb28fc | ||
|
|
b9fa3f54e9 | ||
|
|
4d2f904d72 | ||
|
|
26b7911177 | ||
|
|
dd91edf70b | ||
|
|
d994e6b6c7 | ||
|
|
aba48bde0b | ||
|
|
3e5d9884cb | ||
|
|
faadad62ff | ||
|
|
406d70e4a3 | ||
|
|
f17f256b2b | ||
|
|
b367f48de6 | ||
|
|
dee7b6eb22 | ||
|
|
6f17200dec | ||
|
|
d3dbfbe8b3 | ||
|
|
b1f250862f | ||
|
|
141d6b1abf | ||
|
|
a7eb534761 | ||
|
|
808f792f55 | ||
|
|
346d066128 | ||
|
|
5c41922b8a | ||
|
|
9c3e3b00d0 | ||
|
|
da3a3ce165 | ||
|
|
b525bc2b81 | ||
|
|
14dc3e8642 | ||
|
|
e52c905aa5 | ||
|
|
7b9a3c1084 | ||
|
|
ce8ddae11e | ||
|
|
4e8184bc56 | ||
|
|
9eb8597957 | ||
|
|
cde584046d | ||
|
|
b7f9d7e94a | ||
|
|
88817bf974 | ||
|
|
92e6c52c0e | ||
|
|
309dfe8829 | ||
|
|
1d8b390584 | ||
|
|
7dea7f77ac | ||
|
|
4d9b15e519 | ||
|
|
45a708f17e | ||
|
|
5f08a9314c | ||
|
|
5802b2b437 | ||
|
|
f995436eec | ||
|
|
25f0c61e65 | ||
|
|
66fa68fa18 | ||
|
|
3e5781c6f1 | ||
|
|
a6f7560d2f | ||
|
|
45c76c1d68 | ||
|
|
14d5af468c | ||
|
|
874e1bc41d | ||
|
|
d2ae695b3b | ||
|
|
6ecdac6344 | ||
|
|
3c2ce07f38 | ||
|
|
5c58b11b22 | ||
|
|
be92122f17 | ||
|
|
2972a06f16 | ||
|
|
caa275fdbd | ||
|
|
5dbda7f4c5 | ||
|
|
0564651f6f | ||
|
|
eff8108f1c | ||
|
|
127a77d807 | ||
|
|
265842223c | ||
|
|
95a24156de | ||
|
|
80ca5b3356 | ||
|
|
e934503fa0 | ||
|
|
442bcd18c0 | ||
|
|
aeb1d1946c | ||
|
|
12f2913e08 | ||
|
|
0aeeee49f7 | ||
|
|
eb7479b1ea | ||
|
|
80b219707e | ||
|
|
65ac022245 | ||
|
|
6e6090d5a9 | ||
|
|
58b5daeef3 | ||
|
|
33fd1fa79d | ||
|
|
978118f770 | ||
|
|
a2610b22cc | ||
|
|
f4789d750d | ||
|
|
176f9ea2f4 | ||
|
|
5e71f7c825 | ||
|
|
7624edd32d | ||
|
|
7b79354849 | ||
|
|
a7ff2ab470 | ||
|
|
d3eedaf0ec | ||
|
|
bcb0496bf4 | ||
|
|
4d967544f3 | ||
|
|
c18ee4be50 | ||
|
|
65873aa411 | ||
|
|
b95256d624 | ||
|
|
c0d3452494 | ||
|
|
c91456de1b | ||
|
|
e1ce156433 | ||
|
|
9e19ed4e67 | ||
|
|
ba383b1b0d | ||
|
|
ad3d9cf782 | ||
|
|
69053332e4 | ||
|
|
5b4d04b348 | ||
|
|
47664f8fd3 | ||
|
|
8d8f21addd | ||
|
|
9b9640b3db | ||
|
|
83ba61203b | ||
|
|
fcbd5febeb | ||
|
|
b8813e199f | ||
|
|
2322496552 | ||
|
|
21a3509bef | ||
|
|
3e2f12b065 | ||
|
|
55e20d189a | ||
|
|
1aa13bd20d | ||
|
|
cc2dd052df | ||
|
|
4ffdf68a20 | ||
|
|
547bd3cc1b | ||
|
|
f3e9761c75 | ||
|
|
83ca59e0f1 | ||
|
|
d725aa8791 | ||
|
|
cc8ee0ff69 | ||
|
|
4a249c40b1 | ||
|
|
04e4a1e3aa | ||
|
|
d2d5fc62ae | ||
|
|
52460f6929 | ||
|
|
06dfc32e0f | ||
|
|
0ca38d8215 | ||
|
|
3da6becad3 | ||
|
|
f9d0a7bdc8 | ||
|
|
e961722597 | ||
|
|
2ddd2616ec | ||
|
|
a82a9fb9d4 | ||
|
|
3fce6f2581 | ||
|
|
3db864561e | ||
|
|
d2750f1a02 | ||
|
|
30a50c5cc8 | ||
|
|
0ff746ebf6 | ||
|
|
5193fa2118 | ||
|
|
9a0dc82e6a | ||
|
|
8e4165defe | ||
|
|
d917bc8ed0 | ||
|
|
ef7bd262c5 | ||
|
|
d3e29ffa74 | ||
|
|
70432952fd | ||
|
|
cf2ef93ad5 | ||
|
|
cbf0864edc | ||
|
|
bce2bdd0de | ||
|
|
82e7c8a2f9 | ||
|
|
2acdb0a4ea | ||
|
|
350ea6be6e | ||
|
|
4664174ef3 | ||
|
|
f0413f359a | ||
|
|
53b32c8b22 | ||
|
|
b8ef1d9585 | ||
|
|
90ca98ff3a | ||
|
|
d4a1d045f8 | ||
|
|
91fefa0e37 | ||
|
|
067ec17539 | ||
|
|
c084b57933 | ||
|
|
876be7e6e9 | ||
|
|
468bfdfed9 | ||
|
|
82d817f612 | ||
|
|
9e84a5321d | ||
|
|
d77e27ac05 | ||
|
|
8a86a2c817 | ||
|
|
fdc4c36b77 | ||
|
|
52c118f5b8 | ||
|
|
5d7c7023c3 | ||
|
|
3e0a10b7ed | ||
|
|
84f5272f72 | ||
|
|
6286f368f1 | ||
|
|
cb2ca0b533 | ||
|
|
5fe5da7c1d | ||
|
|
c83370f701 | ||
|
|
7506867fb9 | ||
|
|
842136959b | ||
|
|
4c2cc98ebc | ||
|
|
44b9f49ab1 | ||
|
|
f7f7952951 | ||
|
|
a7fa5044e3 | ||
|
|
eb84134706 | ||
|
|
fbca9010f3 | ||
|
|
0bf0c7dbe8 | ||
|
|
e071bd63e6 | ||
|
|
8a147a00e8 | ||
|
|
c9a4c66b07 | ||
|
|
edec654b68 | ||
|
|
a82ab1d152 | ||
|
|
9934eac15c | ||
|
|
c155afac29 | ||
|
|
7080c9f279 | ||
|
|
e41699cbc8 | ||
|
|
133193e7d0 | ||
|
|
9d6371e0a3 | ||
|
|
dfe091789c | ||
|
|
4c9bf78363 | ||
|
|
b95ecaf8a3 | ||
|
|
7a0e8108ae | ||
|
|
3afd5e73c9 | ||
|
|
c09c8c6e5b | ||
|
|
cab491795a | ||
|
|
e290ddc3e5 | ||
|
|
db154e33b7 | ||
|
|
32f9004b5f | ||
|
|
225402280e | ||
|
|
abcca11479 | ||
|
|
9cdd2cbb27 | ||
|
|
309fffd1e4 | ||
|
|
0a9f50e85f | ||
|
|
ed1d71f4d0 | ||
|
|
7039ec33b9 | ||
|
|
025dc7c781 | ||
|
|
4130c50643 | ||
|
|
7b7f8ef51d | ||
|
|
bad451d5ec | ||
|
|
87c15062e6 | ||
|
|
573cd15e77 | ||
|
|
ab1730bbaa | ||
|
|
163bae3aaf | ||
|
|
270edd43ab | ||
|
|
b8f3b23b1a | ||
|
|
b9c6496fea | ||
|
|
0486aa3445 | ||
|
|
5fb771218c | ||
|
|
3fb02a7933 | ||
|
|
898495b5c4 | ||
|
|
08624878cf | ||
|
|
6fe473f0fa | ||
|
|
11cf23e5fc | ||
|
|
631768ea1d | ||
|
|
e1d658b482 | ||
|
|
1274aaed5d | ||
|
|
9be036e0ca | ||
|
|
7284569c5f | ||
|
|
976b465e76 | ||
|
|
804e55824d | ||
|
|
69529fb16d | ||
|
|
cb5cfb2dae | ||
|
|
a826879cf7 | ||
|
|
e7c48c0b69 | ||
|
|
558a280fc8 | ||
|
|
2158c03231 | ||
|
|
a61f1f8eb0 | ||
|
|
9f724c19db | ||
|
|
4ae936b263 | ||
|
|
80875a109a | ||
|
|
121e54f3e3 | ||
|
|
1c2c4b62f8 | ||
|
|
9176790adf | ||
|
|
6ff6525d1d | ||
|
|
71ce505631 | ||
|
|
11dfe3713f | ||
|
|
a025db137d | ||
|
|
797d044714 | ||
|
|
c4169f8aa0 | ||
|
|
3005419573 | ||
|
|
7f59ffe7af | ||
|
|
cc7ad5ac97 | ||
|
|
769b5e185a | ||
|
|
9e763c9e87 | ||
|
|
b9214ca76b | ||
|
|
29d2f2339b | ||
|
|
5ac1e3584d | ||
|
|
dd0cf6fadc | ||
|
|
b320ebe2ba | ||
|
|
377093b776 | ||
|
|
70119a054a | ||
|
|
69d1e3ec7d | ||
|
|
365157c37d | ||
|
|
4bc0a1bd37 | ||
|
|
d6640f2adf | ||
|
|
987f845e79 | ||
|
|
84daf49047 | ||
|
|
31e183ef0d | ||
|
|
754a1d1197 | ||
|
|
049a6de4b3 | ||
|
|
6bd28cadc4 | ||
|
|
3b9a0b1d25 | ||
|
|
db963a638c | ||
|
|
dcb4c9e84a | ||
|
|
5fc2bc58a9 | ||
|
|
d333645e09 | ||
|
|
2812c774c6 | ||
|
|
e2f3f0ae4c | ||
|
|
83ca7f8deb | ||
|
|
e6c6fa8ed8 | ||
|
|
678d6ffe2b | ||
|
|
cef77a3717 | ||
|
|
28726b6cf3 | ||
|
|
ef0e41de07 | ||
|
|
dc2b63b832 | ||
|
|
0478fc9649 | ||
|
|
1b07e612d2 | ||
|
|
38cce3f62a | ||
|
|
35be8721b9 | ||
|
|
665ffbdc10 | ||
|
|
b5f88c77a3 | ||
|
|
324c0d7b4c | ||
|
|
13e3f17493 | ||
|
|
841bd35ebb | ||
|
|
ccefd41606 | ||
|
|
ec1c4efca9 | ||
|
|
0f10852b6b | ||
|
|
6d547447d3 | ||
|
|
6123f1ab21 | ||
|
|
e7370766bd | ||
|
|
db4958be05 | ||
|
|
a15bf8e8fe | ||
|
|
70d2c78176 | ||
|
|
42fcda3dc8 | ||
|
|
ac049d938e | ||
|
|
3af61f4b5d | ||
|
|
e19adbbbc5 | ||
|
|
64d997fdb0 | ||
|
|
a49942b949 | ||
|
|
4460d96e58 | ||
|
|
a7d5f2f53b | ||
|
|
c9bf99a1e2 | ||
|
|
4300ebc8aa | ||
|
|
720ce79901 | ||
|
|
693107a6c8 | ||
|
|
583db24ee7 | ||
|
|
7d92574e02 | ||
|
|
5aaa06c8b0 | ||
|
|
52b773770b | ||
|
|
23adc7d8a8 | ||
|
|
e3708bfa85 | ||
|
|
7d65e9980c | ||
|
|
b93d26ee9f | ||
|
|
b82b26bba5 | ||
|
|
21c24977d8 | ||
|
|
fe435c23c3 | ||
|
|
ead1209f98 | ||
|
|
3994bb1771 | ||
|
|
327690e4a7 | ||
|
|
c2a7e0e986 | ||
|
|
faf6b9ea03 | ||
|
|
3bfc602561 | ||
|
|
5fa2aca2c8 | ||
|
|
69a60101fe | ||
|
|
b18519b824 | ||
|
|
0d01025254 | ||
|
|
eef1542cb3 | ||
|
|
9aef4b6d6b | ||
|
|
7dba83754f | ||
|
|
e2585bc778 | ||
|
|
cc6e2558ef | ||
|
|
20343facad | ||
|
|
eff123a11c | ||
|
|
9bafd3a226 | ||
|
|
82be119fec | ||
|
|
a64df507f6 | ||
|
|
cf73faf174 | ||
|
|
ba52bf27c1 | ||
|
|
55f4177b01 | ||
|
|
14a9052d60 | ||
|
|
314a2f9be8 | ||
|
|
8eee344fbb | ||
|
|
0e0a266142 | ||
|
|
7bce35913d | ||
|
|
7898dbd5bf | ||
|
|
bd1073ff1a | ||
|
|
1bbd572593 | ||
|
|
5199297f61 | ||
|
|
c5a2f43ceb | ||
|
|
8d4ced227e | ||
|
|
f481075f8f | ||
|
|
836cf6453e | ||
|
|
8bea88c8cc | ||
|
|
4b7274f9a5 | ||
|
|
7de5585da6 | ||
|
|
87dc80f6fa | ||
|
|
a008c04331 | ||
|
|
56b66b8a57 | ||
|
|
35a7add4e9 | ||
|
|
f1fe143962 | ||
|
|
9e72afee3c | ||
|
|
613b94a6e6 | ||
|
|
7b0d38f7d3 | ||
|
|
4ff971c8a3 | ||
|
|
019ef74bf2 | ||
|
|
2670557258 | ||
|
|
93ac6d37e9 | ||
|
|
e710a8402c | ||
|
|
360f8a3375 | ||
|
|
818eb46a8b | ||
|
|
f5c297708b | ||
|
|
bf8324f7f7 | ||
|
|
b730d153ea | ||
|
|
11977596c9 | ||
|
|
612dca8b7d | ||
|
|
53018289d4 | ||
|
|
958ff44707 | ||
|
|
d910770b3c | ||
|
|
5a8f10520f | ||
|
|
df928772c0 | ||
|
|
b713218cab | ||
|
|
9ea2123e7f | ||
|
|
de0cb06f8c | ||
|
|
cfb6d59513 | ||
|
|
4c30d1c1eb | ||
|
|
5bb02c79cc | ||
|
|
0a891e5392 | ||
|
|
f6978ce6b1 | ||
|
|
4d68aadc1c | ||
|
|
cef6463847 | ||
|
|
39b8331f81 | ||
|
|
212d4c5899 | ||
|
|
97ec855df4 | ||
|
|
d83b9b70e3 | ||
|
|
b51c18c2cf | ||
|
|
7e31da7882 | ||
|
|
d9ed61287d | ||
|
|
6024dbe98d | ||
|
|
13ce6317f1 | ||
|
|
0099f2296d | ||
|
|
2d93bc6725 | ||
|
|
cb52f9ecc5 | ||
|
|
1fbeb3a21a | ||
|
|
38f1a42ce8 | ||
|
|
3d11af2dd6 | ||
|
|
d1fd5db7f8 | ||
|
|
c240cf3bb1 | ||
|
|
bbbcd68258 | ||
|
|
7ce9710229 | ||
|
|
3f7f21ce70 | ||
|
|
fa8ab4ea04 | ||
|
|
3f1363503b | ||
|
|
3f52f491d7 | ||
|
|
e86a3fc672 | ||
|
|
6f77f67427 | ||
|
|
4025cd0b46 | ||
|
|
3bbb22750c | ||
|
|
d196872059 | ||
|
|
a478d95950 | ||
|
|
12c060b795 | ||
|
|
c480c3d881 | ||
|
|
a998022c12 | ||
|
|
a25cc4e8af | ||
|
|
b4bccf5fef | ||
|
|
14ad34af71 | ||
|
|
7ed398267f | ||
|
|
fc9556e057 | ||
|
|
acf6872a50 | ||
|
|
e689f21a60 | ||
|
|
a7f9259e27 | ||
|
|
a46b4e3616 | ||
|
|
e7e12c1d2e | ||
|
|
66176c4d71 | ||
|
|
2613a380b6 | ||
|
|
9392ce259f | ||
|
|
d1287f08b4 | ||
|
|
7ee8472a5f | ||
|
|
cdb615deeb | ||
|
|
abbba1d004 | ||
|
|
3c386c63a6 | ||
|
|
49d1846e63 | ||
|
|
53f2882077 | ||
|
|
8f07e088f5 | ||
|
|
f71b0eccb2 | ||
|
|
5b89d36ea1 | ||
|
|
7c3af74b0d | ||
|
|
d1d83f8a2a | ||
|
|
839fe12087 | ||
|
|
fd8ee9f53e | ||
|
|
c2d02f8f4d | ||
|
|
8367ae85de | ||
|
|
d1f0e6e5c2 | ||
|
|
7deb44f864 | ||
|
|
d12e9b81e3 | ||
|
|
b1fbaaed95 | ||
|
|
3f8b0b937c | ||
|
|
734c62998f | ||
|
|
4792ca1813 | ||
|
|
d4007ae073 | ||
|
|
389f15f8e3 | ||
|
|
9437145218 | ||
|
|
076924bbd6 | ||
|
|
97cf6b2d65 | ||
|
|
f317ef2fe2 | ||
|
|
f7de55364f | ||
|
|
de30e9278c | ||
|
|
b9ab1555fb | ||
|
|
44b9ce0951 | ||
|
|
d768094376 | ||
|
|
93f83086c1 | ||
|
|
8d9c252811 | ||
|
|
c7f4b41920 | ||
|
|
efb27eb443 | ||
|
|
5b8c43052e | ||
|
|
e04ae927b6 | ||
|
|
ac68d62d1c | ||
|
|
caa17b8fe9 | ||
|
|
cd1562ee24 | ||
|
|
47af1a9c42 | ||
|
|
0cd6a427af | ||
|
|
51165408ed | ||
|
|
a2dc38f90a | ||
|
|
a36436b585 | ||
|
|
2d87823fc6 | ||
|
|
d238da9826 | ||
|
|
6eef5990c9 | ||
|
|
5c4bf2a9e4 | ||
|
|
0345eb4659 | ||
|
|
71f78e0d33 | ||
|
|
942648e9e9 | ||
|
|
d841581679 | ||
|
|
9f8e05d9f0 | ||
|
|
3340775052 | ||
|
|
9987774471 |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "deploy/rag-dev"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
7
.github/workflows/deploy-dev.yml
vendored
7
.github/workflows/deploy-dev.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/dev"
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -12,12 +12,13 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
27
api/app.py
27
api/app.py
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@@ -17,20 +16,20 @@ else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="database",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
||||
@@ -3,6 +3,7 @@ from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
@@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
@@ -76,7 +76,6 @@ from .billing import billing, compliance
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import (
|
||||
data_source,
|
||||
datasets,
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
@@ -85,6 +84,14 @@ from .datasets import (
|
||||
metadata,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
datasource_auth,
|
||||
datasource_content_preview,
|
||||
rag_pipeline,
|
||||
rag_pipeline_datasets,
|
||||
rag_pipeline_import,
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
||||
@@ -283,6 +283,15 @@ class DatasetApi(Resource):
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from datetime import UTC, datetime
|
||||
@@ -51,6 +52,7 @@ from fields.document_fields import (
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
@@ -661,7 +663,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@@ -1028,6 +1030,41 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
"datasource_info": None,
|
||||
"datasource_type": None,
|
||||
"input_data": None,
|
||||
"datasource_node_id": None,
|
||||
}, 200
|
||||
return {
|
||||
"datasource_info": json.loads(log.datasource_info),
|
||||
"datasource_type": log.datasource_type,
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
@@ -1050,3 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
||||
@@ -101,3 +101,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
||||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
||||
|
||||
197
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
197
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from flask import redirect, request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
|
||||
class DatasourcePluginOauthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# get all plugin oauth configs
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_url = (
|
||||
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
|
||||
)
|
||||
system_credentials = plugin_oauth_config.system_credentials
|
||||
if system_credentials:
|
||||
system_credentials["redirect_url"] = redirect_url
|
||||
response = oauth_handler.get_authorization_url(
|
||||
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
class DatasourceOauthCallback(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
credentials = oauth_handler.get_credentials(
|
||||
current_user.current_tenant.id,
|
||||
current_user.id,
|
||||
plugin_id,
|
||||
provider,
|
||||
system_credentials=plugin_oauth_config.system_credentials,
|
||||
request=request,
|
||||
)
|
||||
datasource_provider = DatasourceProvider(
|
||||
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.datasource_provider_credentials_validate(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
class DatasourceAuthUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, auth_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=auth_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, auth_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=auth_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
DatasourcePluginOauthApi,
|
||||
"/oauth/plugin/datasource",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOauthCallback,
|
||||
"/oauth/plugin/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateDeleteApi,
|
||||
"/auth/plugin/datasource/<string:auth_id>",
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
from flask_restful import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
162
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
Normal file
162
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
@@ -0,0 +1,171 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default={},
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"partial_member_list",
|
||||
type=list,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=[],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||
try:
|
||||
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
|
||||
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default={},
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"partial_member_list",
|
||||
type=list,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
@@ -0,0 +1,146 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
43
api/controllers/console/datasets/wraps.py
Normal file
43
api/controllers/console/datasets/wraps.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Optional[Callable] = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@@ -113,9 +113,9 @@ class VariableEntity(BaseModel):
|
||||
hide: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@@ -128,6 +128,16 @@ class VariableEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: Optional[str] = None
|
||||
placeholder: Optional[str] = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
@@ -285,7 +295,7 @@ class AppConfig(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@@ -20,3 +20,19 @@ class WorkflowVariablesConfigManager:
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
Convert workflow start variables to variables
|
||||
|
||||
:param workflow: workflow instance
|
||||
"""
|
||||
variables = []
|
||||
|
||||
user_input_form = workflow.rag_pipeline_user_input_form()
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
variables.append(RagPipelineVariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@@ -43,11 +43,13 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import (
|
||||
@@ -183,6 +185,14 @@ class WorkflowResponseConverter:
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
node_data = cast(DatasourceNodeData, event.node_data)
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.icon
|
||||
|
||||
return response
|
||||
|
||||
|
||||
0
api/core/app/apps/pipeline/__init__.py
Normal file
0
api/core/app/apps/pipeline/__init__.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
64
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
64
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Pipeline Config Entity.
|
||||
"""
|
||||
|
||||
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||
pass
|
||||
|
||||
|
||||
class PipelineConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
|
||||
pipeline_config = PipelineConfig(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
workflow_id=workflow.id,
|
||||
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
|
||||
)
|
||||
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
:param only_structure_validate: only validate the structure of the config
|
||||
"""
|
||||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
621
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
621
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
)
|
||||
# Add null check for dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||
documents = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
document_id = documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=json.dumps(datasource_info),
|
||||
datasource_node_id=start_node_id,
|
||||
input_data=inputs,
|
||||
pipeline_id=pipeline.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
db.session.commit()
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=dataset.id,
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=pipeline_config.rag_pipeline_variables,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=contextvars.copy_context(),
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
# run in child thread
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"pipeline": pipeline,
|
||||
"workflow_id": workflow.id,
|
||||
"user": user,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"invoke_from": invoke_from,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"streaming": streaming,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
"dataset": PipelineDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
).model_dump(),
|
||||
"documents": [
|
||||
PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_type=document.data_source_type,
|
||||
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump()
|
||||
for document in documents
|
||||
],
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param pipeline: Pipeline
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"queue_manager": queue_manager,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
dataset_id=dataset.id,
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
dataset_id=dataset.id,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
# workflow app
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(
|
||||
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Union[Account, EndUser],
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info["name"]
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info["page"]["page_name"]
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info["title"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=position,
|
||||
data_source_type=datasource_type,
|
||||
data_source_info=json.dumps(datasource_info),
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.source: datasource_type,
|
||||
}
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
44
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
44
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class PipelineQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
221
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
221
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Pipeline Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(PipelineConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
|
||||
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
||||
}
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
) and rag_pipeline_variable.variable in inputs:
|
||||
rag_pipeline_variables.append(
|
||||
RAGPipelineVariableInput(
|
||||
variable=rag_pipeline_variable,
|
||||
value=inputs[rag_pipeline_variable.variable],
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_config=workflow.graph_dict,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
nodes = graph_config.get("nodes", [])
|
||||
edges = graph_config.get("edges", [])
|
||||
real_run_nodes = []
|
||||
real_edges = []
|
||||
exclude_node_ids = []
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("data", {}).get("type", "")
|
||||
if node_type == "datasource":
|
||||
if start_node_id != node_id:
|
||||
exclude_node_ids.append(node_id)
|
||||
continue
|
||||
real_run_nodes.append(node)
|
||||
for edge in edges:
|
||||
if edge.get("source") in exclude_node_ids:
|
||||
continue
|
||||
real_edges.append(edge)
|
||||
graph_config = dict(graph_config)
|
||||
graph_config["nodes"] = real_run_nodes
|
||||
graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
@@ -36,6 +36,7 @@ class InvokeFrom(Enum):
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@@ -240,3 +241,38 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
@@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
|
||||
self.current_loop += 1
|
||||
|
||||
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||
"""Run on datasource start."""
|
||||
if dify_config.DEBUG:
|
||||
print_text(
|
||||
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
|
||||
color=self.color,
|
||||
)
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
|
||||
33
api/core/datasource/__base/datasource_plugin.py
Normal file
33
api/core/datasource/__base/datasource_plugin.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class DatasourcePlugin(ABC):
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> str:
|
||||
"""
|
||||
returns the type of the datasource provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
118
api/core/datasource/__base/datasource_provider.py
Normal file
118
api/core/datasource/__base/datasource_provider.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
returns whether the provider needs credentials
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@abstractmethod
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
36
api/core/datasource/__base/datasource_runtime.py
Normal file
36
api/core/datasource/__base/datasource_runtime.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
|
||||
class DatasourceRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a datasource call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
datasource_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||
"""
|
||||
Fake datasource runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
datasource_id="fake_datasource_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
||||
0
api/core/datasource/__init__.py
Normal file
0
api/core/datasource/__init__.py
Normal file
244
api/core/datasource/datasource_file_manager.py
Normal file
244
api/core/datasource/datasource_file_manager.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceFileManager:
|
||||
@staticmethod
|
||||
def sign_file(datasource_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"datasources/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||
source_url="",
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = (
|
||||
guess_type(file_url)[0]
|
||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
or "application/octet-stream"
|
||||
)
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||
source_url=file_url,
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(upload_file.key)
|
||||
|
||||
return blob, upload_file.mime_type
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_upload_file_id(upload_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == upload_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None, None
|
||||
|
||||
stream = storage.load_stream(upload_file.key)
|
||||
|
||||
return stream, upload_file.mime_type
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
# from core.file.datasource_file_parser import datasource_file_manager
|
||||
#
|
||||
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||
100
api/core/datasource/datasource_manager.py
Normal file
100
api/core/datasource/datasource_manager.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_datasource_plugin_provider(
|
||||
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
) -> DatasourcePluginProviderController:
|
||||
"""
|
||||
get the datasource plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.datasource_plugin_providers.get()
|
||||
except LookupError:
|
||||
contexts.datasource_plugin_providers.set({})
|
||||
contexts.datasource_plugin_providers_lock.set(Lock())
|
||||
|
||||
with contexts.datasource_plugin_providers_lock.get():
|
||||
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||
if provider_id in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider_id]
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
datasource_plugin_providers[provider_id] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_datasource_runtime(
|
||||
cls,
|
||||
provider_id: str,
|
||||
datasource_name: str,
|
||||
tenant_id: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> DatasourcePlugin:
|
||||
"""
|
||||
get the datasource runtime
|
||||
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:param datasource_name: the name of the datasource
|
||||
:param tenant_id: the tenant id
|
||||
|
||||
:return: the datasource plugin
|
||||
"""
|
||||
return cls.get_datasource_plugin_provider(
|
||||
provider_id,
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
71
api/core/datasource/entities/api_entities.py
Normal file
71
api/core/datasource/entities/api_entities.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("datasources", mode="before")
|
||||
@classmethod
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
for datasource in datasources:
|
||||
if datasource.get("parameters"):
|
||||
for parameter in datasource.get("parameters"):
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
23
api/core/datasource/entities/common_entities.py
Normal file
23
api/core/datasource/entities/common_entities.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: Optional[str] = Field(default=None)
|
||||
pt_BR: Optional[str] = Field(default=None)
|
||||
ja_JP: Optional[str] = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
361
api/core/datasource/entities/datasource_entities.py
Normal file
361
api/core/datasource/entities/datasource_entities.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
|
||||
|
||||
|
||||
class DatasourceProviderType(enum.StrEnum):
|
||||
"""
|
||||
Enum class for datasource provider
|
||||
"""
|
||||
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
ONLINE_DRIVE = "online_drive"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class DatasourceParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class DatasourceParameterType(enum.StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
description: I18nObject = Field(..., description="The description of the parameter")
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "DatasourceParameter":
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param typ: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
type=typ,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the datasource")
|
||||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class DatasourceProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class DatasourceProviderEntity(BaseModel):
|
||||
"""
|
||||
Datasource provider entity
|
||||
"""
|
||||
|
||||
identity: DatasourceProviderIdentity
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
"""
|
||||
Get an empty instance of DatasourceInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||
"""
|
||||
Get an instance of DatasourceInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
|
||||
|
||||
class DatasourceLabel(BaseModel):
|
||||
"""
|
||||
Datasource label
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class OnlineDocumentPage(BaseModel):
|
||||
"""
|
||||
Online document page
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
workspace_name: str = Field(..., description="The workspace name")
|
||||
workspace_icon: str = Field(..., description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
||||
class OnlineDocumentPagesMessage(BaseModel):
|
||||
"""
|
||||
Get online document pages response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||
"""
|
||||
Get online document page content request
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
|
||||
|
||||
class OnlineDocumentPageContent(BaseModel):
|
||||
"""
|
||||
Online document page content
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
content: str = Field(..., description="The content of the page")
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||
"""
|
||||
Get online document page content response
|
||||
"""
|
||||
|
||||
result: OnlineDocumentPageContent
|
||||
|
||||
|
||||
class GetWebsiteCrawlRequest(BaseModel):
|
||||
"""
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfoDetail(BaseModel):
|
||||
source_url: str = Field(..., description="The url of the website")
|
||||
content: str = Field(..., description="The content of the website")
|
||||
title: str = Field(..., description="The title of the website")
|
||||
description: str = Field(..., description="The description of the website")
|
||||
|
||||
|
||||
class WebSiteInfo(BaseModel):
|
||||
"""
|
||||
Website info
|
||||
"""
|
||||
|
||||
status: Optional[str] = Field(..., description="crawl job status")
|
||||
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||
|
||||
|
||||
class WebsiteCrawlMessage(BaseModel):
|
||||
"""
|
||||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
|
||||
|
||||
|
||||
class DatasourceMessage(ToolInvokeMessage):
|
||||
pass
|
||||
|
||||
|
||||
#########################
|
||||
# Online driver file
|
||||
#########################
|
||||
|
||||
|
||||
class OnlineDriveFile(BaseModel):
|
||||
"""
|
||||
Online driver file
|
||||
"""
|
||||
|
||||
key: str = Field(..., description="The key of the file")
|
||||
size: int = Field(..., description="The size of the file")
|
||||
|
||||
|
||||
class OnlineDriveFileBucket(BaseModel):
|
||||
"""
|
||||
Online driver file bucket
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The bucket of the file")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The files of the bucket")
|
||||
is_truncated: bool = Field(False, description="Whether the bucket has more files")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
"""
|
||||
Get online driver file list request
|
||||
"""
|
||||
|
||||
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
|
||||
bucket: Optional[str] = Field(None, description="Storage bucket name")
|
||||
max_keys: int = Field(20, description="Maximum number of files to return")
|
||||
start_after: Optional[str] = Field(
|
||||
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
|
||||
)
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
"""
|
||||
Get online driver file list response
|
||||
"""
|
||||
|
||||
result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files")
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
Get online driver file
|
||||
"""
|
||||
|
||||
key: str = Field(..., description="The name of the file")
|
||||
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||
37
api/core/datasource/errors.py
Normal file
37
api/core/datasource/errors.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||
|
||||
|
||||
class DatasourceProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceEngineInvokeError(Exception):
|
||||
meta: DatasourceInvokeMeta
|
||||
|
||||
def __init__(self, meta, **kwargs):
|
||||
self.meta = meta
|
||||
super().__init__(**kwargs)
|
||||
28
api/core/datasource/local_file/local_file_plugin.py
Normal file
28
api/core/datasource/local_file/local_file_plugin.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
56
api/core/datasource/local_file/local_file_provider.py
Normal file
56
api/core/datasource/local_file/local_file_provider.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return LocalFileDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_pages(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_page_content(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
@@ -0,0 +1,48 @@
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return OnlineDocumentDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
73
api/core/datasource/online_drive/online_drive_plugin.py
Normal file
73
api/core/datasource/online_drive/online_drive_plugin.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def online_drive_browse_files(
|
||||
self,
|
||||
user_id: str,
|
||||
request: OnlineDriveBrowseFilesRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.online_drive_browse_files(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def online_drive_download_file(
|
||||
self,
|
||||
user_id: str,
|
||||
request: OnlineDriveDownloadFileRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.online_drive_download_file(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
48
api/core/datasource/online_drive/online_drive_provider.py
Normal file
48
api/core/datasource/online_drive/online_drive_provider.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return OnlineDriveDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
0
api/core/datasource/utils/__init__.py
Normal file
0
api/core/datasource/utils/__init__.py
Normal file
265
api/core/datasource/utils/configuration.py
Normal file
265
api/core/datasource/utils/configuration.py
Normal file
@@ -0,0 +1,265 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
121
api/core/datasource/utils/message_transformer.py
Normal file
121
api/core/datasource/utils/message_transformer.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_datasource_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Transform datasource message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, DatasourceMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
file = DatasourceFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# get filename from meta
|
||||
filename = meta.get("file_name", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||
389
api/core/datasource/utils/parser.py
Normal file
389
api/core/datasource/utils/parser.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
17
api/core/datasource/utils/text_processing_utils.py
Normal file
17
api/core/datasource/utils/text_processing_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
9
api/core/datasource/utils/uuid_utils.py
Normal file
9
api/core/datasource/utils/uuid_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
43
api/core/datasource/utils/workflow_configuration_sync.py
Normal file
43
api/core/datasource/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
35
api/core/datasource/utils/yaml_utils.py
Normal file
35
api/core/datasource/utils/yaml_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
53
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
53
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_website_crawl(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_website_crawl(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
52
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
52
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceProviderEntityWithPlugin,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return WebsiteCrawlDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@@ -17,3 +17,27 @@ class IndexingEstimate(BaseModel):
|
||||
total_segments: int
|
||||
preview: list[PreviewDetail]
|
||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||
|
||||
|
||||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
chunk_structure: str
|
||||
|
||||
|
||||
class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: Optional[dict] = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: Optional[str] = None
|
||||
enabled: bool
|
||||
|
||||
|
||||
class PipelineGenerateResponse(BaseModel):
|
||||
batch: str
|
||||
dataset: PipelineDataset
|
||||
documents: list[PipelineDocument]
|
||||
|
||||
15
api/core/file/datasource_file_parser.py
Normal file
15
api/core/file/datasource_file_parser.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.datasource import datasource_file_manager
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
|
||||
|
||||
class DatasourceFileParser:
|
||||
@staticmethod
|
||||
def get_datasource_file_manager() -> "DatasourceFileManager":
|
||||
return cast("DatasourceFileManager", datasource_file_manager["manager"])
|
||||
@@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
DATASOURCE_FILE = "datasource_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
||||
@@ -135,3 +135,4 @@ class TraceTaskName(StrEnum):
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
DATASOURCE_TRACE = "datasource"
|
||||
|
||||
21
api/core/plugin/entities/oauth.py
Normal file
21
api/core/plugin/entities/oauth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
"""
|
||||
OAuth schema
|
||||
"""
|
||||
|
||||
client_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="client schema like client_id, client_secret, etc.",
|
||||
)
|
||||
|
||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="credentials schema like access_token, refresh_token, etc.",
|
||||
)
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
@@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum):
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
AgentStrategy = "agent-strategy"
|
||||
Datasource = "datasource"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
@@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
datasources: Optional[list[str]] = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
@@ -90,6 +93,7 @@ class PluginDeclaration(BaseModel):
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
datasource: Optional[DatasourceProviderEntity] = None
|
||||
meta: Meta
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -100,6 +104,8 @@ class PluginDeclaration(BaseModel):
|
||||
values["category"] = PluginCategory.Tool
|
||||
elif values.get("model"):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("datasource"):
|
||||
values["category"] = PluginCategory.Datasource
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
else:
|
||||
@@ -193,6 +199,11 @@ class ToolProviderID(GenericProviderID):
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class DatasourceProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Generic, Optional, TypeVar
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
@@ -48,6 +49,14 @@ class PluginToolProviderEntity(BaseModel):
|
||||
declaration: ToolProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginDatasourceProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
is_authorized: bool = False
|
||||
declaration: DatasourceProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginAgentProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
329
api/core/plugin/impl/datasource.py
Normal file
329
api/core/plugin/impl/datasource.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for datasource in declaration.get("datasources", []):
|
||||
datasource["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
all_response = [local_file_datasource_provider] + response
|
||||
|
||||
for provider in all_response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.datasources:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return all_response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||
datasource["identity"]["provider"] = tool_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasource",
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for datasource in response.declaration.datasources:
|
||||
datasource.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def get_website_crawl(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
|
||||
WebsiteCrawlMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
|
||||
OnlineDocumentPagesMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
|
||||
DatasourceMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"page": datasource_parameters.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def online_drive_browse_files(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
request: OnlineDriveBrowseFilesRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"request": request.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
yield from response
|
||||
|
||||
def online_drive_download_file(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
request: OnlineDriveDownloadFileRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
|
||||
DatasourceMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"request": request.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
yield from response
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
# datasource_provider_id = GenericProviderID(provider_id)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "langgenius/file/file",
|
||||
"plugin_id": "langgenius/file",
|
||||
"provider": "file",
|
||||
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||
"declaration": {
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"provider_type": "local_file",
|
||||
"datasources": [
|
||||
{
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload-file",
|
||||
"provider": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"parameters": [],
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -4,7 +4,10 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginToolProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
||||
@@ -197,6 +200,36 @@ class PluginToolManager(BasePluginClient):
|
||||
|
||||
return False
|
||||
|
||||
def validate_datasource_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the datasource
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -28,10 +28,12 @@ class Jieba(BaseKeyword):
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
@@ -49,18 +51,17 @@ class Jieba(BaseKeyword):
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get("keywords_list")
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
if not keywords:
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
@@ -239,7 +240,11 @@ class Jieba(BaseKeyword):
|
||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
|
||||
keyword_number = (
|
||||
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||
)
|
||||
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, segment.index_node_id, list(keywords)
|
||||
|
||||
38
api/core/rag/entities/event.py
Normal file
38
api/core/rag/entities/event.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatasourceStreamEvent(Enum):
|
||||
"""
|
||||
Datasource Stream event
|
||||
"""
|
||||
|
||||
PROCESSING = "datasource_processing"
|
||||
COMPLETED = "datasource_completed"
|
||||
ERROR = "datasource_error"
|
||||
|
||||
|
||||
class BaseDatasourceEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceErrorEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.ERROR.value
|
||||
error: str = Field(..., description="error message")
|
||||
|
||||
|
||||
class DatasourceCompletedEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.COMPLETED.value
|
||||
data: Mapping[str, Any] | list = Field(..., description="result")
|
||||
total: Optional[int] = Field(default=0, description="total")
|
||||
completed: Optional[int] = Field(default=0, description="completed")
|
||||
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
|
||||
|
||||
|
||||
class DatasourceProcessingEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.PROCESSING.value
|
||||
total: Optional[int] = Field(..., description="total")
|
||||
completed: Optional[int] = Field(..., description="completed")
|
||||
@@ -13,3 +13,5 @@ class MetadataDataSource(Enum):
|
||||
upload_file = "file_upload"
|
||||
website_crawl = "website"
|
||||
notion_import = "notion"
|
||||
local_file = "file_upload"
|
||||
online_document = "online_document"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -13,6 +14,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
@@ -33,6 +35,14 @@ class BaseIndexProcessor(ABC):
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(
|
||||
self,
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.models.document import Document, GeneralStructureChunk
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
|
||||
@@ -127,3 +130,34 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
paragraph = GeneralStructureChunk(**chunks)
|
||||
documents = []
|
||||
for content in paragraph.general_chunks:
|
||||
metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": document.id,
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"doc_hash": helper.generate_text_hash(content),
|
||||
}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
documents.append(doc)
|
||||
if documents:
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
paragraph = GeneralStructureChunk(**chunks)
|
||||
preview = []
|
||||
for content in paragraph.general_chunks:
|
||||
preview.append({"content": content})
|
||||
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
|
||||
|
||||
@@ -202,3 +205,40 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_document.page_content = child_page_content
|
||||
child_nodes.append(child_document)
|
||||
return child_nodes
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
documents = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": document.id,
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"doc_hash": helper.generate_text_hash(parent_child.parent_content),
|
||||
}
|
||||
child_documents = []
|
||||
for child in parent_child.child_contents:
|
||||
child_metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": document.id,
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"doc_hash": helper.generate_text_hash(child),
|
||||
}
|
||||
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
|
||||
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
|
||||
documents.append(doc)
|
||||
if documents:
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=True)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
parent_childs = ParentChildStructureChunk(**chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}
|
||||
|
||||
@@ -4,7 +4,8 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@@ -14,13 +15,15 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.models.document import Document, QAStructureChunk
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
|
||||
|
||||
@@ -161,6 +164,36 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||
qa_chunks = QAStructureChunk(**chunks)
|
||||
documents = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
metadata = {
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": document.id,
|
||||
"doc_id": str(uuid.uuid4()),
|
||||
"doc_hash": helper.generate_text_hash(qa_chunk.question),
|
||||
"answer": qa_chunk.answer,
|
||||
}
|
||||
doc = Document(page_content=qa_chunk.question, metadata=metadata)
|
||||
documents.append(doc)
|
||||
if documents:
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
else:
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
qa_chunks = QAStructureChunk(**chunks)
|
||||
preview = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
return {"qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks)}
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
|
||||
@@ -35,6 +35,48 @@ class Document(BaseModel):
|
||||
children: Optional[list[ChildDocument]] = None
|
||||
|
||||
|
||||
class GeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunks: list[str]
|
||||
|
||||
|
||||
class ParentChildChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Chunk.
|
||||
"""
|
||||
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class ParentChildStructureChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Structure Chunk.
|
||||
"""
|
||||
|
||||
parent_child_chunks: list[ParentChildChunk]
|
||||
|
||||
|
||||
class QAChunk(BaseModel):
|
||||
"""
|
||||
QA Chunk.
|
||||
"""
|
||||
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class QAStructureChunk(BaseModel):
|
||||
"""
|
||||
QAStructureChunk.
|
||||
"""
|
||||
|
||||
qa_chunks: list[QAChunk]
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Abstract base class for document transformation systems.
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ class RetrievalMethod(Enum):
|
||||
SEMANTIC_SEARCH = "semantic_search"
|
||||
FULL_TEXT_SEARCH = "full_text_search"
|
||||
HYBRID_SEARCH = "hybrid_search"
|
||||
KEYWORD_SEARCH = "keyword_search"
|
||||
|
||||
@staticmethod
|
||||
def is_support_semantic_search(retrieval_method: str) -> bool:
|
||||
|
||||
@@ -262,6 +262,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||
@@ -283,7 +284,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
@@ -317,6 +318,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
@@ -334,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
A list of NodeExecution instances
|
||||
"""
|
||||
# Get the database models using the new method
|
||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
|
||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from)
|
||||
|
||||
# Convert database models to domain models
|
||||
domain_models = []
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.helper import encrypter
|
||||
|
||||
@@ -93,3 +93,32 @@ class FileVariable(FileSegment, Variable):
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
label: str = Field(description="label")
|
||||
description: str | None = Field(description="description", default="")
|
||||
variable: str = Field(description="variable key", default="")
|
||||
max_length: int | None = Field(
|
||||
description="max length, applicable to text-input, paragraph, and file-list", default=0
|
||||
)
|
||||
default_value: Any = Field(description="default value", default="")
|
||||
placeholder: str | None = Field(description="placeholder", default="")
|
||||
unit: str | None = Field(description="unit, applicable to Number", default="")
|
||||
tooltips: str | None = Field(description="helpful text", default="")
|
||||
allowed_file_types: list[str] | None = Field(
|
||||
description="image, document, audio, video, custom.", default_factory=list
|
||||
)
|
||||
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
|
||||
allowed_file_upload_methods: list[str] | None = Field(
|
||||
description="remote_url, local_file, tool_file.", default_factory=list
|
||||
)
|
||||
required: bool = Field(description="optional, default false", default=False)
|
||||
options: list[str] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class RAGPipelineVariableInput(BaseModel):
|
||||
variable: RAGPipelineVariable
|
||||
value: Any
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
@@ -9,7 +9,13 @@ from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, NoneSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.variables.variables import RAGPipelineVariableInput
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from factories import variable_factory
|
||||
|
||||
@@ -44,6 +50,10 @@ class VariablePool(BaseModel):
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
for key, value in self.system_variables.items():
|
||||
@@ -54,6 +64,9 @@ class VariablePool(BaseModel):
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
for var in self.rag_pipeline_variables:
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ class WorkflowType(StrEnum):
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
|
||||
@@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
|
||||
@@ -14,3 +14,10 @@ class SystemVariableKey(StrEnum):
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
# RAG Pipeline
|
||||
DOCUMENT_ID = "document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
DATASOURCE_TYPE = "datasource_type"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
@@ -121,6 +121,7 @@ class Graph(BaseModel):
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
@@ -141,6 +142,7 @@ class Graph(BaseModel):
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ class GraphEngine:
|
||||
)
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX):
|
||||
self.graph_runtime_state.outputs = (
|
||||
dict(item.route_node_state.node_run_result.outputs)
|
||||
if item.route_node_state.node_run_result
|
||||
@@ -320,10 +320,10 @@ class GraphEngine:
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (
|
||||
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
||||
== NodeType.END.value
|
||||
):
|
||||
if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
|
||||
NodeType.END.value,
|
||||
NodeType.KNOWLEDGE_INDEX.value,
|
||||
]:
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
3
api/core/workflow/nodes/datasource/__init__.py
Normal file
3
api/core/workflow/nodes/datasource/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
468
api/core/workflow/nodes/datasource/datasource_node.py
Normal file
468
api/core/workflow/nodes/datasource/datasource_node.py
Normal file
@@ -0,0 +1,468 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data_cls = DatasourceNodeData
|
||||
_node_type = NodeType.DATASOURCE
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = datasource_type.value
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if not datasource_info:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info = datasource_info.value
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
parameters = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
try:
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id"),
|
||||
page_id=datasource_info.get("page").get("page_id"),
|
||||
type=datasource_info.get("page").get("type"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
key=datasource_info.get("key"),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_drive_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
**datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
variable_pool.add([self.node_id, "file"], [file_info])
|
||||
for key, value in datasource_info.items():
|
||||
# construct new key list
|
||||
new_key_list = ["file", key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool,
|
||||
node_id=self.node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file_info": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to invoke datasource: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _append_variables_recursively(
|
||||
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DatasourceNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
result = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
input = node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.DATASOURCE_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"datasource_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"datasource_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.DATASOURCE_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"json": json, "files": files, **variables, "text": text},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
41
api/core/workflow/nodes/datasource/entities.py
Normal file
41
api/core/workflow/nodes/datasource/entities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
plugin_id: str
|
||||
provider_name: str # redundancy
|
||||
provider_type: str
|
||||
datasource_name: Optional[str] = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
16
api/core/workflow/nodes/datasource/exc.py
Normal file
16
api/core/workflow/nodes/datasource/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class DatasourceNodeError(ValueError):
|
||||
"""Base exception for datasource node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterError(DatasourceNodeError):
|
||||
"""Exception raised for errors in datasource parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceFileError(DatasourceNodeError):
|
||||
"""Exception raised for errors related to datasource files."""
|
||||
|
||||
pass
|
||||
@@ -7,12 +7,14 @@ class NodeType(StrEnum):
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
KNOWLEDGE_INDEX = "knowledge-index"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
|
||||
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .knowledge_index_node import KnowledgeIndexNode
|
||||
|
||||
__all__ = ["KnowledgeIndexNode"]
|
||||
159
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
159
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""
|
||||
File Info.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class OnlineDocumentIcon(BaseModel):
|
||||
"""
|
||||
Document Icon.
|
||||
"""
|
||||
|
||||
icon_url: str
|
||||
icon_type: str
|
||||
icon_emoji: str
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
workspace_id: str
|
||||
page_id: str
|
||||
page_type: str
|
||||
icon: OnlineDocumentIcon
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
|
||||
|
||||
class GeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunks: list[str]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
class ParentChildChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Chunk.
|
||||
"""
|
||||
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class ParentChildStructureChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Structure Chunk.
|
||||
"""
|
||||
|
||||
parent_child_chunks: list[ParentChildChunk]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
class KnowledgeIndexNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
@@ -0,0 +1,22 @@
|
||||
class KnowledgeIndexNodeError(ValueError):
|
||||
"""Base class for KnowledgeIndexNode errors."""
|
||||
|
||||
|
||||
class ModelNotExistError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model does not exist."""
|
||||
|
||||
|
||||
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model credentials are not initialized."""
|
||||
|
||||
|
||||
class ModelNotSupportedError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model is not supported."""
|
||||
|
||||
|
||||
class ModelQuotaExceededError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model provider quota is exceeded."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
165
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
165
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from ..base import BaseNode
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||
else:
|
||||
is_preview = False
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
if not chunks:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=None,
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(
|
||||
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||
)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
logger.warning("Error when running knowledge index node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
node_data: KnowledgeIndexNodeData,
|
||||
chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> Any:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
# update document status
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.filter(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"batch": batch.value,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"display_status": document.indexing_status,
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@@ -57,10 +57,6 @@ class MultipleRetrievalConfig(BaseModel):
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
|
||||
@@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end import EndNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
@@ -124,4 +126,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
LATEST_VERSION: AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ def build_from_mapping(
|
||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
|
||||
}
|
||||
|
||||
build_func = build_functions.get(transfer_method)
|
||||
@@ -305,6 +306,53 @@ def _build_from_tool_file(
|
||||
)
|
||||
|
||||
|
||||
def _build_from_datasource_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
datasource_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == mapping.get("datasource_file_id"),
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if datasource_file is None:
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
|
||||
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
tenant_id=tenant_id,
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=datasource_file.source_url,
|
||||
related_id=datasource_file.id,
|
||||
extension=extension,
|
||||
mime_type=datasource_file.mime_type,
|
||||
size=datasource_file.size,
|
||||
storage_key=datasource_file.key,
|
||||
)
|
||||
|
||||
|
||||
def _is_file_valid_with_config(
|
||||
*,
|
||||
input_file_type: str,
|
||||
|
||||
@@ -36,7 +36,10 @@ from core.variables.variables import (
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedSegmentTypeError(Exception):
|
||||
@@ -75,6 +78,12 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("variable"):
|
||||
raise VariableError("missing variable")
|
||||
return mapping["variable"]
|
||||
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
"""
|
||||
This factory function is used to create the environment variable or the conversation variable,
|
||||
|
||||
@@ -56,6 +56,13 @@ external_knowledge_info_fields = {
|
||||
|
||||
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
icon_info_fields = {
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": fields.String,
|
||||
}
|
||||
|
||||
dataset_detail_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
@@ -81,6 +88,13 @@ dataset_detail_fields = {
|
||||
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
|
||||
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
|
||||
"built_in_field_enabled": fields.Boolean,
|
||||
"pipeline_id": fields.String,
|
||||
"runtime_mode": fields.String,
|
||||
"chunk_structure": fields.String,
|
||||
"icon_info": fields.Nested(icon_info_fields),
|
||||
"is_published": fields.Boolean,
|
||||
"total_documents": fields.Integer,
|
||||
"total_available_documents": fields.Integer,
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
|
||||
164
api/fields/rag_pipeline_fields.py
Normal file
164
api/fields/rag_pipeline_fields.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from flask_restful import fields # type: ignore
|
||||
|
||||
from fields.workflow_fields import workflow_partial_fields
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
|
||||
pipeline_detail_kernel_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
"data": fields.List(fields.Nested(pipeline_detail_kernel_fields)),
|
||||
"total": fields.Integer,
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
app_partial_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
}
|
||||
|
||||
|
||||
app_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
|
||||
}
|
||||
|
||||
template_fields = {
|
||||
"name": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String,
|
||||
}
|
||||
|
||||
template_list_fields = {
|
||||
"data": fields.List(fields.Nested(template_fields)),
|
||||
}
|
||||
|
||||
site_fields = {
|
||||
"access_token": fields.String(attribute="code"),
|
||||
"code": fields.String,
|
||||
"title": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"default_language": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"customize_domain": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"customize_token_strategy": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"app_base_url": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
deleted_tool_fields = {
|
||||
"type": fields.String,
|
||||
"tool_name": fields.String,
|
||||
"provider_id": fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"site": fields.Nested(site_fields),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
app_site_fields = {
|
||||
"app_id": fields.String,
|
||||
"access_token": fields.String(attribute="code"),
|
||||
"code": fields.String,
|
||||
"title": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"description": fields.String,
|
||||
"default_language": fields.String,
|
||||
"customize_domain": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"customize_token_strategy": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String}
|
||||
|
||||
pipeline_import_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"pipeline_id": fields.String,
|
||||
"dataset_id": fields.String,
|
||||
"current_dsl_version": fields.String,
|
||||
"imported_dsl_version": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
pipeline_import_check_dependencies_fields = {
|
||||
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
|
||||
}
|
||||
@@ -40,6 +40,23 @@ conversation_variable_fields = {
|
||||
"description": fields.String,
|
||||
}
|
||||
|
||||
pipeline_variable_fields = {
|
||||
"label": fields.String,
|
||||
"variable": fields.String,
|
||||
"type": fields.String,
|
||||
"belong_to_node_id": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"required": fields.Boolean,
|
||||
"unit": fields.String,
|
||||
"default_value": fields.Raw,
|
||||
"options": fields.List(fields.String),
|
||||
"placeholder": fields.String,
|
||||
"tooltips": fields.String,
|
||||
"allowed_file_types": fields.List(fields.String),
|
||||
"allow_file_extension": fields.List(fields.String),
|
||||
"allow_file_upload_methods": fields.List(fields.String),
|
||||
}
|
||||
|
||||
workflow_fields = {
|
||||
"id": fields.String,
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
@@ -55,6 +72,7 @@ workflow_fields = {
|
||||
"tool_published": fields.Boolean,
|
||||
"environment_variables": fields.List(EnvironmentVariableField()),
|
||||
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
||||
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
|
||||
}
|
||||
|
||||
workflow_partial_fields = {
|
||||
|
||||
1
api/installed_plugins.jsonl
Normal file
1
api/installed_plugins.jsonl
Normal file
@@ -0,0 +1 @@
|
||||
{"not_installed": [], "plugin_install_failed": []}
|
||||
@@ -0,0 +1,113 @@
|
||||
"""add_pipeline_info
|
||||
|
||||
Revision ID: b35c3db83d09
|
||||
Revises: d28f2004b072
|
||||
Create Date: 2025-05-15 15:58:05.179877
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b35c3db83d09'
|
||||
down_revision = '0ab65e1cc7fa'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('pipeline_built_in_templates',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=False),
|
||||
sa.Column('icon', sa.JSON(), nullable=False),
|
||||
sa.Column('copyright', sa.String(length=255), nullable=False),
|
||||
sa.Column('privacy_policy', sa.String(length=255), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('install_count', sa.Integer(), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
|
||||
)
|
||||
op.create_table('pipeline_customized_templates',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=False),
|
||||
sa.Column('icon', sa.JSON(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('install_count', sa.Integer(), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||
)
|
||||
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('pipelines',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False),
|
||||
sa.Column('mode', sa.String(length=255), nullable=False),
|
||||
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
|
||||
)
|
||||
op.create_table('tool_builtin_datasource_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=256), nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
|
||||
batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||
batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True))
|
||||
batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True))
|
||||
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True))
|
||||
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.drop_column('rag_pipeline_variables')
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('chunk_structure')
|
||||
batch_op.drop_column('pipeline_id')
|
||||
batch_op.drop_column('runtime_mode')
|
||||
batch_op.drop_column('icon_info')
|
||||
batch_op.drop_column('keyword_number')
|
||||
|
||||
op.drop_table('tool_builtin_datasource_providers')
|
||||
op.drop_table('pipelines')
|
||||
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||
batch_op.drop_index('pipeline_customized_template_tenant_idx')
|
||||
|
||||
op.drop_table('pipeline_customized_templates')
|
||||
op.drop_table('pipeline_built_in_templates')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add_pipeline_info_2
|
||||
|
||||
Revision ID: abb18a379e62
|
||||
Revises: b35c3db83d09
|
||||
Create Date: 2025-05-16 16:59:16.423127
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'abb18a379e62'
|
||||
down_revision = 'b35c3db83d09'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||
batch_op.drop_column('mode')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,70 @@
|
||||
"""add_pipeline_info_3
|
||||
|
||||
Revision ID: c459994abfa8
|
||||
Revises: abb18a379e62
|
||||
Create Date: 2025-05-30 00:33:14.068312
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c459994abfa8'
|
||||
down_revision = 'abb18a379e62'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('datasource_oauth_params',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
|
||||
)
|
||||
op.create_table('datasource_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('auth_type', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_provider_plugin_id_provider_idx')
|
||||
)
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
|
||||
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
|
||||
batch_op.drop_column('pipeline_id')
|
||||
|
||||
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
|
||||
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
|
||||
batch_op.drop_column('pipeline_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('yaml_content')
|
||||
batch_op.drop_column('chunk_structure')
|
||||
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('yaml_content')
|
||||
batch_op.drop_column('chunk_structure')
|
||||
|
||||
op.drop_table('datasource_providers')
|
||||
op.drop_table('datasource_oauth_params')
|
||||
# ### end Alembic commands ###
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user