mirror of
https://github.com/langgenius/dify.git
synced 2026-03-12 02:27:05 +00:00
Compare commits
639 Commits
deploy/dev
...
feat/r2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38d895ab5f | ||
|
|
a6ff9b224b | ||
|
|
832bef053f | ||
|
|
81b07dc3be | ||
|
|
e23d7e39ec | ||
|
|
9f14b5db9a | ||
|
|
39d3f58082 | ||
|
|
f33b6c0c73 | ||
|
|
a4eddd7dc2 | ||
|
|
c993a05da7 | ||
|
|
f44f0fa34c | ||
|
|
bfcf09b684 | ||
|
|
cdbba1400c | ||
|
|
618ad4c291 | ||
|
|
1449ed86c4 | ||
|
|
eee72101f4 | ||
|
|
efccbe4039 | ||
|
|
540096a8d8 | ||
|
|
7b7cdad1d8 | ||
|
|
6aba39a2dd | ||
|
|
49bb15fae1 | ||
|
|
e165f4a102 | ||
|
|
83cc484c24 | ||
|
|
1ff9c07a92 | ||
|
|
b25b284d7f | ||
|
|
2414dbb5f8 | ||
|
|
916a8c76e7 | ||
|
|
9783832223 | ||
|
|
b77081a19e | ||
|
|
896906ae77 | ||
|
|
2365a3a5fc | ||
|
|
dd792210f6 | ||
|
|
6ba4a4c165 | ||
|
|
0a6dbf6ee2 | ||
|
|
ca0979dd43 | ||
|
|
0762e5ae50 | ||
|
|
48f53f3b9b | ||
|
|
af64f29e87 | ||
|
|
b9f59e3a75 | ||
|
|
b12a8eeb90 | ||
|
|
e551cf65c9 | ||
|
|
3899211c41 | ||
|
|
335e1e3602 | ||
|
|
725fc72c6f | ||
|
|
b618f3bd9e | ||
|
|
95ba55af4d | ||
|
|
f4e1ea9011 | ||
|
|
3d0e288e85 | ||
|
|
9620d6bcd8 | ||
|
|
f7fbded8b9 | ||
|
|
0c5706b3f6 | ||
|
|
82d0a70cb4 | ||
|
|
55516c4e57 | ||
|
|
cc2cd85ff5 | ||
|
|
6ec742539a | ||
|
|
09e0a54070 | ||
|
|
5d25199f42 | ||
|
|
387826674c | ||
|
|
02ae479636 | ||
|
|
a103324f25 | ||
|
|
643efc5d85 | ||
|
|
43e5798e13 | ||
|
|
8aca70cd50 | ||
|
|
2cf980026e | ||
|
|
224111081b | ||
|
|
4dc6cad588 | ||
|
|
f85e6a0dea | ||
|
|
4b3a54633f | ||
|
|
6f67a34349 | ||
|
|
e51d308312 | ||
|
|
379c92bd82 | ||
|
|
fa9f0ebfb1 | ||
|
|
ac917bb56d | ||
|
|
f7a4e5d1a6 | ||
|
|
515d34bbfb | ||
|
|
66de2e1f0a | ||
|
|
7f7ea92a45 | ||
|
|
a014345688 | ||
|
|
cf66d111ba | ||
|
|
2d01b1a808 | ||
|
|
739ebf2117 | ||
|
|
176b844cd5 | ||
|
|
8fc6684ab1 | ||
|
|
7c41f71248 | ||
|
|
2c2bfb4f54 | ||
|
|
3164f90327 | ||
|
|
90ac52482c | ||
|
|
879ac940dd | ||
|
|
796797d12b | ||
|
|
7ac0f0c08c | ||
|
|
5cc6a2bf33 | ||
|
|
2db0b19044 | ||
|
|
1d2ee9020c | ||
|
|
f2538bf381 | ||
|
|
f37e28a368 | ||
|
|
c5976f5a09 | ||
|
|
64a9181ee4 | ||
|
|
33cd32382f | ||
|
|
9456c59290 | ||
|
|
ce0bd421ae | ||
|
|
f9d04c6975 | ||
|
|
ecb07a5d0d | ||
|
|
a165ba2059 | ||
|
|
12fd2903d8 | ||
|
|
0a2c569b3b | ||
|
|
9ab0d5fe60 | ||
|
|
1d71fd5b56 | ||
|
|
b277acc298 | ||
|
|
8d47d8ce4f | ||
|
|
41fef8a21f | ||
|
|
b853a42e37 | ||
|
|
1633626d23 | ||
|
|
6c7a40c571 | ||
|
|
abb2ed66e7 | ||
|
|
5ae78f79b0 | ||
|
|
e3b3a6d040 | ||
|
|
6622ce6ad8 | ||
|
|
5ccb8d9736 | ||
|
|
55906c8375 | ||
|
|
0908f310fc | ||
|
|
58842898e1 | ||
|
|
1c17c8fa36 | ||
|
|
26aff400e4 | ||
|
|
4b11d29ede | ||
|
|
b2b95412b9 | ||
|
|
5c228bca4f | ||
|
|
7bd2509ad5 | ||
|
|
2a5d70d9e1 | ||
|
|
b0107f4128 | ||
|
|
dc3c5362e4 | ||
|
|
1d106c3660 | ||
|
|
fcb2fa04e7 | ||
|
|
55bff10f0d | ||
|
|
45c9b77e82 | ||
|
|
767860e76b | ||
|
|
80f656f79a | ||
|
|
c891eb28fc | ||
|
|
b9fa3f54e9 | ||
|
|
4d2f904d72 | ||
|
|
26b7911177 | ||
|
|
dd91edf70b | ||
|
|
d994e6b6c7 | ||
|
|
aba48bde0b | ||
|
|
3e5d9884cb | ||
|
|
faadad62ff | ||
|
|
406d70e4a3 | ||
|
|
f17f256b2b | ||
|
|
b367f48de6 | ||
|
|
dee7b6eb22 | ||
|
|
6f17200dec | ||
|
|
d3dbfbe8b3 | ||
|
|
b1f250862f | ||
|
|
141d6b1abf | ||
|
|
a7eb534761 | ||
|
|
808f792f55 | ||
|
|
346d066128 | ||
|
|
5c41922b8a | ||
|
|
9c3e3b00d0 | ||
|
|
da3a3ce165 | ||
|
|
b525bc2b81 | ||
|
|
14dc3e8642 | ||
|
|
e52c905aa5 | ||
|
|
7b9a3c1084 | ||
|
|
ce8ddae11e | ||
|
|
4e8184bc56 | ||
|
|
9eb8597957 | ||
|
|
cde584046d | ||
|
|
b7f9d7e94a | ||
|
|
88817bf974 | ||
|
|
92e6c52c0e | ||
|
|
309dfe8829 | ||
|
|
1d8b390584 | ||
|
|
7dea7f77ac | ||
|
|
4d9b15e519 | ||
|
|
45a708f17e | ||
|
|
5f08a9314c | ||
|
|
5802b2b437 | ||
|
|
f995436eec | ||
|
|
25f0c61e65 | ||
|
|
66fa68fa18 | ||
|
|
3e5781c6f1 | ||
|
|
a6f7560d2f | ||
|
|
45c76c1d68 | ||
|
|
14d5af468c | ||
|
|
874e1bc41d | ||
|
|
d2ae695b3b | ||
|
|
6ecdac6344 | ||
|
|
3c2ce07f38 | ||
|
|
5c58b11b22 | ||
|
|
be92122f17 | ||
|
|
2972a06f16 | ||
|
|
caa275fdbd | ||
|
|
5dbda7f4c5 | ||
|
|
0564651f6f | ||
|
|
eff8108f1c | ||
|
|
127a77d807 | ||
|
|
265842223c | ||
|
|
95a24156de | ||
|
|
80ca5b3356 | ||
|
|
e934503fa0 | ||
|
|
442bcd18c0 | ||
|
|
aeb1d1946c | ||
|
|
12f2913e08 | ||
|
|
0aeeee49f7 | ||
|
|
eb7479b1ea | ||
|
|
80b219707e | ||
|
|
65ac022245 | ||
|
|
6e6090d5a9 | ||
|
|
58b5daeef3 | ||
|
|
33fd1fa79d | ||
|
|
978118f770 | ||
|
|
a2610b22cc | ||
|
|
f4789d750d | ||
|
|
176f9ea2f4 | ||
|
|
5e71f7c825 | ||
|
|
7624edd32d | ||
|
|
7b79354849 | ||
|
|
a7ff2ab470 | ||
|
|
d3eedaf0ec | ||
|
|
bcb0496bf4 | ||
|
|
4d967544f3 | ||
|
|
c18ee4be50 | ||
|
|
65873aa411 | ||
|
|
b95256d624 | ||
|
|
c0d3452494 | ||
|
|
c91456de1b | ||
|
|
e1ce156433 | ||
|
|
9e19ed4e67 | ||
|
|
ba383b1b0d | ||
|
|
ad3d9cf782 | ||
|
|
69053332e4 | ||
|
|
5b4d04b348 | ||
|
|
47664f8fd3 | ||
|
|
8d8f21addd | ||
|
|
9b9640b3db | ||
|
|
83ba61203b | ||
|
|
fcbd5febeb | ||
|
|
b8813e199f | ||
|
|
2322496552 | ||
|
|
21a3509bef | ||
|
|
3e2f12b065 | ||
|
|
55e20d189a | ||
|
|
1aa13bd20d | ||
|
|
cc2dd052df | ||
|
|
4ffdf68a20 | ||
|
|
547bd3cc1b | ||
|
|
f3e9761c75 | ||
|
|
83ca59e0f1 | ||
|
|
d725aa8791 | ||
|
|
cc8ee0ff69 | ||
|
|
4a249c40b1 | ||
|
|
04e4a1e3aa | ||
|
|
d2d5fc62ae | ||
|
|
52460f6929 | ||
|
|
06dfc32e0f | ||
|
|
0ca38d8215 | ||
|
|
3da6becad3 | ||
|
|
f9d0a7bdc8 | ||
|
|
e961722597 | ||
|
|
2ddd2616ec | ||
|
|
a82a9fb9d4 | ||
|
|
3fce6f2581 | ||
|
|
3db864561e | ||
|
|
d2750f1a02 | ||
|
|
30a50c5cc8 | ||
|
|
0ff746ebf6 | ||
|
|
5193fa2118 | ||
|
|
9a0dc82e6a | ||
|
|
8e4165defe | ||
|
|
d917bc8ed0 | ||
|
|
ef7bd262c5 | ||
|
|
d3e29ffa74 | ||
|
|
70432952fd | ||
|
|
cf2ef93ad5 | ||
|
|
cbf0864edc | ||
|
|
bce2bdd0de | ||
|
|
82e7c8a2f9 | ||
|
|
2acdb0a4ea | ||
|
|
350ea6be6e | ||
|
|
4664174ef3 | ||
|
|
f0413f359a | ||
|
|
53b32c8b22 | ||
|
|
b8ef1d9585 | ||
|
|
90ca98ff3a | ||
|
|
d4a1d045f8 | ||
|
|
91fefa0e37 | ||
|
|
067ec17539 | ||
|
|
c084b57933 | ||
|
|
876be7e6e9 | ||
|
|
468bfdfed9 | ||
|
|
82d817f612 | ||
|
|
9e84a5321d | ||
|
|
d77e27ac05 | ||
|
|
8a86a2c817 | ||
|
|
fdc4c36b77 | ||
|
|
52c118f5b8 | ||
|
|
5d7c7023c3 | ||
|
|
3e0a10b7ed | ||
|
|
84f5272f72 | ||
|
|
6286f368f1 | ||
|
|
cb2ca0b533 | ||
|
|
5fe5da7c1d | ||
|
|
c83370f701 | ||
|
|
7506867fb9 | ||
|
|
842136959b | ||
|
|
4c2cc98ebc | ||
|
|
44b9f49ab1 | ||
|
|
f7f7952951 | ||
|
|
a7fa5044e3 | ||
|
|
eb84134706 | ||
|
|
fbca9010f3 | ||
|
|
0bf0c7dbe8 | ||
|
|
e071bd63e6 | ||
|
|
8a147a00e8 | ||
|
|
c9a4c66b07 | ||
|
|
edec654b68 | ||
|
|
a82ab1d152 | ||
|
|
9934eac15c | ||
|
|
c155afac29 | ||
|
|
7080c9f279 | ||
|
|
e41699cbc8 | ||
|
|
133193e7d0 | ||
|
|
9d6371e0a3 | ||
|
|
dfe091789c | ||
|
|
4c9bf78363 | ||
|
|
b95ecaf8a3 | ||
|
|
7a0e8108ae | ||
|
|
3afd5e73c9 | ||
|
|
c09c8c6e5b | ||
|
|
cab491795a | ||
|
|
e290ddc3e5 | ||
|
|
db154e33b7 | ||
|
|
32f9004b5f | ||
|
|
225402280e | ||
|
|
abcca11479 | ||
|
|
9cdd2cbb27 | ||
|
|
309fffd1e4 | ||
|
|
0a9f50e85f | ||
|
|
ed1d71f4d0 | ||
|
|
7039ec33b9 | ||
|
|
025dc7c781 | ||
|
|
4130c50643 | ||
|
|
7b7f8ef51d | ||
|
|
bad451d5ec | ||
|
|
87c15062e6 | ||
|
|
573cd15e77 | ||
|
|
ab1730bbaa | ||
|
|
163bae3aaf | ||
|
|
270edd43ab | ||
|
|
b8f3b23b1a | ||
|
|
b9c6496fea | ||
|
|
0486aa3445 | ||
|
|
5fb771218c | ||
|
|
3fb02a7933 | ||
|
|
898495b5c4 | ||
|
|
08624878cf | ||
|
|
6fe473f0fa | ||
|
|
11cf23e5fc | ||
|
|
631768ea1d | ||
|
|
e1d658b482 | ||
|
|
1274aaed5d | ||
|
|
9be036e0ca | ||
|
|
7284569c5f | ||
|
|
976b465e76 | ||
|
|
804e55824d | ||
|
|
69529fb16d | ||
|
|
cb5cfb2dae | ||
|
|
a826879cf7 | ||
|
|
e7c48c0b69 | ||
|
|
558a280fc8 | ||
|
|
2158c03231 | ||
|
|
a61f1f8eb0 | ||
|
|
9f724c19db | ||
|
|
4ae936b263 | ||
|
|
80875a109a | ||
|
|
121e54f3e3 | ||
|
|
1c2c4b62f8 | ||
|
|
9176790adf | ||
|
|
6ff6525d1d | ||
|
|
71ce505631 | ||
|
|
11dfe3713f | ||
|
|
a025db137d | ||
|
|
797d044714 | ||
|
|
c4169f8aa0 | ||
|
|
3005419573 | ||
|
|
7f59ffe7af | ||
|
|
cc7ad5ac97 | ||
|
|
769b5e185a | ||
|
|
9e763c9e87 | ||
|
|
b9214ca76b | ||
|
|
29d2f2339b | ||
|
|
5ac1e3584d | ||
|
|
dd0cf6fadc | ||
|
|
b320ebe2ba | ||
|
|
377093b776 | ||
|
|
70119a054a | ||
|
|
69d1e3ec7d | ||
|
|
365157c37d | ||
|
|
4bc0a1bd37 | ||
|
|
d6640f2adf | ||
|
|
987f845e79 | ||
|
|
84daf49047 | ||
|
|
31e183ef0d | ||
|
|
754a1d1197 | ||
|
|
049a6de4b3 | ||
|
|
6bd28cadc4 | ||
|
|
3b9a0b1d25 | ||
|
|
db963a638c | ||
|
|
dcb4c9e84a | ||
|
|
5fc2bc58a9 | ||
|
|
d333645e09 | ||
|
|
2812c774c6 | ||
|
|
e2f3f0ae4c | ||
|
|
83ca7f8deb | ||
|
|
e6c6fa8ed8 | ||
|
|
678d6ffe2b | ||
|
|
cef77a3717 | ||
|
|
28726b6cf3 | ||
|
|
ef0e41de07 | ||
|
|
dc2b63b832 | ||
|
|
0478fc9649 | ||
|
|
1b07e612d2 | ||
|
|
38cce3f62a | ||
|
|
35be8721b9 | ||
|
|
665ffbdc10 | ||
|
|
b5f88c77a3 | ||
|
|
324c0d7b4c | ||
|
|
13e3f17493 | ||
|
|
841bd35ebb | ||
|
|
ccefd41606 | ||
|
|
ec1c4efca9 | ||
|
|
0f10852b6b | ||
|
|
6d547447d3 | ||
|
|
6123f1ab21 | ||
|
|
e7370766bd | ||
|
|
db4958be05 | ||
|
|
a15bf8e8fe | ||
|
|
70d2c78176 | ||
|
|
42fcda3dc8 | ||
|
|
ac049d938e | ||
|
|
3af61f4b5d | ||
|
|
e19adbbbc5 | ||
|
|
64d997fdb0 | ||
|
|
a49942b949 | ||
|
|
4460d96e58 | ||
|
|
a7d5f2f53b | ||
|
|
c9bf99a1e2 | ||
|
|
4300ebc8aa | ||
|
|
720ce79901 | ||
|
|
693107a6c8 | ||
|
|
583db24ee7 | ||
|
|
7d92574e02 | ||
|
|
5aaa06c8b0 | ||
|
|
52b773770b | ||
|
|
23adc7d8a8 | ||
|
|
e3708bfa85 | ||
|
|
7d65e9980c | ||
|
|
b93d26ee9f | ||
|
|
b82b26bba5 | ||
|
|
21c24977d8 | ||
|
|
fe435c23c3 | ||
|
|
ead1209f98 | ||
|
|
3994bb1771 | ||
|
|
327690e4a7 | ||
|
|
c2a7e0e986 | ||
|
|
faf6b9ea03 | ||
|
|
3bfc602561 | ||
|
|
5fa2aca2c8 | ||
|
|
69a60101fe | ||
|
|
b18519b824 | ||
|
|
0d01025254 | ||
|
|
eef1542cb3 | ||
|
|
9aef4b6d6b | ||
|
|
7dba83754f | ||
|
|
e2585bc778 | ||
|
|
cc6e2558ef | ||
|
|
20343facad | ||
|
|
eff123a11c | ||
|
|
9bafd3a226 | ||
|
|
82be119fec | ||
|
|
a64df507f6 | ||
|
|
cf73faf174 | ||
|
|
ba52bf27c1 | ||
|
|
55f4177b01 | ||
|
|
14a9052d60 | ||
|
|
314a2f9be8 | ||
|
|
8eee344fbb | ||
|
|
0e0a266142 | ||
|
|
7bce35913d | ||
|
|
7898dbd5bf | ||
|
|
bd1073ff1a | ||
|
|
1bbd572593 | ||
|
|
5199297f61 | ||
|
|
c5a2f43ceb | ||
|
|
8d4ced227e | ||
|
|
f481075f8f | ||
|
|
836cf6453e | ||
|
|
8bea88c8cc | ||
|
|
4b7274f9a5 | ||
|
|
7de5585da6 | ||
|
|
87dc80f6fa | ||
|
|
a008c04331 | ||
|
|
56b66b8a57 | ||
|
|
35a7add4e9 | ||
|
|
f1fe143962 | ||
|
|
9e72afee3c | ||
|
|
613b94a6e6 | ||
|
|
7b0d38f7d3 | ||
|
|
4ff971c8a3 | ||
|
|
019ef74bf2 | ||
|
|
2670557258 | ||
|
|
93ac6d37e9 | ||
|
|
e710a8402c | ||
|
|
360f8a3375 | ||
|
|
818eb46a8b | ||
|
|
f5c297708b | ||
|
|
bf8324f7f7 | ||
|
|
b730d153ea | ||
|
|
11977596c9 | ||
|
|
612dca8b7d | ||
|
|
53018289d4 | ||
|
|
958ff44707 | ||
|
|
d910770b3c | ||
|
|
5a8f10520f | ||
|
|
df928772c0 | ||
|
|
b713218cab | ||
|
|
9ea2123e7f | ||
|
|
de0cb06f8c | ||
|
|
cfb6d59513 | ||
|
|
4c30d1c1eb | ||
|
|
5bb02c79cc | ||
|
|
0a891e5392 | ||
|
|
f6978ce6b1 | ||
|
|
4d68aadc1c | ||
|
|
cef6463847 | ||
|
|
39b8331f81 | ||
|
|
212d4c5899 | ||
|
|
97ec855df4 | ||
|
|
d83b9b70e3 | ||
|
|
b51c18c2cf | ||
|
|
7e31da7882 | ||
|
|
d9ed61287d | ||
|
|
6024dbe98d | ||
|
|
13ce6317f1 | ||
|
|
0099f2296d | ||
|
|
2d93bc6725 | ||
|
|
cb52f9ecc5 | ||
|
|
1fbeb3a21a | ||
|
|
38f1a42ce8 | ||
|
|
3d11af2dd6 | ||
|
|
d1fd5db7f8 | ||
|
|
c240cf3bb1 | ||
|
|
bbbcd68258 | ||
|
|
7ce9710229 | ||
|
|
3f7f21ce70 | ||
|
|
fa8ab4ea04 | ||
|
|
3f1363503b | ||
|
|
3f52f491d7 | ||
|
|
e86a3fc672 | ||
|
|
6f77f67427 | ||
|
|
4025cd0b46 | ||
|
|
3bbb22750c | ||
|
|
d196872059 | ||
|
|
a478d95950 | ||
|
|
12c060b795 | ||
|
|
c480c3d881 | ||
|
|
a998022c12 | ||
|
|
a25cc4e8af | ||
|
|
b4bccf5fef | ||
|
|
14ad34af71 | ||
|
|
7ed398267f | ||
|
|
fc9556e057 | ||
|
|
acf6872a50 | ||
|
|
e689f21a60 | ||
|
|
a7f9259e27 | ||
|
|
a46b4e3616 | ||
|
|
e7e12c1d2e | ||
|
|
66176c4d71 | ||
|
|
2613a380b6 | ||
|
|
9392ce259f | ||
|
|
d1287f08b4 | ||
|
|
7ee8472a5f | ||
|
|
cdb615deeb | ||
|
|
abbba1d004 | ||
|
|
3c386c63a6 | ||
|
|
49d1846e63 | ||
|
|
53f2882077 | ||
|
|
8f07e088f5 | ||
|
|
f71b0eccb2 | ||
|
|
5b89d36ea1 | ||
|
|
7c3af74b0d | ||
|
|
d1d83f8a2a | ||
|
|
839fe12087 | ||
|
|
fd8ee9f53e | ||
|
|
c2d02f8f4d | ||
|
|
8367ae85de | ||
|
|
d1f0e6e5c2 | ||
|
|
7deb44f864 | ||
|
|
d12e9b81e3 | ||
|
|
b1fbaaed95 | ||
|
|
3f8b0b937c | ||
|
|
734c62998f | ||
|
|
4792ca1813 | ||
|
|
d4007ae073 | ||
|
|
389f15f8e3 | ||
|
|
9437145218 | ||
|
|
076924bbd6 | ||
|
|
97cf6b2d65 | ||
|
|
f317ef2fe2 | ||
|
|
f7de55364f | ||
|
|
de30e9278c | ||
|
|
b9ab1555fb | ||
|
|
44b9ce0951 | ||
|
|
d768094376 | ||
|
|
93f83086c1 | ||
|
|
8d9c252811 | ||
|
|
c7f4b41920 | ||
|
|
efb27eb443 | ||
|
|
5b8c43052e | ||
|
|
e04ae927b6 | ||
|
|
ac68d62d1c | ||
|
|
caa17b8fe9 | ||
|
|
cd1562ee24 | ||
|
|
47af1a9c42 | ||
|
|
0cd6a427af | ||
|
|
51165408ed | ||
|
|
a2dc38f90a | ||
|
|
a36436b585 | ||
|
|
2d87823fc6 | ||
|
|
d238da9826 | ||
|
|
6eef5990c9 | ||
|
|
5c4bf2a9e4 | ||
|
|
0345eb4659 | ||
|
|
71f78e0d33 | ||
|
|
942648e9e9 | ||
|
|
d841581679 | ||
|
|
9f8e05d9f0 | ||
|
|
3340775052 | ||
|
|
9987774471 |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -6,6 +6,7 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
- "deploy/dev"
|
- "deploy/dev"
|
||||||
- "deploy/enterprise"
|
- "deploy/enterprise"
|
||||||
|
- "deploy/rag-dev"
|
||||||
tags:
|
tags:
|
||||||
- "*"
|
- "*"
|
||||||
|
|
||||||
|
|||||||
7
.github/workflows/deploy-dev.yml
vendored
7
.github/workflows/deploy-dev.yml
vendored
@@ -4,7 +4,7 @@ on:
|
|||||||
workflow_run:
|
workflow_run:
|
||||||
workflows: ["Build and Push API & Web"]
|
workflows: ["Build and Push API & Web"]
|
||||||
branches:
|
branches:
|
||||||
- "deploy/dev"
|
- "deploy/rag-dev"
|
||||||
types:
|
types:
|
||||||
- completed
|
- completed
|
||||||
|
|
||||||
@@ -12,12 +12,13 @@ jobs:
|
|||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success'
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
|
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||||
steps:
|
steps:
|
||||||
- name: Deploy to server
|
- name: Deploy to server
|
||||||
uses: appleboy/ssh-action@v0.1.8
|
uses: appleboy/ssh-action@v0.1.8
|
||||||
with:
|
with:
|
||||||
host: ${{ secrets.SSH_HOST }}
|
host: ${{ secrets.RAG_SSH_HOST }}
|
||||||
username: ${{ secrets.SSH_USER }}
|
username: ${{ secrets.SSH_USER }}
|
||||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
script: |
|
script: |
|
||||||
|
|||||||
27
api/app.py
27
api/app.py
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
@@ -17,20 +16,20 @@ else:
|
|||||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||||
# so we need to disable gevent in debug mode.
|
# so we need to disable gevent in debug mode.
|
||||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||||
from gevent import monkey
|
# from gevent import monkey
|
||||||
|
#
|
||||||
|
# # gevent
|
||||||
|
# monkey.patch_all()
|
||||||
|
#
|
||||||
|
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||||
|
#
|
||||||
|
# # grpc gevent
|
||||||
|
# grpc_gevent.init_gevent()
|
||||||
|
|
||||||
# gevent
|
# import psycogreen.gevent # type: ignore
|
||||||
monkey.patch_all()
|
#
|
||||||
|
# psycogreen.gevent.patch_psycopg()
|
||||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
|
||||||
|
|
||||||
# grpc gevent
|
|
||||||
grpc_gevent.init_gevent()
|
|
||||||
|
|
||||||
import psycogreen.gevent # type: ignore
|
|
||||||
|
|
||||||
psycogreen.gevent.patch_psycopg()
|
|
||||||
|
|
||||||
from app_factory import create_app
|
from app_factory import create_app
|
||||||
|
|
||||||
|
|||||||
@@ -222,11 +222,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for fetching pipeline templates
|
||||||
|
"""
|
||||||
|
|
||||||
|
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||||
|
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||||
|
default="database",
|
||||||
|
)
|
||||||
|
|
||||||
|
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||||
|
description="Domain for fetching remote pipeline templates",
|
||||||
|
default="https://tmpl.dify.ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HostedServiceConfig(
|
class HostedServiceConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
HostedAnthropicConfig,
|
HostedAnthropicConfig,
|
||||||
HostedAzureOpenAiConfig,
|
HostedAzureOpenAiConfig,
|
||||||
HostedFetchAppTemplateConfig,
|
HostedFetchAppTemplateConfig,
|
||||||
|
HostedFetchPipelineTemplateConfig,
|
||||||
HostedMinmaxConfig,
|
HostedMinmaxConfig,
|
||||||
HostedOpenAiConfig,
|
HostedOpenAiConfig,
|
||||||
HostedSparkConfig,
|
HostedSparkConfig,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from threading import Lock
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from contexts.wrapper import RecyclableContextVar
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
@@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
|||||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||||
ContextVar("plugin_model_schemas")
|
ContextVar("plugin_model_schemas")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||||
|
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||||
|
)
|
||||||
|
|
||||||
|
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||||
|
ContextVar("datasource_plugin_providers_lock")
|
||||||
|
)
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ from .billing import billing, compliance
|
|||||||
|
|
||||||
# Import datasets controllers
|
# Import datasets controllers
|
||||||
from .datasets import (
|
from .datasets import (
|
||||||
data_source,
|
|
||||||
datasets,
|
datasets,
|
||||||
datasets_document,
|
datasets_document,
|
||||||
datasets_segments,
|
datasets_segments,
|
||||||
@@ -85,6 +84,14 @@ from .datasets import (
|
|||||||
metadata,
|
metadata,
|
||||||
website,
|
website,
|
||||||
)
|
)
|
||||||
|
from .datasets.rag_pipeline import (
|
||||||
|
datasource_auth,
|
||||||
|
datasource_content_preview,
|
||||||
|
rag_pipeline,
|
||||||
|
rag_pipeline_datasets,
|
||||||
|
rag_pipeline_import,
|
||||||
|
rag_pipeline_workflow,
|
||||||
|
)
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
|
|||||||
@@ -283,6 +283,15 @@ class DatasetApi(Resource):
|
|||||||
location="json",
|
location="json",
|
||||||
help="Invalid external knowledge api id.",
|
help="Invalid external knowledge api id.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
required=False,
|
||||||
|
nullable=True,
|
||||||
|
location="json",
|
||||||
|
help="Invalid icon info.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from argparse import ArgumentTypeError
|
from argparse import ArgumentTypeError
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -51,6 +52,7 @@ from fields.document_fields import (
|
|||||||
)
|
)
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
|
from models.dataset import DocumentPipelineExecutionLog
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||||
|
|
||||||
@@ -661,7 +663,7 @@ class DocumentDetailApi(DocumentResource):
|
|||||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||||
elif metadata == "without":
|
elif metadata == "without":
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||||
data_source_info = document.data_source_detail_dict
|
data_source_info = document.data_source_detail_dict
|
||||||
response = {
|
response = {
|
||||||
"id": document.id,
|
"id": document.id,
|
||||||
@@ -1028,6 +1030,41 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, dataset_id, document_id):
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
document_id = str(document_id)
|
||||||
|
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
if not document:
|
||||||
|
raise NotFound("Document not found.")
|
||||||
|
log = (
|
||||||
|
db.session.query(DocumentPipelineExecutionLog)
|
||||||
|
.filter_by(document_id=document_id)
|
||||||
|
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not log:
|
||||||
|
return {
|
||||||
|
"datasource_info": None,
|
||||||
|
"datasource_type": None,
|
||||||
|
"input_data": None,
|
||||||
|
"datasource_node_id": None,
|
||||||
|
}, 200
|
||||||
|
return {
|
||||||
|
"datasource_info": json.loads(log.datasource_info),
|
||||||
|
"datasource_type": log.datasource_type,
|
||||||
|
"input_data": log.input_data,
|
||||||
|
"datasource_node_id": log.datasource_node_id,
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||||
@@ -1050,3 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
|||||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||||
|
|
||||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||||
|
api.add_resource(
|
||||||
|
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||||
|
)
|
||||||
|
|||||||
@@ -101,3 +101,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
|||||||
error_code = "child_chunk_delete_index_error"
|
error_code = "child_chunk_delete_index_error"
|
||||||
description = "Delete child chunk index failed: {message}"
|
description = "Delete child chunk index failed: {message}"
|
||||||
code = 500
|
code = 500
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineNotFoundError(BaseHTTPException):
|
||||||
|
error_code = "pipeline_not_found"
|
||||||
|
description = "Pipeline not found."
|
||||||
|
code = 404
|
||||||
|
|||||||
197
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
197
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
from flask import redirect, request
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import ( # type: ignore
|
||||||
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePluginOauthApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
provider = args["provider"]
|
||||||
|
plugin_id = args["plugin_id"]
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
# get all plugin oauth configs
|
||||||
|
plugin_oauth_config = (
|
||||||
|
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||||
|
)
|
||||||
|
if not plugin_oauth_config:
|
||||||
|
raise NotFound()
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
redirect_url = (
|
||||||
|
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
|
||||||
|
)
|
||||||
|
system_credentials = plugin_oauth_config.system_credentials
|
||||||
|
if system_credentials:
|
||||||
|
system_credentials["redirect_url"] = redirect_url
|
||||||
|
response = oauth_handler.get_authorization_url(
|
||||||
|
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
|
||||||
|
)
|
||||||
|
return response.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceOauthCallback(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
provider = args["provider"]
|
||||||
|
plugin_id = args["plugin_id"]
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
plugin_oauth_config = (
|
||||||
|
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||||
|
)
|
||||||
|
if not plugin_oauth_config:
|
||||||
|
raise NotFound()
|
||||||
|
credentials = oauth_handler.get_credentials(
|
||||||
|
current_user.current_tenant.id,
|
||||||
|
current_user.id,
|
||||||
|
plugin_id,
|
||||||
|
provider,
|
||||||
|
system_credentials=plugin_oauth_config.system_credentials,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
datasource_provider = DatasourceProvider(
|
||||||
|
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
|
||||||
|
)
|
||||||
|
db.session.add(datasource_provider)
|
||||||
|
db.session.commit()
|
||||||
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuth(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
|
||||||
|
try:
|
||||||
|
datasource_provider_service.datasource_provider_credentials_validate(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=args["provider"],
|
||||||
|
plugin_id=args["plugin_id"],
|
||||||
|
credentials=args["credentials"],
|
||||||
|
name=args["name"],
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasources = datasource_provider_service.get_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
|
||||||
|
)
|
||||||
|
return {"result": datasources}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthUpdateDeleteApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, auth_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
auth_id=auth_id,
|
||||||
|
provider=args["provider"],
|
||||||
|
plugin_id=args["plugin_id"],
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def patch(self, auth_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
try:
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.update_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
auth_id=auth_id,
|
||||||
|
provider=args["provider"],
|
||||||
|
plugin_id=args["plugin_id"],
|
||||||
|
credentials=args["credentials"],
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
|
# Import Rag Pipeline
|
||||||
|
api.add_resource(
|
||||||
|
DatasourcePluginOauthApi,
|
||||||
|
"/oauth/plugin/datasource",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceOauthCallback,
|
||||||
|
"/oauth/plugin/datasource/callback",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuth,
|
||||||
|
"/auth/plugin/datasource",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthUpdateDeleteApi,
|
||||||
|
"/auth/plugin/datasource/<string:auth_id>",
|
||||||
|
)
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
from flask_restful import ( # type: ignore
|
||||||
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceContentPreviewApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
def post(self, pipeline: Pipeline, node_id: str):
|
||||||
|
"""
|
||||||
|
Run datasource content preview
|
||||||
|
"""
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
inputs = args.get("inputs")
|
||||||
|
if inputs is None:
|
||||||
|
raise ValueError("missing inputs")
|
||||||
|
datasource_type = args.get("datasource_type")
|
||||||
|
if datasource_type is None:
|
||||||
|
raise ValueError("missing datasource_type")
|
||||||
|
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||||
|
pipeline=pipeline,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=inputs,
|
||||||
|
account=current_user,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
is_published=True,
|
||||||
|
)
|
||||||
|
return preview_content, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DataSourceContentPreviewApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||||
|
)
|
||||||
162
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
Normal file
162
api/controllers/console/datasets/rag_pipeline/rag_pipeline.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
enterprise_license_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.dataset import PipelineCustomizedTemplate
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTemplateListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self):
|
||||||
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
|
language = request.args.get("language", default="en-US", type=str)
|
||||||
|
# get pipeline templates
|
||||||
|
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||||
|
return pipeline_templates, 200
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTemplateDetailApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self, template_id: str):
|
||||||
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||||
|
return pipeline_template, 200
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizedPipelineTemplateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def patch(self, template_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||||
|
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def delete(self, template_id: str):
|
||||||
|
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, template_id: str):
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
template = (
|
||||||
|
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||||
|
)
|
||||||
|
if not template:
|
||||||
|
raise ValueError("Customized pipeline template not found.")
|
||||||
|
|
||||||
|
return {"data": template.yaml_content}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, pipeline_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
PipelineTemplateListApi,
|
||||||
|
"/rag/pipeline/templates",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
PipelineTemplateDetailApi,
|
||||||
|
"/rag/pipeline/templates/<string:template_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
CustomizedPipelineTemplateApi,
|
||||||
|
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
PublishCustomizedPipelineTemplateApi,
|
||||||
|
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||||
|
)
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
from flask_login import current_user # type: ignore # type: ignore
|
||||||
|
from flask_restful import Resource, marshal, reqparse # type: ignore
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
cloud_edition_billing_rate_limit_check,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class CreateRagPipelineDatasetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"permission",
|
||||||
|
type=str,
|
||||||
|
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=DatasetPermissionEnum.ONLY_ME,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"partial_member_list",
|
||||||
|
type=list,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"yaml_content",
|
||||||
|
type=str,
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="yaml_content is required.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||||
|
try:
|
||||||
|
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
|
)
|
||||||
|
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||||
|
DatasetPermissionService.update_partial_member_list(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
import_info["dataset_id"],
|
||||||
|
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||||
|
)
|
||||||
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
|
raise DatasetNameDuplicateError()
|
||||||
|
|
||||||
|
return import_info, 201
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
def post(self):
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"permission",
|
||||||
|
type=str,
|
||||||
|
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=DatasetPermissionEnum.ONLY_ME,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"partial_member_list",
|
||||||
|
type=list,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
|
||||||
|
)
|
||||||
|
return marshal(dataset, dataset_detail_fields), 201
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||||
|
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from models import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from services.app_dsl_service import ImportStatus
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_fields)
|
||||||
|
def post(self):
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("mode", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("yaml_content", type=str, location="json")
|
||||||
|
parser.add_argument("yaml_url", type=str, location="json")
|
||||||
|
parser.add_argument("name", type=str, location="json")
|
||||||
|
parser.add_argument("description", type=str, location="json")
|
||||||
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
|
parser.add_argument("icon", type=str, location="json")
|
||||||
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
|
parser.add_argument("pipeline_id", type=str, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create service with session
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
# Import app
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
result = import_service.import_rag_pipeline(
|
||||||
|
account=account,
|
||||||
|
import_mode=args["mode"],
|
||||||
|
yaml_content=args.get("yaml_content"),
|
||||||
|
yaml_url=args.get("yaml_url"),
|
||||||
|
pipeline_id=args.get("pipeline_id"),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Return appropriate status code based on result
|
||||||
|
status = result.status
|
||||||
|
if status == ImportStatus.FAILED.value:
|
||||||
|
return result.model_dump(mode="json"), 400
|
||||||
|
elif status == ImportStatus.PENDING.value:
|
||||||
|
return result.model_dump(mode="json"), 202
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportConfirmApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_fields)
|
||||||
|
def post(self, import_id):
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Create service with session
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
# Confirm import
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Return appropriate status code based on result
|
||||||
|
if result.status == ImportStatus.FAILED.value:
|
||||||
|
return result.model_dump(mode="json"), 400
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
result = import_service.check_dependencies(pipeline=pipeline)
|
||||||
|
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineExportApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Add include_secret params
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
export_service = RagPipelineDslService(session)
|
||||||
|
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||||
|
|
||||||
|
return {"data": result}, 200
|
||||||
|
|
||||||
|
|
||||||
|
# Import Rag Pipeline
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportApi,
|
||||||
|
"/rag/pipelines/imports",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportConfirmApi,
|
||||||
|
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportCheckDependenciesApi,
|
||||||
|
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineExportApi,
|
||||||
|
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||||
|
)
|
||||||
File diff suppressed because it is too large
Load Diff
43
api/controllers/console/datasets/wraps.py
Normal file
43
api/controllers/console/datasets/wraps.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_pipeline(
|
||||||
|
view: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
def decorator(view_func):
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args, **kwargs):
|
||||||
|
if not kwargs.get("pipeline_id"):
|
||||||
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
|
pipeline_id = kwargs.get("pipeline_id")
|
||||||
|
pipeline_id = str(pipeline_id)
|
||||||
|
|
||||||
|
del kwargs["pipeline_id"]
|
||||||
|
|
||||||
|
pipeline = (
|
||||||
|
db.session.query(Pipeline)
|
||||||
|
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pipeline:
|
||||||
|
raise PipelineNotFoundError()
|
||||||
|
|
||||||
|
kwargs["pipeline"] = pipeline
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
if view is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(view)
|
||||||
@@ -113,9 +113,9 @@ class VariableEntity(BaseModel):
|
|||||||
hide: bool = False
|
hide: bool = False
|
||||||
max_length: Optional[int] = None
|
max_length: Optional[int] = None
|
||||||
options: Sequence[str] = Field(default_factory=list)
|
options: Sequence[str] = Field(default_factory=list)
|
||||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("description", mode="before")
|
@field_validator("description", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -128,6 +128,16 @@ class VariableEntity(BaseModel):
|
|||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineVariableEntity(VariableEntity):
|
||||||
|
"""
|
||||||
|
Rag Pipeline Variable Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tooltips: Optional[str] = None
|
||||||
|
placeholder: Optional[str] = None
|
||||||
|
belong_to_node_id: str
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
External Data Variable Entity.
|
External Data Variable Entity.
|
||||||
@@ -285,7 +295,7 @@ class AppConfig(BaseModel):
|
|||||||
tenant_id: str
|
tenant_id: str
|
||||||
app_id: str
|
app_id: str
|
||||||
app_mode: AppMode
|
app_mode: AppMode
|
||||||
additional_features: AppAdditionalFeatures
|
additional_features: Optional[AppAdditionalFeatures] = None
|
||||||
variables: list[VariableEntity] = []
|
variables: list[VariableEntity] = []
|
||||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
@@ -20,3 +20,19 @@ class WorkflowVariablesConfigManager:
|
|||||||
variables.append(VariableEntity.model_validate(variable))
|
variables.append(VariableEntity.model_validate(variable))
|
||||||
|
|
||||||
return variables
|
return variables
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
|
||||||
|
"""
|
||||||
|
Convert workflow start variables to variables
|
||||||
|
|
||||||
|
:param workflow: workflow instance
|
||||||
|
"""
|
||||||
|
variables = []
|
||||||
|
|
||||||
|
user_input_form = workflow.rag_pipeline_user_input_form()
|
||||||
|
# variables
|
||||||
|
for variable in user_input_form:
|
||||||
|
variables.append(RagPipelineVariableEntity.model_validate(variable))
|
||||||
|
|
||||||
|
return variables
|
||||||
|
|||||||
@@ -43,11 +43,13 @@ from core.app.entities.task_entities import (
|
|||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from models import (
|
from models import (
|
||||||
@@ -183,6 +185,14 @@ class WorkflowResponseConverter:
|
|||||||
provider_type=node_data.provider_type,
|
provider_type=node_data.provider_type,
|
||||||
provider_id=node_data.provider_id,
|
provider_id=node_data.provider_id,
|
||||||
)
|
)
|
||||||
|
elif event.node_type == NodeType.DATASOURCE:
|
||||||
|
node_data = cast(DatasourceNodeData, event.node_data)
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
provider_entity = manager.fetch_datasource_provider(
|
||||||
|
self._application_generate_entity.app_config.tenant_id,
|
||||||
|
f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||||
|
)
|
||||||
|
response.data.extras["icon"] = provider_entity.declaration.identity.icon
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
0
api/core/app/apps/pipeline/__init__.py
Normal file
0
api/core/app/apps/pipeline/__init__.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
95
api/core/app/apps/pipeline/generate_response_converter.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
|
from core.app.entities.task_entities import (
|
||||||
|
AppStreamResponse,
|
||||||
|
ErrorStreamResponse,
|
||||||
|
NodeFinishStreamResponse,
|
||||||
|
NodeStartStreamResponse,
|
||||||
|
PingStreamResponse,
|
||||||
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppStreamResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
|
_blocking_response_type = WorkflowAppBlockingResponse
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking full response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return dict(blocking_response.to_dict())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking simple response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_full_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream full response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(data)
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.to_dict())
|
||||||
|
yield response_chunk
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_simple_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream simple response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(data)
|
||||||
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.to_dict())
|
||||||
|
yield response_chunk
|
||||||
64
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
64
api/core/app/apps/pipeline/pipeline_config_manager.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||||
|
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||||
|
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||||
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||||
|
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.model import AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||||
|
"""
|
||||||
|
Pipeline Config Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfigManager(BaseAppConfigManager):
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
|
||||||
|
pipeline_config = PipelineConfig(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
app_id=pipeline.id,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Validate for pipeline config
|
||||||
|
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param config: app model config args
|
||||||
|
:param only_structure_validate: only validate the structure of the config
|
||||||
|
"""
|
||||||
|
related_config_keys = []
|
||||||
|
|
||||||
|
# file upload validation
|
||||||
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# text_to_speech
|
||||||
|
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# moderation validation
|
||||||
|
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||||
|
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||||
|
)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
related_config_keys = list(set(related_config_keys))
|
||||||
|
|
||||||
|
# Filter out extra parameters
|
||||||
|
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||||
|
|
||||||
|
return filtered_config
|
||||||
621
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
621
api/core/app/apps/pipeline/pipeline_generator.py
Normal file
@@ -0,0 +1,621 @@
|
|||||||
|
import contextvars
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||||
|
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||||
|
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||||
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||||
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
|
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||||
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
from models.model import AppMode
|
||||||
|
from services.dataset_service import DocumentService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineGenerator(BaseAppGenerator):
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[True],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[False],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool,
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool = True,
|
||||||
|
call_depth: int = 0,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
)
|
||||||
|
# Add null check for dataset
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
|
start_node_id: str = args["start_node_id"]
|
||||||
|
datasource_type: str = args["datasource_type"]
|
||||||
|
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||||
|
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||||
|
documents = []
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED:
|
||||||
|
for datasource_info in datasource_info_list:
|
||||||
|
position = DocumentService.get_documents_position(dataset.id)
|
||||||
|
document = self._build_document(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
created_from="rag-pipeline",
|
||||||
|
position=position,
|
||||||
|
account=user,
|
||||||
|
batch=batch,
|
||||||
|
document_form=dataset.chunk_structure,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# run in child thread
|
||||||
|
for i, datasource_info in enumerate(datasource_info_list):
|
||||||
|
workflow_run_id = str(uuid.uuid4())
|
||||||
|
document_id = None
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED:
|
||||||
|
document_id = documents[i].id
|
||||||
|
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||||
|
document_id=document_id,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=json.dumps(datasource_info),
|
||||||
|
datasource_node_id=start_node_id,
|
||||||
|
input_data=inputs,
|
||||||
|
pipeline_id=pipeline.id,
|
||||||
|
created_by=user.id,
|
||||||
|
)
|
||||||
|
db.session.add(document_pipeline_execution_log)
|
||||||
|
db.session.commit()
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
start_node_id=start_node_id,
|
||||||
|
batch=batch,
|
||||||
|
document_id=document_id,
|
||||||
|
inputs=self._prepare_user_inputs(
|
||||||
|
user_inputs=inputs,
|
||||||
|
variables=pipeline_config.rag_pipeline_variables,
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
|
),
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
call_depth=call_depth,
|
||||||
|
workflow_execution_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||||
|
else:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=workflow_triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||||
|
)
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
context=contextvars.copy_context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# run in child thread
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(
|
||||||
|
target=self._generate,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"context": context,
|
||||||
|
"pipeline": pipeline,
|
||||||
|
"workflow_id": workflow.id,
|
||||||
|
"user": user,
|
||||||
|
"application_generate_entity": application_generate_entity,
|
||||||
|
"invoke_from": invoke_from,
|
||||||
|
"workflow_execution_repository": workflow_execution_repository,
|
||||||
|
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||||
|
"streaming": streaming,
|
||||||
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
# return batch, dataset, documents
|
||||||
|
return {
|
||||||
|
"batch": batch,
|
||||||
|
"dataset": PipelineDataset(
|
||||||
|
id=dataset.id,
|
||||||
|
name=dataset.name,
|
||||||
|
description=dataset.description,
|
||||||
|
chunk_structure=dataset.chunk_structure,
|
||||||
|
).model_dump(),
|
||||||
|
"documents": [
|
||||||
|
PipelineDocument(
|
||||||
|
id=document.id,
|
||||||
|
position=document.position,
|
||||||
|
data_source_type=document.data_source_type,
|
||||||
|
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||||
|
name=document.name,
|
||||||
|
indexing_status=document.indexing_status,
|
||||||
|
error=document.error,
|
||||||
|
enabled=document.enabled,
|
||||||
|
).model_dump()
|
||||||
|
for document in documents
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
flask_app: Flask,
|
||||||
|
context: contextvars.Context,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow_id: str,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
streaming: bool = True,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param pipeline: Pipeline
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param user: account or end user
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param workflow_execution_repository: repository for workflow execution
|
||||||
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
|
:param streaming: is stream
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
|
# init queue manager
|
||||||
|
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||||
|
queue_manager = PipelineQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
)
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
worker_thread = threading.Thread(
|
||||||
|
target=self._generate_worker,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"context": context,
|
||||||
|
"queue_manager": queue_manager,
|
||||||
|
"application_generate_entity": application_generate_entity,
|
||||||
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
# return response or stream generator
|
||||||
|
response = self._handle_response(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
stream=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
|
|
||||||
|
def single_iteration_generate(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
|
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
def single_loop_generate(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
extras={"auto_generate_conversation_name": False},
|
||||||
|
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_worker(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
context: contextvars.Context,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate worker in a new thread.
|
||||||
|
:param flask_app: Flask app
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
|
try:
|
||||||
|
# workflow app
|
||||||
|
runner = PipelineRunner(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.run()
|
||||||
|
except GenerateTaskStoppedError:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
queue_manager.publish_error(
|
||||||
|
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except ValueError as e:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logger.exception("Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
finally:
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
def _handle_response(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
workflow: Workflow,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
|
"""
|
||||||
|
Handle response.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param workflow: workflow
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param user: account or end user
|
||||||
|
:param stream: is stream
|
||||||
|
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# init generate task pipeline
|
||||||
|
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
stream=stream,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return generate_task_pipeline.process()
|
||||||
|
except ValueError as e:
|
||||||
|
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
|
else:
|
||||||
|
logger.exception(
|
||||||
|
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _build_document(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
built_in_field_enabled: bool,
|
||||||
|
datasource_type: str,
|
||||||
|
datasource_info: Mapping[str, Any],
|
||||||
|
created_from: str,
|
||||||
|
position: int,
|
||||||
|
account: Union[Account, EndUser],
|
||||||
|
batch: str,
|
||||||
|
document_form: str,
|
||||||
|
):
|
||||||
|
if datasource_type == "local_file":
|
||||||
|
name = datasource_info["name"]
|
||||||
|
elif datasource_type == "online_document":
|
||||||
|
name = datasource_info["page"]["page_name"]
|
||||||
|
elif datasource_type == "website_crawl":
|
||||||
|
name = datasource_info["title"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
position=position,
|
||||||
|
data_source_type=datasource_type,
|
||||||
|
data_source_info=json.dumps(datasource_info),
|
||||||
|
batch=batch,
|
||||||
|
name=name,
|
||||||
|
created_from=created_from,
|
||||||
|
created_by=account.id,
|
||||||
|
doc_form=document_form,
|
||||||
|
)
|
||||||
|
doc_metadata = {}
|
||||||
|
if built_in_field_enabled:
|
||||||
|
doc_metadata = {
|
||||||
|
BuiltInField.document_name: name,
|
||||||
|
BuiltInField.uploader: account.name,
|
||||||
|
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.source: datasource_type,
|
||||||
|
}
|
||||||
|
if doc_metadata:
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
return document
|
||||||
44
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
44
api/core/app/apps/pipeline/pipeline_queue_manager.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
|
QueueErrorEvent,
|
||||||
|
QueueMessageEndEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
|
QueueWorkflowSucceededEvent,
|
||||||
|
WorkflowQueueMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineQueueManager(AppQueueManager):
|
||||||
|
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||||
|
super().__init__(task_id, user_id, invoke_from)
|
||||||
|
|
||||||
|
self._app_mode = app_mode
|
||||||
|
|
||||||
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||||
|
"""
|
||||||
|
Publish event to queue
|
||||||
|
:param event:
|
||||||
|
:param pub_from:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||||
|
|
||||||
|
self._q.put(message)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
event,
|
||||||
|
QueueStopEvent
|
||||||
|
| QueueErrorEvent
|
||||||
|
| QueueMessageEndEvent
|
||||||
|
| QueueWorkflowSucceededEvent
|
||||||
|
| QueueWorkflowFailedEvent
|
||||||
|
| QueueWorkflowPartialSuccessEvent,
|
||||||
|
):
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
221
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
221
api/core/app/apps/pipeline/pipeline_runner.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||||
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
InvokeFrom,
|
||||||
|
RagPipelineGenerateEntity,
|
||||||
|
)
|
||||||
|
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||||
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.model import EndUser
|
||||||
|
from models.workflow import Workflow, WorkflowType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner(WorkflowBasedAppRunner):
|
||||||
|
"""
|
||||||
|
Pipeline Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
self.application_generate_entity = application_generate_entity
|
||||||
|
self.queue_manager = queue_manager
|
||||||
|
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||||
|
|
||||||
|
def _get_app_id(self) -> str:
|
||||||
|
return self.application_generate_entity.app_config.app_id
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
"""
|
||||||
|
app_config = self.application_generate_entity.app_config
|
||||||
|
app_config = cast(PipelineConfig, app_config)
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
|
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
|
if end_user:
|
||||||
|
user_id = end_user.session_id
|
||||||
|
else:
|
||||||
|
user_id = self.application_generate_entity.user_id
|
||||||
|
|
||||||
|
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||||
|
if not pipeline:
|
||||||
|
raise ValueError("Pipeline not found")
|
||||||
|
|
||||||
|
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
|
# if only single iteration run is requested
|
||||||
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
# if only single iteration run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
)
|
||||||
|
elif self.application_generate_entity.single_loop_run:
|
||||||
|
# if only single loop run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = self.application_generate_entity.inputs
|
||||||
|
files = self.application_generate_entity.files
|
||||||
|
|
||||||
|
# Create a variable pool.
|
||||||
|
system_inputs = {
|
||||||
|
SystemVariableKey.FILES: files,
|
||||||
|
SystemVariableKey.USER_ID: user_id,
|
||||||
|
SystemVariableKey.APP_ID: app_config.app_id,
|
||||||
|
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||||
|
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||||
|
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||||
|
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||||
|
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||||
|
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
|
||||||
|
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||||
|
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
||||||
|
}
|
||||||
|
rag_pipeline_variables = []
|
||||||
|
if workflow.rag_pipeline_variables:
|
||||||
|
for v in workflow.rag_pipeline_variables:
|
||||||
|
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||||
|
if (
|
||||||
|
rag_pipeline_variable.belong_to_node_id
|
||||||
|
in (self.application_generate_entity.start_node_id, "shared")
|
||||||
|
) and rag_pipeline_variable.variable in inputs:
|
||||||
|
rag_pipeline_variables.append(
|
||||||
|
RAGPipelineVariableInput(
|
||||||
|
variable=rag_pipeline_variable,
|
||||||
|
value=inputs[rag_pipeline_variable.variable],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables=system_inputs,
|
||||||
|
user_inputs=inputs,
|
||||||
|
environment_variables=workflow.environment_variables,
|
||||||
|
conversation_variables=[],
|
||||||
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init graph
|
||||||
|
graph = self._init_rag_pipeline_graph(
|
||||||
|
graph_config=workflow.graph_dict,
|
||||||
|
start_node_id=self.application_generate_entity.start_node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# RUN WORKFLOW
|
||||||
|
workflow_entry = WorkflowEntry(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=workflow.app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
workflow_type=WorkflowType.value_of(workflow.type),
|
||||||
|
graph=graph,
|
||||||
|
graph_config=workflow.graph_dict,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
user_from=(
|
||||||
|
UserFrom.ACCOUNT
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
|
else UserFrom.END_USER
|
||||||
|
),
|
||||||
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
thread_pool_id=self.workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||||
|
|
||||||
|
for event in generator:
|
||||||
|
self._handle_event(workflow_entry, event)
|
||||||
|
|
||||||
|
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||||
|
"""
|
||||||
|
Get workflow
|
||||||
|
"""
|
||||||
|
# fetch workflow by workflow_id
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.filter(
|
||||||
|
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# return workflow
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||||
|
"""
|
||||||
|
Init pipeline graph
|
||||||
|
"""
|
||||||
|
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||||
|
raise ValueError("nodes or edges not found in workflow graph")
|
||||||
|
|
||||||
|
if not isinstance(graph_config.get("nodes"), list):
|
||||||
|
raise ValueError("nodes in workflow graph must be a list")
|
||||||
|
|
||||||
|
if not isinstance(graph_config.get("edges"), list):
|
||||||
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
|
nodes = graph_config.get("nodes", [])
|
||||||
|
edges = graph_config.get("edges", [])
|
||||||
|
real_run_nodes = []
|
||||||
|
real_edges = []
|
||||||
|
exclude_node_ids = []
|
||||||
|
for node in nodes:
|
||||||
|
node_id = node.get("id")
|
||||||
|
node_type = node.get("data", {}).get("type", "")
|
||||||
|
if node_type == "datasource":
|
||||||
|
if start_node_id != node_id:
|
||||||
|
exclude_node_ids.append(node_id)
|
||||||
|
continue
|
||||||
|
real_run_nodes.append(node)
|
||||||
|
for edge in edges:
|
||||||
|
if edge.get("source") in exclude_node_ids:
|
||||||
|
continue
|
||||||
|
real_edges.append(edge)
|
||||||
|
graph_config = dict(graph_config)
|
||||||
|
graph_config["nodes"] = real_run_nodes
|
||||||
|
graph_config["edges"] = real_edges
|
||||||
|
# init graph
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
||||||
|
return graph
|
||||||
@@ -36,6 +36,7 @@ class InvokeFrom(Enum):
|
|||||||
# DEBUGGER indicates that this invocation is from
|
# DEBUGGER indicates that this invocation is from
|
||||||
# the workflow (or chatflow) edit page.
|
# the workflow (or chatflow) edit page.
|
||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
|
PUBLISHED = "published"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str):
|
def value_of(cls, value: str):
|
||||||
@@ -240,3 +241,38 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||||||
inputs: dict
|
inputs: dict
|
||||||
|
|
||||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||||
|
"""
|
||||||
|
RAG Pipeline Application Generate Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pipeline config
|
||||||
|
pipeline_config: WorkflowUIBasedAppConfig
|
||||||
|
datasource_type: str
|
||||||
|
datasource_info: Mapping[str, Any]
|
||||||
|
dataset_id: str
|
||||||
|
batch: str
|
||||||
|
document_id: Optional[str] = None
|
||||||
|
start_node_id: Optional[str] = None
|
||||||
|
|
||||||
|
class SingleIterationRunEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Single Iteration Run Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_id: str
|
||||||
|
inputs: dict
|
||||||
|
|
||||||
|
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||||
|
|
||||||
|
class SingleLoopRunEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Single Loop Run Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_id: str
|
||||||
|
inputs: dict
|
||||||
|
|
||||||
|
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||||
|
|||||||
@@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
|
|
||||||
self.current_loop += 1
|
self.current_loop += 1
|
||||||
|
|
||||||
|
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||||
|
"""Run on datasource start."""
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
print_text(
|
||||||
|
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
|
||||||
|
color=self.color,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_agent(self) -> bool:
|
def ignore_agent(self) -> bool:
|
||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
|
|||||||
33
api/core/datasource/__base/datasource_plugin.py
Normal file
33
api/core/datasource/__base/datasource_plugin.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePlugin(ABC):
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
) -> None:
|
||||||
|
self.entity = entity
|
||||||
|
self.runtime = runtime
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
"""
|
||||||
|
returns the type of the datasource provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
|
return self.__class__(
|
||||||
|
entity=self.entity.model_copy(),
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
118
api/core/datasource/__base/datasource_provider.py
Normal file
118
api/core/datasource/__base/datasource_provider.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePluginProviderController(ABC):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
|
||||||
|
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||||
|
self.entity = entity
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_credentials(self) -> bool:
|
||||||
|
"""
|
||||||
|
returns whether the provider needs credentials
|
||||||
|
|
||||||
|
:return: whether the provider needs credentials
|
||||||
|
"""
|
||||||
|
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
manager = PluginToolManager()
|
||||||
|
if not manager.validate_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=self.entity.identity.name,
|
||||||
|
credentials=credentials,
|
||||||
|
):
|
||||||
|
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the format of the credentials of the provider and set the default value if needed
|
||||||
|
|
||||||
|
:param credentials: the credentials of the tool
|
||||||
|
"""
|
||||||
|
credentials_schema = dict[str, ProviderConfig]()
|
||||||
|
if credentials_schema is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for credential in self.entity.credentials_schema:
|
||||||
|
credentials_schema[credential.name] = credential
|
||||||
|
|
||||||
|
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||||
|
for credential_name in credentials_schema:
|
||||||
|
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||||
|
|
||||||
|
for credential_name in credentials:
|
||||||
|
if credential_name not in credentials_need_to_validate:
|
||||||
|
raise ToolProviderCredentialValidationError(
|
||||||
|
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check type
|
||||||
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
|
if not credential_schema.required and credentials[credential_name] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||||
|
if not isinstance(credentials[credential_name], str):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||||
|
if not isinstance(credentials[credential_name], str):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
options = credential_schema.options
|
||||||
|
if not isinstance(options, list):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||||
|
|
||||||
|
if credentials[credential_name] not in [x.value for x in options]:
|
||||||
|
raise ToolProviderCredentialValidationError(
|
||||||
|
f"credential {credential_name} should be one of {options}"
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_need_to_validate.pop(credential_name)
|
||||||
|
|
||||||
|
for credential_name in credentials_need_to_validate:
|
||||||
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
|
if credential_schema.required:
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||||
|
|
||||||
|
# the credential is not set currently, set the default value if needed
|
||||||
|
if credential_schema.default is not None:
|
||||||
|
default_value = credential_schema.default
|
||||||
|
# parse default value into the correct type
|
||||||
|
if credential_schema.type in {
|
||||||
|
ProviderConfig.Type.SECRET_INPUT,
|
||||||
|
ProviderConfig.Type.TEXT_INPUT,
|
||||||
|
ProviderConfig.Type.SELECT,
|
||||||
|
}:
|
||||||
|
default_value = str(default_value)
|
||||||
|
|
||||||
|
credentials[credential_name] = default_value
|
||||||
36
api/core/datasource/__base/datasource_runtime.py
Normal file
36
api/core/datasource/__base/datasource_runtime.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from openai import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceRuntime(BaseModel):
|
||||||
|
"""
|
||||||
|
Meta data of a datasource call processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
datasource_id: Optional[str] = None
|
||||||
|
invoke_from: Optional[InvokeFrom] = None
|
||||||
|
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||||
|
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||||
|
"""
|
||||||
|
Fake datasource runtime for testing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
tenant_id="fake_tenant_id",
|
||||||
|
datasource_id="fake_datasource_id",
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||||
|
credentials={},
|
||||||
|
runtime_parameters={},
|
||||||
|
)
|
||||||
0
api/core/datasource/__init__.py
Normal file
0
api/core/datasource/__init__.py
Normal file
244
api/core/datasource/datasource_file_manager.py
Normal file
244
api/core/datasource/datasource_file_manager.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from mimetypes import guess_extension, guess_type
|
||||||
|
from typing import Optional, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import MessageFile, UploadFile
|
||||||
|
from models.tools import ToolFile
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileManager:
|
||||||
|
@staticmethod
|
||||||
|
def sign_file(datasource_file_id: str, extension: str) -> str:
|
||||||
|
"""
|
||||||
|
sign file to get a temporary url
|
||||||
|
"""
|
||||||
|
base_url = dify_config.FILES_URL
|
||||||
|
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
"""
|
||||||
|
verify signature
|
||||||
|
"""
|
||||||
|
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_file_by_raw(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
file_binary: bytes,
|
||||||
|
mimetype: str,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
) -> UploadFile:
|
||||||
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
|
unique_name = uuid4().hex
|
||||||
|
unique_filename = f"{unique_name}{extension}"
|
||||||
|
# default just as before
|
||||||
|
present_filename = unique_filename
|
||||||
|
if filename is not None:
|
||||||
|
has_extension = len(filename.split(".")) > 1
|
||||||
|
# Add extension flexibly
|
||||||
|
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||||
|
filepath = f"datasources/{tenant_id}/{unique_filename}"
|
||||||
|
storage.save(filepath, file_binary)
|
||||||
|
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
storage_type=dify_config.STORAGE_TYPE,
|
||||||
|
key=filepath,
|
||||||
|
name=present_filename,
|
||||||
|
size=len(file_binary),
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mimetype,
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=user_id,
|
||||||
|
used=False,
|
||||||
|
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||||
|
source_url="",
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(upload_file)
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_file_by_url(
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
file_url: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
) -> UploadFile:
|
||||||
|
# try to download image
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.get(file_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
blob = response.content
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||||
|
|
||||||
|
mimetype = (
|
||||||
|
guess_type(file_url)[0]
|
||||||
|
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
|
unique_name = uuid4().hex
|
||||||
|
filename = f"{unique_name}{extension}"
|
||||||
|
filepath = f"tools/{tenant_id}/{filename}"
|
||||||
|
storage.save(filepath, blob)
|
||||||
|
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
storage_type=dify_config.STORAGE_TYPE,
|
||||||
|
key=filepath,
|
||||||
|
name=filename,
|
||||||
|
size=len(blob),
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mimetype,
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=user_id,
|
||||||
|
used=False,
|
||||||
|
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||||
|
source_url=file_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param id: the id of the file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
upload_file: UploadFile | None = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(
|
||||||
|
UploadFile.id == id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
|
||||||
|
return blob, upload_file.mime_type
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param id: the id of the file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
message_file: MessageFile | None = (
|
||||||
|
db.session.query(MessageFile)
|
||||||
|
.filter(
|
||||||
|
MessageFile.id == id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if message_file is not None
|
||||||
|
if message_file is not None:
|
||||||
|
# get tool file id
|
||||||
|
if message_file.url is not None:
|
||||||
|
tool_file_id = message_file.url.split("/")[-1]
|
||||||
|
# trim extension
|
||||||
|
tool_file_id = tool_file_id.split(".")[0]
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
|
|
||||||
|
tool_file: ToolFile | None = (
|
||||||
|
db.session.query(ToolFile)
|
||||||
|
.filter(
|
||||||
|
ToolFile.id == tool_file_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tool_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob = storage.load_once(tool_file.file_key)
|
||||||
|
|
||||||
|
return blob, tool_file.mimetype
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_generator_by_upload_file_id(upload_file_id: str):
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param tool_file_id: the id of the tool file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
upload_file: UploadFile | None = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(
|
||||||
|
UploadFile.id == upload_file_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
stream = storage.load_stream(upload_file.key)
|
||||||
|
|
||||||
|
return stream, upload_file.mime_type
|
||||||
|
|
||||||
|
|
||||||
|
# init tool_file_parser
|
||||||
|
# from core.file.datasource_file_parser import datasource_file_manager
|
||||||
|
#
|
||||||
|
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||||
100
api/core/datasource/datasource_manager.py
Normal file
100
api/core/datasource/datasource_manager.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.entities.common_entities import I18nObject
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||||
|
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||||
|
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||||
|
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||||
|
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceManager:
|
||||||
|
_builtin_provider_lock = Lock()
|
||||||
|
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||||
|
_builtin_providers_loaded = False
|
||||||
|
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_plugin_provider(
|
||||||
|
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||||
|
) -> DatasourcePluginProviderController:
|
||||||
|
"""
|
||||||
|
get the datasource plugin provider
|
||||||
|
"""
|
||||||
|
# check if context is set
|
||||||
|
try:
|
||||||
|
contexts.datasource_plugin_providers.get()
|
||||||
|
except LookupError:
|
||||||
|
contexts.datasource_plugin_providers.set({})
|
||||||
|
contexts.datasource_plugin_providers_lock.set(Lock())
|
||||||
|
|
||||||
|
with contexts.datasource_plugin_providers_lock.get():
|
||||||
|
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||||
|
if provider_id in datasource_plugin_providers:
|
||||||
|
return datasource_plugin_providers[provider_id]
|
||||||
|
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||||
|
if not provider_entity:
|
||||||
|
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||||
|
|
||||||
|
match datasource_type:
|
||||||
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
|
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
controller = LocalFileDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
datasource_plugin_providers[provider_id] = controller
|
||||||
|
|
||||||
|
return controller
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_runtime(
|
||||||
|
cls,
|
||||||
|
provider_id: str,
|
||||||
|
datasource_name: str,
|
||||||
|
tenant_id: str,
|
||||||
|
datasource_type: DatasourceProviderType,
|
||||||
|
) -> DatasourcePlugin:
|
||||||
|
"""
|
||||||
|
get the datasource runtime
|
||||||
|
|
||||||
|
:param provider_type: the type of the provider
|
||||||
|
:param provider_id: the id of the provider
|
||||||
|
:param datasource_name: the name of the datasource
|
||||||
|
:param tenant_id: the tenant id
|
||||||
|
|
||||||
|
:return: the datasource plugin
|
||||||
|
"""
|
||||||
|
return cls.get_datasource_plugin_provider(
|
||||||
|
provider_id,
|
||||||
|
tenant_id,
|
||||||
|
datasource_type,
|
||||||
|
).get_datasource(datasource_name)
|
||||||
71
api/core/datasource/entities/api_entities.py
Normal file
71
api/core/datasource/entities/api_entities.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceApiEntity(BaseModel):
|
||||||
|
author: str
|
||||||
|
name: str # identifier
|
||||||
|
label: I18nObject # label
|
||||||
|
description: I18nObject
|
||||||
|
parameters: Optional[list[DatasourceParameter]] = None
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
output_schema: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderApiEntity(BaseModel):
|
||||||
|
id: str
|
||||||
|
author: str
|
||||||
|
name: str # identifier
|
||||||
|
description: I18nObject
|
||||||
|
icon: str | dict
|
||||||
|
label: I18nObject # label
|
||||||
|
type: str
|
||||||
|
masked_credentials: Optional[dict] = None
|
||||||
|
original_credentials: Optional[dict] = None
|
||||||
|
is_team_authorization: bool = False
|
||||||
|
allow_delete: bool = True
|
||||||
|
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||||
|
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||||
|
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("datasources", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_to_empty_list(cls, v):
|
||||||
|
return v if v is not None else []
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# -------------
|
||||||
|
# overwrite datasource parameter types for temp fix
|
||||||
|
datasources = jsonable_encoder(self.datasources)
|
||||||
|
for datasource in datasources:
|
||||||
|
if datasource.get("parameters"):
|
||||||
|
for parameter in datasource.get("parameters"):
|
||||||
|
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||||
|
parameter["type"] = "files"
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"author": self.author,
|
||||||
|
"name": self.name,
|
||||||
|
"plugin_id": self.plugin_id,
|
||||||
|
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||||
|
"description": self.description.to_dict(),
|
||||||
|
"icon": self.icon,
|
||||||
|
"label": self.label.to_dict(),
|
||||||
|
"type": self.type.value,
|
||||||
|
"team_credentials": self.masked_credentials,
|
||||||
|
"is_team_authorization": self.is_team_authorization,
|
||||||
|
"allow_delete": self.allow_delete,
|
||||||
|
"datasources": datasources,
|
||||||
|
"labels": self.labels,
|
||||||
|
}
|
||||||
23
api/core/datasource/entities/common_entities.py
Normal file
23
api/core/datasource/entities/common_entities.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class I18nObject(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for i18n object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
en_US: str
|
||||||
|
zh_Hans: Optional[str] = Field(default=None)
|
||||||
|
pt_BR: Optional[str] = Field(default=None)
|
||||||
|
ja_JP: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.zh_Hans = self.zh_Hans or self.en_US
|
||||||
|
self.pt_BR = self.pt_BR or self.en_US
|
||||||
|
self.ja_JP = self.ja_JP or self.en_US
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||||
361
api/core/datasource/entities/datasource_entities.py
Normal file
361
api/core/datasource/entities/datasource_entities.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
import enum
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities.oauth import OAuthSchema
|
||||||
|
from core.plugin.entities.parameters import (
|
||||||
|
PluginParameter,
|
||||||
|
PluginParameterOption,
|
||||||
|
PluginParameterType,
|
||||||
|
as_normal_type,
|
||||||
|
cast_parameter_value,
|
||||||
|
init_frontend_parameter,
|
||||||
|
)
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
Enum class for datasource provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
ONLINE_DOCUMENT = "online_document"
|
||||||
|
LOCAL_FILE = "local_file"
|
||||||
|
WEBSITE_CRAWL = "website_crawl"
|
||||||
|
ONLINE_DRIVE = "online_drive"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f"invalid mode value {value}")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameter(PluginParameter):
|
||||||
|
"""
|
||||||
|
Overrides type
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DatasourceParameterType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
removes TOOLS_SELECTOR from PluginParameterType
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRING = PluginParameterType.STRING.value
|
||||||
|
NUMBER = PluginParameterType.NUMBER.value
|
||||||
|
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||||
|
SELECT = PluginParameterType.SELECT.value
|
||||||
|
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||||
|
FILE = PluginParameterType.FILE.value
|
||||||
|
FILES = PluginParameterType.FILES.value
|
||||||
|
|
||||||
|
# deprecated, should not use.
|
||||||
|
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||||
|
|
||||||
|
def as_normal_type(self):
|
||||||
|
return as_normal_type(self)
|
||||||
|
|
||||||
|
def cast_value(self, value: Any):
|
||||||
|
return cast_parameter_value(self, value)
|
||||||
|
|
||||||
|
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||||
|
description: I18nObject = Field(..., description="The description of the parameter")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_simple_instance(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
typ: DatasourceParameterType,
|
||||||
|
required: bool,
|
||||||
|
options: Optional[list[str]] = None,
|
||||||
|
) -> "DatasourceParameter":
|
||||||
|
"""
|
||||||
|
get a simple datasource parameter
|
||||||
|
|
||||||
|
:param name: the name of the parameter
|
||||||
|
:param llm_description: the description presented to the LLM
|
||||||
|
:param typ: the type of the parameter
|
||||||
|
:param required: if the parameter is required
|
||||||
|
:param options: the options of the parameter
|
||||||
|
"""
|
||||||
|
# convert options to ToolParameterOption
|
||||||
|
# FIXME fix the type error
|
||||||
|
if options:
|
||||||
|
option_objs = [
|
||||||
|
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||||
|
for option in options
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
option_objs = []
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
label=I18nObject(en_US="", zh_Hans=""),
|
||||||
|
placeholder=None,
|
||||||
|
type=typ,
|
||||||
|
required=required,
|
||||||
|
options=option_objs,
|
||||||
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_frontend_parameter(self, value: Any):
|
||||||
|
return init_frontend_parameter(self, self.type, value)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceIdentity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the datasource")
|
||||||
|
name: str = Field(..., description="The name of the datasource")
|
||||||
|
label: I18nObject = Field(..., description="The label of the datasource")
|
||||||
|
provider: str = Field(..., description="The provider of the datasource")
|
||||||
|
icon: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEntity(BaseModel):
|
||||||
|
identity: DatasourceIdentity
|
||||||
|
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||||
|
description: I18nObject = Field(..., description="The label of the datasource")
|
||||||
|
|
||||||
|
@field_validator("parameters", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
|
||||||
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderIdentity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the tool")
|
||||||
|
name: str = Field(..., description="The name of the tool")
|
||||||
|
description: I18nObject = Field(..., description="The description of the tool")
|
||||||
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
label: I18nObject = Field(..., description="The label of the tool")
|
||||||
|
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||||
|
default=[],
|
||||||
|
description="The tags of the tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource provider entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
identity: DatasourceProviderIdentity
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||||
|
oauth_schema: Optional[OAuthSchema] = None
|
||||||
|
provider_type: DatasourceProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||||
|
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeMeta(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource invoke meta
|
||||||
|
"""
|
||||||
|
|
||||||
|
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||||
|
error: Optional[str] = None
|
||||||
|
tool_config: Optional[dict] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "DatasourceInvokeMeta":
|
||||||
|
"""
|
||||||
|
Get an empty instance of DatasourceInvokeMeta
|
||||||
|
"""
|
||||||
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||||
|
"""
|
||||||
|
Get an instance of DatasourceInvokeMeta with error
|
||||||
|
"""
|
||||||
|
return cls(time_cost=0.0, error=error, tool_config={})
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"time_cost": self.time_cost,
|
||||||
|
"error": self.error,
|
||||||
|
"tool_config": self.tool_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceLabel(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource label
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="The name of the tool")
|
||||||
|
label: I18nObject = Field(..., description="The label of the tool")
|
||||||
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeFrom(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for datasource invoke
|
||||||
|
"""
|
||||||
|
|
||||||
|
RAG_PIPELINE = "rag_pipeline"
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPage(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document page
|
||||||
|
"""
|
||||||
|
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
page_name: str = Field(..., description="The page title")
|
||||||
|
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||||
|
type: str = Field(..., description="The type of the page")
|
||||||
|
last_edited_time: str = Field(..., description="The last edited time")
|
||||||
|
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document info
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
workspace_name: str = Field(..., description="The workspace name")
|
||||||
|
workspace_icon: str = Field(..., description="The workspace icon")
|
||||||
|
total: int = Field(..., description="The total number of documents")
|
||||||
|
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPagesMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document pages response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: list[OnlineDocumentInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document page content request
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
type: str = Field(..., description="The type of the page")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPageContent(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document page content
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
content: str = Field(..., description="The content of the page")
|
||||||
|
|
||||||
|
|
||||||
|
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document page content response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: OnlineDocumentPageContent
|
||||||
|
|
||||||
|
|
||||||
|
class GetWebsiteCrawlRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get website crawl request
|
||||||
|
"""
|
||||||
|
|
||||||
|
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSiteInfoDetail(BaseModel):
|
||||||
|
source_url: str = Field(..., description="The url of the website")
|
||||||
|
content: str = Field(..., description="The content of the website")
|
||||||
|
title: str = Field(..., description="The title of the website")
|
||||||
|
description: str = Field(..., description="The description of the website")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Website info
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Optional[str] = Field(..., description="crawl job status")
|
||||||
|
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||||
|
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||||
|
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Get website crawl response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceMessage(ToolInvokeMessage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# Online driver file
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveFile(BaseModel):
|
||||||
|
"""
|
||||||
|
Online driver file
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str = Field(..., description="The key of the file")
|
||||||
|
size: int = Field(..., description="The size of the file")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveFileBucket(BaseModel):
|
||||||
|
"""
|
||||||
|
Online driver file bucket
|
||||||
|
"""
|
||||||
|
|
||||||
|
bucket: Optional[str] = Field(None, description="The bucket of the file")
|
||||||
|
files: list[OnlineDriveFile] = Field(..., description="The files of the bucket")
|
||||||
|
is_truncated: bool = Field(False, description="Whether the bucket has more files")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online driver file list request
|
||||||
|
"""
|
||||||
|
|
||||||
|
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
|
||||||
|
bucket: Optional[str] = Field(None, description="Storage bucket name")
|
||||||
|
max_keys: int = Field(20, description="Maximum number of files to return")
|
||||||
|
start_after: Optional[str] = Field(
|
||||||
|
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online driver file list response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDownloadFileRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online driver file
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str = Field(..., description="The name of the file")
|
||||||
|
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||||
37
api/core/datasource/errors.py
Normal file
37
api/core/datasource/errors.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameterValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderCredentialValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNotSupportedError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceApiSchemaError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEngineInvokeError(Exception):
|
||||||
|
meta: DatasourceInvokeMeta
|
||||||
|
|
||||||
|
def __init__(self, meta, **kwargs):
|
||||||
|
self.meta = meta
|
||||||
|
super().__init__(**kwargs)
|
||||||
28
api/core/datasource/local_file/local_file_plugin.py
Normal file
28
api/core/datasource/local_file/local_file_plugin.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
56
api/core/datasource/local_file/local_file_provider.py
Normal file
56
api/core/datasource/local_file/local_file_provider.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return LocalFileDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def get_online_document_pages(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_pages(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_page_content(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_page_content(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return OnlineDocumentDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
73
api/core/datasource/online_drive/online_drive_plugin.py
Normal file
73
api/core/datasource/online_drive/online_drive_plugin.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceProviderType,
|
||||||
|
OnlineDriveBrowseFilesRequest,
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def online_drive_browse_files(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
request: OnlineDriveBrowseFilesRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.online_drive_browse_files(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
request=request,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def online_drive_download_file(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
request: OnlineDriveDownloadFileRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.online_drive_download_file(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
request=request,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.ONLINE_DRIVE
|
||||||
48
api/core/datasource/online_drive/online_drive_provider.py
Normal file
48
api/core/datasource/online_drive/online_drive_provider.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DRIVE
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return OnlineDriveDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
0
api/core/datasource/utils/__init__.py
Normal file
0
api/core/datasource/utils/__init__.py
Normal file
265
api/core/datasource/utils/configuration.py
Normal file
265
api/core/datasource/utils/configuration.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
|
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||||
|
from core.tools.__base.tool import Tool
|
||||||
|
from core.tools.entities.tool_entities import (
|
||||||
|
ToolParameter,
|
||||||
|
ToolProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigEncrypter(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
config: list[BasicProviderConfig]
|
||||||
|
provider_type: str
|
||||||
|
provider_identity: str
|
||||||
|
|
||||||
|
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
deep copy data
|
||||||
|
"""
|
||||||
|
return deepcopy(data)
|
||||||
|
|
||||||
|
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
encrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with encrypted values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||||
|
data[field_name] = encrypted
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask tool credentials
|
||||||
|
|
||||||
|
return a deep copy of credentials with masked values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
if len(data[field_name]) > 6:
|
||||||
|
data[field_name] = (
|
||||||
|
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data[field_name] = "*" * len(data[field_name])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with decrypted values
|
||||||
|
"""
|
||||||
|
cache = ToolProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||||
|
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||||
|
)
|
||||||
|
cached_credentials = cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
try:
|
||||||
|
# if the value is None or empty string, skip decrypt
|
||||||
|
if not data[field_name]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
cache.set(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def delete_tool_credentials_cache(self):
|
||||||
|
cache = ToolProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||||
|
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||||
|
)
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParameterConfigurationManager:
|
||||||
|
"""
|
||||||
|
Tool parameter configuration manager
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
tool_runtime: Tool
|
||||||
|
provider_name: str
|
||||||
|
provider_type: ToolProviderType
|
||||||
|
identity_id: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||||
|
) -> None:
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.tool_runtime = tool_runtime
|
||||||
|
self.provider_name = provider_name
|
||||||
|
self.provider_type = provider_type
|
||||||
|
self.identity_id = identity_id
|
||||||
|
|
||||||
|
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
deep copy parameters
|
||||||
|
"""
|
||||||
|
return deepcopy(parameters)
|
||||||
|
|
||||||
|
def _merge_parameters(self) -> list[ToolParameter]:
|
||||||
|
"""
|
||||||
|
merge parameters
|
||||||
|
"""
|
||||||
|
# get tool parameters
|
||||||
|
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||||
|
# get tool runtime parameters
|
||||||
|
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||||
|
# override parameters
|
||||||
|
current_parameters = tool_parameters.copy()
|
||||||
|
for runtime_parameter in runtime_parameters:
|
||||||
|
found = False
|
||||||
|
for index, parameter in enumerate(current_parameters):
|
||||||
|
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||||
|
current_parameters[index] = runtime_parameter
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||||
|
current_parameters.append(runtime_parameter)
|
||||||
|
|
||||||
|
return current_parameters
|
||||||
|
|
||||||
|
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask tool parameters
|
||||||
|
|
||||||
|
return a deep copy of parameters with masked values
|
||||||
|
"""
|
||||||
|
parameters = self._deep_copy(parameters)
|
||||||
|
|
||||||
|
# override parameters
|
||||||
|
current_parameters = self._merge_parameters()
|
||||||
|
|
||||||
|
for parameter in current_parameters:
|
||||||
|
if (
|
||||||
|
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||||
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
):
|
||||||
|
if parameter.name in parameters:
|
||||||
|
if len(parameters[parameter.name]) > 6:
|
||||||
|
parameters[parameter.name] = (
|
||||||
|
parameters[parameter.name][:2]
|
||||||
|
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||||
|
+ parameters[parameter.name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
encrypt tool parameters with tenant id
|
||||||
|
|
||||||
|
return a deep copy of parameters with encrypted values
|
||||||
|
"""
|
||||||
|
# override parameters
|
||||||
|
current_parameters = self._merge_parameters()
|
||||||
|
|
||||||
|
parameters = self._deep_copy(parameters)
|
||||||
|
|
||||||
|
for parameter in current_parameters:
|
||||||
|
if (
|
||||||
|
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||||
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
):
|
||||||
|
if parameter.name in parameters:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||||
|
parameters[parameter.name] = encrypted
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
decrypt tool parameters with tenant id
|
||||||
|
|
||||||
|
return a deep copy of parameters with decrypted values
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||||
|
tool_name=self.tool_runtime.entity.identity.name,
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id=self.identity_id,
|
||||||
|
)
|
||||||
|
cached_parameters = cache.get()
|
||||||
|
if cached_parameters:
|
||||||
|
return cached_parameters
|
||||||
|
|
||||||
|
# override parameters
|
||||||
|
current_parameters = self._merge_parameters()
|
||||||
|
has_secret_input = False
|
||||||
|
|
||||||
|
for parameter in current_parameters:
|
||||||
|
if (
|
||||||
|
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||||
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
):
|
||||||
|
if parameter.name in parameters:
|
||||||
|
try:
|
||||||
|
has_secret_input = True
|
||||||
|
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if has_secret_input:
|
||||||
|
cache.set(parameters)
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def delete_tool_parameters_cache(self):
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||||
|
tool_name=self.tool_runtime.entity.identity.name,
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id=self.identity_id,
|
||||||
|
)
|
||||||
|
cache.delete()
|
||||||
121
api/core/datasource/utils/message_transformer.py
Normal file
121
api/core/datasource/utils/message_transformer.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from mimetypes import guess_extension
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileMessageTransformer:
|
||||||
|
@classmethod
|
||||||
|
def transform_datasource_invoke_messages(
|
||||||
|
cls,
|
||||||
|
messages: Generator[DatasourceMessage, None, None],
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Transform datasource message and handle file download
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
|
||||||
|
yield message
|
||||||
|
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
|
||||||
|
message.message, DatasourceMessage.TextMessage
|
||||||
|
):
|
||||||
|
# try to download image
|
||||||
|
try:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
|
||||||
|
file = DatasourceFileManager.create_file_by_url(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
file_url=message.message.text,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||||
|
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.TEXT,
|
||||||
|
message=DatasourceMessage.TextMessage(
|
||||||
|
text=f"Failed to download image: {message.message.text}: {e}"
|
||||||
|
),
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||||
|
# get mime type and save blob to storage
|
||||||
|
meta = message.meta or {}
|
||||||
|
|
||||||
|
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||||
|
# get filename from meta
|
||||||
|
filename = meta.get("file_name", None)
|
||||||
|
# if message is str, encode it to bytes
|
||||||
|
|
||||||
|
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||||
|
raise ValueError("unexpected message type")
|
||||||
|
|
||||||
|
# FIXME: should do a type check here.
|
||||||
|
assert isinstance(message.message.blob, bytes)
|
||||||
|
file = DatasourceFileManager.create_file_by_raw(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_binary=message.message.blob,
|
||||||
|
mimetype=mimetype,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
|
||||||
|
|
||||||
|
# check if file is image
|
||||||
|
if "image" in mimetype:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||||
|
meta = message.meta or {}
|
||||||
|
file = meta.get("file", None)
|
||||||
|
if isinstance(file, File):
|
||||||
|
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert file.related_id is not None
|
||||||
|
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield message
|
||||||
|
else:
|
||||||
|
yield message
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||||
|
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||||
389
api/core/datasource/utils/parser.py
Normal file
389
api/core/datasource/utils/parser.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from json import dumps as json_dumps
|
||||||
|
from json import loads as json_loads
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from requests import get
|
||||||
|
from yaml import YAMLError, safe_load # type: ignore
|
||||||
|
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||||
|
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
class ApiBasedToolSchemaParser:
|
||||||
|
@staticmethod
|
||||||
|
def parse_openapi_to_tool_bundle(
|
||||||
|
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> list[ApiToolBundle]:
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
# set description to extra_info
|
||||||
|
extra_info["description"] = openapi["info"].get("description", "")
|
||||||
|
|
||||||
|
if len(openapi["servers"]) == 0:
|
||||||
|
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||||
|
|
||||||
|
server_url = openapi["servers"][0]["url"]
|
||||||
|
request_env = request.headers.get("X-Request-Env")
|
||||||
|
if request_env:
|
||||||
|
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||||
|
server_url = matched_servers[0] if matched_servers else server_url
|
||||||
|
|
||||||
|
# list all interfaces
|
||||||
|
interfaces = []
|
||||||
|
for path, path_item in openapi["paths"].items():
|
||||||
|
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||||
|
for method in methods:
|
||||||
|
if method in path_item:
|
||||||
|
interfaces.append(
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"method": method,
|
||||||
|
"operation": path_item[method],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# get all parameters
|
||||||
|
bundles = []
|
||||||
|
for interface in interfaces:
|
||||||
|
# convert parameters
|
||||||
|
parameters = []
|
||||||
|
if "parameters" in interface["operation"]:
|
||||||
|
for parameter in interface["operation"]["parameters"]:
|
||||||
|
tool_parameter = ToolParameter(
|
||||||
|
name=parameter["name"],
|
||||||
|
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||||
|
human_description=I18nObject(
|
||||||
|
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||||
|
),
|
||||||
|
type=ToolParameter.ToolParameterType.STRING,
|
||||||
|
required=parameter.get("required", False),
|
||||||
|
form=ToolParameter.ToolParameterForm.LLM,
|
||||||
|
llm_description=parameter.get("description"),
|
||||||
|
default=parameter["schema"]["default"]
|
||||||
|
if "schema" in parameter and "default" in parameter["schema"]
|
||||||
|
else None,
|
||||||
|
placeholder=I18nObject(
|
||||||
|
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if there is a type
|
||||||
|
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||||
|
if typ:
|
||||||
|
tool_parameter.type = typ
|
||||||
|
|
||||||
|
parameters.append(tool_parameter)
|
||||||
|
# create tool bundle
|
||||||
|
# check if there is a request body
|
||||||
|
if "requestBody" in interface["operation"]:
|
||||||
|
request_body = interface["operation"]["requestBody"]
|
||||||
|
if "content" in request_body:
|
||||||
|
for content_type, content in request_body["content"].items():
|
||||||
|
# if there is a reference, get the reference and overwrite the content
|
||||||
|
if "schema" not in content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "$ref" in content["schema"]:
|
||||||
|
# get the reference
|
||||||
|
root = openapi
|
||||||
|
reference = content["schema"]["$ref"].split("/")[1:]
|
||||||
|
for ref in reference:
|
||||||
|
root = root[ref]
|
||||||
|
# overwrite the content
|
||||||
|
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||||
|
|
||||||
|
# parse body parameters
|
||||||
|
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||||
|
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||||
|
required = body_schema.get("required", [])
|
||||||
|
properties = body_schema.get("properties", {})
|
||||||
|
for name, property in properties.items():
|
||||||
|
tool = ToolParameter(
|
||||||
|
name=name,
|
||||||
|
label=I18nObject(en_US=name, zh_Hans=name),
|
||||||
|
human_description=I18nObject(
|
||||||
|
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||||
|
),
|
||||||
|
type=ToolParameter.ToolParameterType.STRING,
|
||||||
|
required=name in required,
|
||||||
|
form=ToolParameter.ToolParameterForm.LLM,
|
||||||
|
llm_description=property.get("description", ""),
|
||||||
|
default=property.get("default", None),
|
||||||
|
placeholder=I18nObject(
|
||||||
|
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if there is a type
|
||||||
|
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||||
|
if typ:
|
||||||
|
tool.type = typ
|
||||||
|
|
||||||
|
parameters.append(tool)
|
||||||
|
|
||||||
|
# check if parameters is duplicated
|
||||||
|
parameters_count = {}
|
||||||
|
for parameter in parameters:
|
||||||
|
if parameter.name not in parameters_count:
|
||||||
|
parameters_count[parameter.name] = 0
|
||||||
|
parameters_count[parameter.name] += 1
|
||||||
|
for name, count in parameters_count.items():
|
||||||
|
if count > 1:
|
||||||
|
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||||
|
|
||||||
|
# check if there is a operation id, use $path_$method as operation id if not
|
||||||
|
if "operationId" not in interface["operation"]:
|
||||||
|
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||||
|
path = interface["path"]
|
||||||
|
if interface["path"].startswith("/"):
|
||||||
|
path = interface["path"][1:]
|
||||||
|
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||||
|
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||||
|
if not path:
|
||||||
|
path = str(uuid.uuid4())
|
||||||
|
|
||||||
|
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||||
|
|
||||||
|
bundles.append(
|
||||||
|
ApiToolBundle(
|
||||||
|
server_url=server_url + interface["path"],
|
||||||
|
method=interface["method"],
|
||||||
|
summary=interface["operation"]["description"]
|
||||||
|
if "description" in interface["operation"]
|
||||||
|
else interface["operation"].get("summary", None),
|
||||||
|
operation_id=interface["operation"]["operationId"],
|
||||||
|
parameters=parameters,
|
||||||
|
author="",
|
||||||
|
icon=None,
|
||||||
|
openapi=interface["operation"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return bundles
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||||
|
parameter = parameter or {}
|
||||||
|
typ: Optional[str] = None
|
||||||
|
if parameter.get("format") == "binary":
|
||||||
|
return ToolParameter.ToolParameterType.FILE
|
||||||
|
|
||||||
|
if "type" in parameter:
|
||||||
|
typ = parameter["type"]
|
||||||
|
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||||
|
typ = parameter["schema"]["type"]
|
||||||
|
|
||||||
|
if typ in {"integer", "number"}:
|
||||||
|
return ToolParameter.ToolParameterType.NUMBER
|
||||||
|
elif typ == "boolean":
|
||||||
|
return ToolParameter.ToolParameterType.BOOLEAN
|
||||||
|
elif typ == "string":
|
||||||
|
return ToolParameter.ToolParameterType.STRING
|
||||||
|
elif typ == "array":
|
||||||
|
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||||
|
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_openapi_yaml_to_tool_bundle(
|
||||||
|
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> list[ApiToolBundle]:
|
||||||
|
"""
|
||||||
|
parse openapi yaml to tool bundle
|
||||||
|
|
||||||
|
:param yaml: the yaml string
|
||||||
|
:param extra_info: the extra info
|
||||||
|
:param warning: the warning message
|
||||||
|
:return: the tool bundle
|
||||||
|
"""
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
openapi: dict = safe_load(yaml)
|
||||||
|
if openapi is None:
|
||||||
|
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||||
|
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||||
|
warning = warning or {}
|
||||||
|
"""
|
||||||
|
parse swagger to openapi
|
||||||
|
|
||||||
|
:param swagger: the swagger dict
|
||||||
|
:return: the openapi dict
|
||||||
|
"""
|
||||||
|
# convert swagger to openapi
|
||||||
|
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||||
|
|
||||||
|
servers = swagger.get("servers", [])
|
||||||
|
|
||||||
|
if len(servers) == 0:
|
||||||
|
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||||
|
|
||||||
|
openapi = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {
|
||||||
|
"title": info.get("title", "Swagger"),
|
||||||
|
"description": info.get("description", "Swagger"),
|
||||||
|
"version": info.get("version", "1.0.0"),
|
||||||
|
},
|
||||||
|
"servers": swagger["servers"],
|
||||||
|
"paths": {},
|
||||||
|
"components": {"schemas": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
# check paths
|
||||||
|
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||||
|
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||||
|
|
||||||
|
# convert paths
|
||||||
|
for path, path_item in swagger["paths"].items():
|
||||||
|
openapi["paths"][path] = {}
|
||||||
|
for method, operation in path_item.items():
|
||||||
|
if "operationId" not in operation:
|
||||||
|
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||||
|
|
||||||
|
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||||
|
"description" not in operation or len(operation["description"]) == 0
|
||||||
|
):
|
||||||
|
if warning is not None:
|
||||||
|
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||||
|
|
||||||
|
openapi["paths"][path][method] = {
|
||||||
|
"operationId": operation["operationId"],
|
||||||
|
"summary": operation.get("summary", ""),
|
||||||
|
"description": operation.get("description", ""),
|
||||||
|
"parameters": operation.get("parameters", []),
|
||||||
|
"responses": operation.get("responses", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "requestBody" in operation:
|
||||||
|
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||||
|
|
||||||
|
# convert definitions
|
||||||
|
for name, definition in swagger["definitions"].items():
|
||||||
|
openapi["components"]["schemas"][name] = definition
|
||||||
|
|
||||||
|
return openapi
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_openai_plugin_json_to_tool_bundle(
|
||||||
|
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> list[ApiToolBundle]:
|
||||||
|
"""
|
||||||
|
parse openapi plugin yaml to tool bundle
|
||||||
|
|
||||||
|
:param json: the json string
|
||||||
|
:param extra_info: the extra info
|
||||||
|
:param warning: the warning message
|
||||||
|
:return: the tool bundle
|
||||||
|
"""
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
openai_plugin = json_loads(json)
|
||||||
|
api = openai_plugin["api"]
|
||||||
|
api_url = api["url"]
|
||||||
|
api_type = api["type"]
|
||||||
|
except JSONDecodeError:
|
||||||
|
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||||
|
|
||||||
|
if api_type != "openapi":
|
||||||
|
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||||
|
|
||||||
|
# get openapi yaml
|
||||||
|
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||||
|
|
||||||
|
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||||
|
response.text, extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auto_parse_to_tool_bundle(
|
||||||
|
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> tuple[list[ApiToolBundle], str]:
|
||||||
|
"""
|
||||||
|
auto parse to tool bundle
|
||||||
|
|
||||||
|
:param content: the content
|
||||||
|
:param extra_info: the extra info
|
||||||
|
:param warning: the warning message
|
||||||
|
:return: tools bundle, schema_type
|
||||||
|
"""
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
content = content.strip()
|
||||||
|
loaded_content = None
|
||||||
|
json_error = None
|
||||||
|
yaml_error = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
loaded_content = json_loads(content)
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
json_error = e
|
||||||
|
|
||||||
|
if loaded_content is None:
|
||||||
|
try:
|
||||||
|
loaded_content = safe_load(content)
|
||||||
|
except YAMLError as e:
|
||||||
|
yaml_error = e
|
||||||
|
if loaded_content is None:
|
||||||
|
raise ToolApiSchemaError(
|
||||||
|
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||||
|
f" yaml error: {str(yaml_error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
swagger_error = None
|
||||||
|
openapi_error = None
|
||||||
|
openapi_plugin_error = None
|
||||||
|
schema_type = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||||
|
loaded_content, extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||||
|
return openapi, schema_type
|
||||||
|
except ToolApiSchemaError as e:
|
||||||
|
openapi_error = e
|
||||||
|
|
||||||
|
# openai parse error, fallback to swagger
|
||||||
|
try:
|
||||||
|
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||||
|
loaded_content, extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||||
|
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||||
|
converted_swagger, extra_info=extra_info, warning=warning
|
||||||
|
), schema_type
|
||||||
|
except ToolApiSchemaError as e:
|
||||||
|
swagger_error = e
|
||||||
|
|
||||||
|
# swagger parse error, fallback to openai plugin
|
||||||
|
try:
|
||||||
|
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||||
|
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||||
|
except ToolNotSupportedError as e:
|
||||||
|
# maybe it's not plugin at all
|
||||||
|
openapi_plugin_error = e
|
||||||
|
|
||||||
|
raise ToolApiSchemaError(
|
||||||
|
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||||
|
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||||
|
)
|
||||||
17
api/core/datasource/utils/text_processing_utils.py
Normal file
17
api/core/datasource/utils/text_processing_utils.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def remove_leading_symbols(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove leading punctuation or symbols from the given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text with leading punctuation or symbols removed.
|
||||||
|
"""
|
||||||
|
# Match Unicode ranges for punctuation and symbols
|
||||||
|
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||||
|
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||||
|
return re.sub(pattern, "", text)
|
||||||
9
api/core/datasource/utils/uuid_utils.py
Normal file
9
api/core/datasource/utils/uuid_utils.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_uuid(uuid_str: str) -> bool:
|
||||||
|
try:
|
||||||
|
uuid.UUID(uuid_str)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
43
api/core/datasource/utils/workflow_configuration_sync.py
Normal file
43
api/core/datasource/utils/workflow_configuration_sync.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.app.app_config.entities import VariableEntity
|
||||||
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowToolConfigurationUtils:
|
||||||
|
@classmethod
|
||||||
|
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||||
|
for configuration in configurations:
|
||||||
|
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||||
|
"""
|
||||||
|
get workflow graph variables
|
||||||
|
"""
|
||||||
|
nodes = graph.get("nodes", [])
|
||||||
|
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||||
|
|
||||||
|
if not start_node:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_is_synced(
|
||||||
|
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
check is synced
|
||||||
|
|
||||||
|
raise ValueError if not synced
|
||||||
|
"""
|
||||||
|
variable_names = [variable.variable for variable in variables]
|
||||||
|
|
||||||
|
if len(tool_configurations) != len(variables):
|
||||||
|
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||||
|
|
||||||
|
for parameter in tool_configurations:
|
||||||
|
if parameter.name not in variable_names:
|
||||||
|
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||||
35
api/core/datasource/utils/yaml_utils.py
Normal file
35
api/core/datasource/utils/yaml_utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml # type: ignore
|
||||||
|
from yaml import YAMLError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||||
|
"""
|
||||||
|
Safe loading a YAML file
|
||||||
|
:param file_path: the path of the YAML file
|
||||||
|
:param ignore_error:
|
||||||
|
if True, return default_value if error occurs and the error will be logged in debug level
|
||||||
|
if False, raise error if error occurs
|
||||||
|
:param default_value: the value returned when errors ignored
|
||||||
|
:return: an object of the YAML content
|
||||||
|
"""
|
||||||
|
if not file_path or not Path(file_path).exists():
|
||||||
|
if ignore_error:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8") as yaml_file:
|
||||||
|
try:
|
||||||
|
yaml_content = yaml.safe_load(yaml_file)
|
||||||
|
return yaml_content or default_value
|
||||||
|
except Exception as e:
|
||||||
|
if ignore_error:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||||
53
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
53
api/core/datasource/website_crawl/website_crawl_plugin.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
WebsiteCrawlMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def get_website_crawl(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_website_crawl(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
52
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
52
api/core/datasource/website_crawl/website_crawl_provider.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceProviderEntityWithPlugin,
|
||||||
|
plugin_id: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return WebsiteCrawlDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@@ -17,3 +17,27 @@ class IndexingEstimate(BaseModel):
|
|||||||
total_segments: int
|
total_segments: int
|
||||||
preview: list[PreviewDetail]
|
preview: list[PreviewDetail]
|
||||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineDataset(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
chunk_structure: str
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineDocument(BaseModel):
|
||||||
|
id: str
|
||||||
|
position: int
|
||||||
|
data_source_type: str
|
||||||
|
data_source_info: Optional[dict] = None
|
||||||
|
name: str
|
||||||
|
indexing_status: str
|
||||||
|
error: Optional[str] = None
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineGenerateResponse(BaseModel):
|
||||||
|
batch: str
|
||||||
|
dataset: PipelineDataset
|
||||||
|
documents: list[PipelineDocument]
|
||||||
|
|||||||
15
api/core/file/datasource_file_parser.py
Normal file
15
api/core/file/datasource_file_parser.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from core.datasource import datasource_file_manager
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
|
||||||
|
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileParser:
|
||||||
|
@staticmethod
|
||||||
|
def get_datasource_file_manager() -> "DatasourceFileManager":
|
||||||
|
return cast("DatasourceFileManager", datasource_file_manager["manager"])
|
||||||
@@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum):
|
|||||||
REMOTE_URL = "remote_url"
|
REMOTE_URL = "remote_url"
|
||||||
LOCAL_FILE = "local_file"
|
LOCAL_FILE = "local_file"
|
||||||
TOOL_FILE = "tool_file"
|
TOOL_FILE = "tool_file"
|
||||||
|
DATASOURCE_FILE = "datasource_file"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def value_of(value):
|
def value_of(value):
|
||||||
|
|||||||
@@ -135,3 +135,4 @@ class TraceTaskName(StrEnum):
|
|||||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||||
TOOL_TRACE = "tool"
|
TOOL_TRACE = "tool"
|
||||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||||
|
DATASOURCE_TRACE = "datasource"
|
||||||
|
|||||||
21
api/core/plugin/entities/oauth.py
Normal file
21
api/core/plugin/entities/oauth.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
OAuth schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
client_schema: Sequence[ProviderConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="client schema like client_id, client_secret, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="credentials schema like access_token, refresh_token, etc.",
|
||||||
|
)
|
||||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, model_validator
|
|||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
from core.plugin.entities.base import BasePluginEntity
|
from core.plugin.entities.base import BasePluginEntity
|
||||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||||
@@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum):
|
|||||||
Model = "model"
|
Model = "model"
|
||||||
Extension = "extension"
|
Extension = "extension"
|
||||||
AgentStrategy = "agent-strategy"
|
AgentStrategy = "agent-strategy"
|
||||||
|
Datasource = "datasource"
|
||||||
|
|
||||||
|
|
||||||
class PluginDeclaration(BaseModel):
|
class PluginDeclaration(BaseModel):
|
||||||
@@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel):
|
|||||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
|
datasources: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
|
|
||||||
class Meta(BaseModel):
|
class Meta(BaseModel):
|
||||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||||
@@ -90,6 +93,7 @@ class PluginDeclaration(BaseModel):
|
|||||||
model: Optional[ProviderEntity] = None
|
model: Optional[ProviderEntity] = None
|
||||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||||
|
datasource: Optional[DatasourceProviderEntity] = None
|
||||||
meta: Meta
|
meta: Meta
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -100,6 +104,8 @@ class PluginDeclaration(BaseModel):
|
|||||||
values["category"] = PluginCategory.Tool
|
values["category"] = PluginCategory.Tool
|
||||||
elif values.get("model"):
|
elif values.get("model"):
|
||||||
values["category"] = PluginCategory.Model
|
values["category"] = PluginCategory.Model
|
||||||
|
elif values.get("datasource"):
|
||||||
|
values["category"] = PluginCategory.Datasource
|
||||||
elif values.get("agent_strategy"):
|
elif values.get("agent_strategy"):
|
||||||
values["category"] = PluginCategory.AgentStrategy
|
values["category"] = PluginCategory.AgentStrategy
|
||||||
else:
|
else:
|
||||||
@@ -193,6 +199,11 @@ class ToolProviderID(GenericProviderID):
|
|||||||
self.plugin_name = f"{self.provider_name}_tool"
|
self.plugin_name = f"{self.provider_name}_tool"
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderID(GenericProviderID):
|
||||||
|
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||||
|
super().__init__(value, is_hardcoded)
|
||||||
|
|
||||||
|
|
||||||
class PluginDependency(BaseModel):
|
class PluginDependency(BaseModel):
|
||||||
class Type(enum.StrEnum):
|
class Type(enum.StrEnum):
|
||||||
Github = PluginInstallationSource.Github.value
|
Github = PluginInstallationSource.Github.value
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Generic, Optional, TypeVar
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
from core.plugin.entities.base import BasePluginEntity
|
from core.plugin.entities.base import BasePluginEntity
|
||||||
@@ -48,6 +49,14 @@ class PluginToolProviderEntity(BaseModel):
|
|||||||
declaration: ToolProviderEntityWithPlugin
|
declaration: ToolProviderEntityWithPlugin
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDatasourceProviderEntity(BaseModel):
|
||||||
|
provider: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
plugin_id: str
|
||||||
|
is_authorized: bool = False
|
||||||
|
declaration: DatasourceProviderEntityWithPlugin
|
||||||
|
|
||||||
|
|
||||||
class PluginAgentProviderEntity(BaseModel):
|
class PluginAgentProviderEntity(BaseModel):
|
||||||
provider: str
|
provider: str
|
||||||
plugin_unique_identifier: str
|
plugin_unique_identifier: str
|
||||||
|
|||||||
329
api/core/plugin/impl/datasource.py
Normal file
329
api/core/plugin/impl/datasource.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceMessage,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
OnlineDriveBrowseFilesRequest,
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
WebsiteCrawlMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import (
|
||||||
|
PluginBasicBooleanResponse,
|
||||||
|
PluginDatasourceProviderEntity,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDatasourceManager(BasePluginClient):
|
||||||
|
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||||
|
"""
|
||||||
|
Fetch datasource providers for the given tenant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
if json_response.get("data"):
|
||||||
|
for provider in json_response.get("data", []):
|
||||||
|
declaration = provider.get("declaration", {}) or {}
|
||||||
|
provider_name = declaration.get("identity", {}).get("name")
|
||||||
|
for datasource in declaration.get("datasources", []):
|
||||||
|
datasource["identity"]["provider"] = provider_name
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response(
|
||||||
|
"GET",
|
||||||
|
f"plugin/{tenant_id}/management/datasources",
|
||||||
|
list[PluginDatasourceProviderEntity],
|
||||||
|
params={"page": 1, "page_size": 256},
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||||
|
|
||||||
|
for provider in response:
|
||||||
|
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||||
|
all_response = [local_file_datasource_provider] + response
|
||||||
|
|
||||||
|
for provider in all_response:
|
||||||
|
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each tool to plugin_id/provider_name
|
||||||
|
for tool in provider.declaration.datasources:
|
||||||
|
tool.identity.provider = provider.declaration.identity.name
|
||||||
|
|
||||||
|
return all_response
|
||||||
|
|
||||||
|
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
|
||||||
|
"""
|
||||||
|
Fetch datasource provider for the given tenant and plugin.
|
||||||
|
"""
|
||||||
|
if provider_id == "langgenius/file/file":
|
||||||
|
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||||
|
|
||||||
|
tool_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
data = json_response.get("data")
|
||||||
|
if data:
|
||||||
|
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||||
|
datasource["identity"]["provider"] = tool_provider_id.provider_name
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response(
|
||||||
|
"GET",
|
||||||
|
f"plugin/{tenant_id}/management/datasource",
|
||||||
|
PluginDatasourceProviderEntity,
|
||||||
|
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each tool to plugin_id/provider_name
|
||||||
|
for datasource in response.declaration.datasources:
|
||||||
|
datasource.identity.provider = response.declaration.identity.name
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_website_crawl(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
|
||||||
|
WebsiteCrawlMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"datasource_parameters": datasource_parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_pages(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"datasource_parameters": datasource_parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_page_content(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
|
||||||
|
DatasourceMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"page": datasource_parameters.model_dump(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def online_drive_browse_files(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
request: OnlineDriveBrowseFilesRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"request": request.model_dump(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield from response
|
||||||
|
|
||||||
|
def online_drive_download_file(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
request: OnlineDriveDownloadFileRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
|
||||||
|
DatasourceMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"request": request.model_dump(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield from response
|
||||||
|
|
||||||
|
def validate_provider_credentials(
|
||||||
|
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
# datasource_provider_id = GenericProviderID(provider_id)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||||
|
PluginBasicBooleanResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp.result
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": "langgenius/file/file",
|
||||||
|
"plugin_id": "langgenius/file",
|
||||||
|
"provider": "file",
|
||||||
|
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||||
|
"declaration": {
|
||||||
|
"identity": {
|
||||||
|
"author": "langgenius",
|
||||||
|
"name": "file",
|
||||||
|
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
|
||||||
|
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
},
|
||||||
|
"credentials_schema": [],
|
||||||
|
"provider_type": "local_file",
|
||||||
|
"datasources": [
|
||||||
|
{
|
||||||
|
"identity": {
|
||||||
|
"author": "langgenius",
|
||||||
|
"name": "upload-file",
|
||||||
|
"provider": "file",
|
||||||
|
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
},
|
||||||
|
"parameters": [],
|
||||||
|
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -4,7 +4,10 @@ from typing import Any, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
from core.plugin.entities.plugin_daemon import (
|
||||||
|
PluginBasicBooleanResponse,
|
||||||
|
PluginToolProviderEntity,
|
||||||
|
)
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
|
|
||||||
@@ -197,6 +200,36 @@ class PluginToolManager(BasePluginClient):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def validate_datasource_credentials(
|
||||||
|
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
validate the credentials of the datasource
|
||||||
|
"""
|
||||||
|
tool_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||||
|
PluginBasicBooleanResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": tool_provider_id.provider_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp.result
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def get_runtime_parameters(
|
def get_runtime_parameters(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@@ -28,10 +28,12 @@ class Jieba(BaseKeyword):
|
|||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
keyword_table_handler = JiebaKeywordTableHandler()
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
keyword_number = (
|
||||||
|
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||||
|
)
|
||||||
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
keywords = keyword_table_handler.extract_keywords(
|
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||||
text.page_content, self._config.max_keywords_per_chunk
|
|
||||||
)
|
|
||||||
if text.metadata is not None:
|
if text.metadata is not None:
|
||||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||||
keyword_table = self._add_text_to_keyword_table(
|
keyword_table = self._add_text_to_keyword_table(
|
||||||
@@ -49,18 +51,17 @@ class Jieba(BaseKeyword):
|
|||||||
|
|
||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
keywords_list = kwargs.get("keywords_list")
|
keywords_list = kwargs.get("keywords_list")
|
||||||
|
keyword_number = (
|
||||||
|
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||||
|
)
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
text = texts[i]
|
text = texts[i]
|
||||||
if keywords_list:
|
if keywords_list:
|
||||||
keywords = keywords_list[i]
|
keywords = keywords_list[i]
|
||||||
if not keywords:
|
if not keywords:
|
||||||
keywords = keyword_table_handler.extract_keywords(
|
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||||
text.page_content, self._config.max_keywords_per_chunk
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
keywords = keyword_table_handler.extract_keywords(
|
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
|
||||||
text.page_content, self._config.max_keywords_per_chunk
|
|
||||||
)
|
|
||||||
if text.metadata is not None:
|
if text.metadata is not None:
|
||||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||||
keyword_table = self._add_text_to_keyword_table(
|
keyword_table = self._add_text_to_keyword_table(
|
||||||
@@ -239,7 +240,11 @@ class Jieba(BaseKeyword):
|
|||||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
|
keyword_number = (
|
||||||
|
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
|
||||||
segment.keywords = list(keywords)
|
segment.keywords = list(keywords)
|
||||||
keyword_table = self._add_text_to_keyword_table(
|
keyword_table = self._add_text_to_keyword_table(
|
||||||
keyword_table or {}, segment.index_node_id, list(keywords)
|
keyword_table or {}, segment.index_node_id, list(keywords)
|
||||||
|
|||||||
38
api/core/rag/entities/event.py
Normal file
38
api/core/rag/entities/event.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceStreamEvent(Enum):
|
||||||
|
"""
|
||||||
|
Datasource Stream event
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROCESSING = "datasource_processing"
|
||||||
|
COMPLETED = "datasource_completed"
|
||||||
|
ERROR = "datasource_error"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDatasourceEvent(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceErrorEvent(BaseDatasourceEvent):
|
||||||
|
event: str = DatasourceStreamEvent.ERROR.value
|
||||||
|
error: str = Field(..., description="error message")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceCompletedEvent(BaseDatasourceEvent):
|
||||||
|
event: str = DatasourceStreamEvent.COMPLETED.value
|
||||||
|
data: Mapping[str, Any] | list = Field(..., description="result")
|
||||||
|
total: Optional[int] = Field(default=0, description="total")
|
||||||
|
completed: Optional[int] = Field(default=0, description="completed")
|
||||||
|
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProcessingEvent(BaseDatasourceEvent):
|
||||||
|
event: str = DatasourceStreamEvent.PROCESSING.value
|
||||||
|
total: Optional[int] = Field(..., description="total")
|
||||||
|
completed: Optional[int] = Field(..., description="completed")
|
||||||
@@ -13,3 +13,5 @@ class MetadataDataSource(Enum):
|
|||||||
upload_file = "file_upload"
|
upload_file = "file_upload"
|
||||||
website_crawl = "website"
|
website_crawl = "website"
|
||||||
notion_import = "notion"
|
notion_import = "notion"
|
||||||
|
local_file = "file_upload"
|
||||||
|
online_document = "online_document"
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Abstract interface for document loader implementations."""
|
"""Abstract interface for document loader implementations."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
@@ -13,6 +14,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
|||||||
)
|
)
|
||||||
from core.rag.splitter.text_splitter import TextSplitter
|
from core.rag.splitter.text_splitter import TextSplitter
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
|
||||||
class BaseIndexProcessor(ABC):
|
class BaseIndexProcessor(ABC):
|
||||||
@@ -33,6 +35,14 @@ class BaseIndexProcessor(ABC):
|
|||||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
|
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def retrieve(
|
def retrieve(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
"""Paragraph index processor."""
|
"""Paragraph index processor."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document, GeneralStructureChunk
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||||
|
|
||||||
|
|
||||||
@@ -127,3 +130,34 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||||
|
paragraph = GeneralStructureChunk(**chunks)
|
||||||
|
documents = []
|
||||||
|
for content in paragraph.general_chunks:
|
||||||
|
metadata = {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"document_id": document.id,
|
||||||
|
"doc_id": str(uuid.uuid4()),
|
||||||
|
"doc_hash": helper.generate_text_hash(content),
|
||||||
|
}
|
||||||
|
doc = Document(page_content=content, metadata=metadata)
|
||||||
|
documents.append(doc)
|
||||||
|
if documents:
|
||||||
|
# save node to document segment
|
||||||
|
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||||
|
# add document segments
|
||||||
|
doc_store.add_documents(docs=documents, save_child=False)
|
||||||
|
if dataset.indexing_technique == "high_quality":
|
||||||
|
vector = Vector(dataset)
|
||||||
|
vector.create(documents)
|
||||||
|
elif dataset.indexing_technique == "economy":
|
||||||
|
keyword = Keyword(dataset)
|
||||||
|
keyword.add_texts(documents)
|
||||||
|
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
paragraph = GeneralStructureChunk(**chunks)
|
||||||
|
preview = []
|
||||||
|
for content in paragraph.general_chunks:
|
||||||
|
preview.append({"content": content})
|
||||||
|
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
"""Paragraph index processor."""
|
"""Paragraph index processor."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import ChildDocument, Document
|
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DocumentSegment
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||||
|
|
||||||
|
|
||||||
@@ -202,3 +205,40 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
child_document.page_content = child_page_content
|
child_document.page_content = child_page_content
|
||||||
child_nodes.append(child_document)
|
child_nodes.append(child_document)
|
||||||
return child_nodes
|
return child_nodes
|
||||||
|
|
||||||
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||||
|
parent_childs = ParentChildStructureChunk(**chunks)
|
||||||
|
documents = []
|
||||||
|
for parent_child in parent_childs.parent_child_chunks:
|
||||||
|
metadata = {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"document_id": document.id,
|
||||||
|
"doc_id": str(uuid.uuid4()),
|
||||||
|
"doc_hash": helper.generate_text_hash(parent_child.parent_content),
|
||||||
|
}
|
||||||
|
child_documents = []
|
||||||
|
for child in parent_child.child_contents:
|
||||||
|
child_metadata = {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"document_id": document.id,
|
||||||
|
"doc_id": str(uuid.uuid4()),
|
||||||
|
"doc_hash": helper.generate_text_hash(child),
|
||||||
|
}
|
||||||
|
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
|
||||||
|
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
|
||||||
|
documents.append(doc)
|
||||||
|
if documents:
|
||||||
|
# save node to document segment
|
||||||
|
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||||
|
# add document segments
|
||||||
|
doc_store.add_documents(docs=documents, save_child=True)
|
||||||
|
if dataset.indexing_technique == "high_quality":
|
||||||
|
vector = Vector(dataset)
|
||||||
|
vector.create(documents)
|
||||||
|
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
parent_childs = ParentChildStructureChunk(**chunks)
|
||||||
|
preview = []
|
||||||
|
for parent_child in parent_childs.parent_child_chunks:
|
||||||
|
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||||
|
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
@@ -14,13 +15,15 @@ from core.llm_generator.llm_generator import LLMGenerator
|
|||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document, QAStructureChunk
|
||||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||||
|
|
||||||
|
|
||||||
@@ -161,6 +164,36 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
|
||||||
|
qa_chunks = QAStructureChunk(**chunks)
|
||||||
|
documents = []
|
||||||
|
for qa_chunk in qa_chunks.qa_chunks:
|
||||||
|
metadata = {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"document_id": document.id,
|
||||||
|
"doc_id": str(uuid.uuid4()),
|
||||||
|
"doc_hash": helper.generate_text_hash(qa_chunk.question),
|
||||||
|
"answer": qa_chunk.answer,
|
||||||
|
}
|
||||||
|
doc = Document(page_content=qa_chunk.question, metadata=metadata)
|
||||||
|
documents.append(doc)
|
||||||
|
if documents:
|
||||||
|
# save node to document segment
|
||||||
|
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||||
|
doc_store.add_documents(docs=documents, save_child=False)
|
||||||
|
if dataset.indexing_technique == "high_quality":
|
||||||
|
vector = Vector(dataset)
|
||||||
|
vector.create(documents)
|
||||||
|
else:
|
||||||
|
raise ValueError("Indexing technique must be high quality.")
|
||||||
|
|
||||||
|
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
qa_chunks = QAStructureChunk(**chunks)
|
||||||
|
preview = []
|
||||||
|
for qa_chunk in qa_chunks.qa_chunks:
|
||||||
|
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||||
|
return {"qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks)}
|
||||||
|
|
||||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||||
format_documents = []
|
format_documents = []
|
||||||
if document_node.page_content is None or not document_node.page_content.strip():
|
if document_node.page_content is None or not document_node.page_content.strip():
|
||||||
|
|||||||
@@ -35,6 +35,48 @@ class Document(BaseModel):
|
|||||||
children: Optional[list[ChildDocument]] = None
|
children: Optional[list[ChildDocument]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
General Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
general_chunks: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_content: str
|
||||||
|
child_contents: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_child_chunks: list[ParentChildChunk]
|
||||||
|
|
||||||
|
|
||||||
|
class QAChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
QA Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class QAStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
QAStructureChunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
qa_chunks: list[QAChunk]
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentTransformer(ABC):
|
class BaseDocumentTransformer(ABC):
|
||||||
"""Abstract base class for document transformation systems.
|
"""Abstract base class for document transformation systems.
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ class RetrievalMethod(Enum):
|
|||||||
SEMANTIC_SEARCH = "semantic_search"
|
SEMANTIC_SEARCH = "semantic_search"
|
||||||
FULL_TEXT_SEARCH = "full_text_search"
|
FULL_TEXT_SEARCH = "full_text_search"
|
||||||
HYBRID_SEARCH = "hybrid_search"
|
HYBRID_SEARCH = "hybrid_search"
|
||||||
|
KEYWORD_SEARCH = "keyword_search"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_support_semantic_search(retrieval_method: str) -> bool:
|
def is_support_semantic_search(retrieval_method: str) -> bool:
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
order_config: Optional[OrderConfig] = None,
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||||
"""
|
"""
|
||||||
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
|
||||||
@@ -283,7 +284,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
stmt = select(WorkflowNodeExecutionModel).where(
|
stmt = select(WorkflowNodeExecutionModel).where(
|
||||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._app_id:
|
if self._app_id:
|
||||||
@@ -317,6 +318,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
order_config: Optional[OrderConfig] = None,
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
) -> Sequence[WorkflowNodeExecution]:
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
"""
|
"""
|
||||||
Retrieve all NodeExecution instances for a specific workflow run.
|
Retrieve all NodeExecution instances for a specific workflow run.
|
||||||
@@ -334,7 +336,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
A list of NodeExecution instances
|
A list of NodeExecution instances
|
||||||
"""
|
"""
|
||||||
# Get the database models using the new method
|
# Get the database models using the new method
|
||||||
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
|
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from)
|
||||||
|
|
||||||
# Convert database models to domain models
|
# Convert database models to domain models
|
||||||
domain_models = []
|
domain_models = []
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
|
||||||
@@ -93,3 +93,32 @@ class FileVariable(FileSegment, Variable):
|
|||||||
|
|
||||||
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RAGPipelineVariable(BaseModel):
|
||||||
|
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||||
|
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||||
|
label: str = Field(description="label")
|
||||||
|
description: str | None = Field(description="description", default="")
|
||||||
|
variable: str = Field(description="variable key", default="")
|
||||||
|
max_length: int | None = Field(
|
||||||
|
description="max length, applicable to text-input, paragraph, and file-list", default=0
|
||||||
|
)
|
||||||
|
default_value: Any = Field(description="default value", default="")
|
||||||
|
placeholder: str | None = Field(description="placeholder", default="")
|
||||||
|
unit: str | None = Field(description="unit, applicable to Number", default="")
|
||||||
|
tooltips: str | None = Field(description="helpful text", default="")
|
||||||
|
allowed_file_types: list[str] | None = Field(
|
||||||
|
description="image, document, audio, video, custom.", default_factory=list
|
||||||
|
)
|
||||||
|
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
|
||||||
|
allowed_file_upload_methods: list[str] | None = Field(
|
||||||
|
description="remote_url, local_file, tool_file.", default_factory=list
|
||||||
|
)
|
||||||
|
required: bool = Field(description="optional, default false", default=False)
|
||||||
|
options: list[str] | None = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class RAGPipelineVariableInput(BaseModel):
|
||||||
|
variable: RAGPipelineVariable
|
||||||
|
value: Any
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||||
|
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||||
|
|||||||
@@ -9,7 +9,13 @@ from core.file import File, FileAttribute, file_manager
|
|||||||
from core.variables import Segment, SegmentGroup, Variable
|
from core.variables import Segment, SegmentGroup, Variable
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
from core.variables.segments import FileSegment, NoneSegment
|
from core.variables.segments import FileSegment, NoneSegment
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.variables.variables import RAGPipelineVariableInput
|
||||||
|
from core.workflow.constants import (
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||||
|
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
)
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
|
||||||
@@ -44,6 +50,10 @@ class VariablePool(BaseModel):
|
|||||||
description="Conversation variables.",
|
description="Conversation variables.",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||||
|
description="RAG pipeline variables.",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
def model_post_init(self, context: Any, /) -> None:
|
def model_post_init(self, context: Any, /) -> None:
|
||||||
for key, value in self.system_variables.items():
|
for key, value in self.system_variables.items():
|
||||||
@@ -54,6 +64,9 @@ class VariablePool(BaseModel):
|
|||||||
# Add conversation variables to the variable pool
|
# Add conversation variables to the variable pool
|
||||||
for var in self.conversation_variables:
|
for var in self.conversation_variables:
|
||||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||||
|
# Add rag pipeline variables to the variable pool
|
||||||
|
for var in self.rag_pipeline_variables:
|
||||||
|
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value)
|
||||||
|
|
||||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class WorkflowType(StrEnum):
|
|||||||
|
|
||||||
WORKFLOW = "workflow"
|
WORKFLOW = "workflow"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecutionStatus(StrEnum):
|
class WorkflowExecutionStatus(StrEnum):
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||||||
AGENT_LOG = "agent_log"
|
AGENT_LOG = "agent_log"
|
||||||
ITERATION_ID = "iteration_id"
|
ITERATION_ID = "iteration_id"
|
||||||
ITERATION_INDEX = "iteration_index"
|
ITERATION_INDEX = "iteration_index"
|
||||||
|
DATASOURCE_INFO = "datasource_info"
|
||||||
LOOP_ID = "loop_id"
|
LOOP_ID = "loop_id"
|
||||||
LOOP_INDEX = "loop_index"
|
LOOP_INDEX = "loop_index"
|
||||||
PARALLEL_ID = "parallel_id"
|
PARALLEL_ID = "parallel_id"
|
||||||
|
|||||||
@@ -14,3 +14,10 @@ class SystemVariableKey(StrEnum):
|
|||||||
APP_ID = "app_id"
|
APP_ID = "app_id"
|
||||||
WORKFLOW_ID = "workflow_id"
|
WORKFLOW_ID = "workflow_id"
|
||||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||||
|
# RAG Pipeline
|
||||||
|
DOCUMENT_ID = "document_id"
|
||||||
|
BATCH = "batch"
|
||||||
|
DATASET_ID = "dataset_id"
|
||||||
|
DATASOURCE_TYPE = "datasource_type"
|
||||||
|
DATASOURCE_INFO = "datasource_info"
|
||||||
|
INVOKE_FROM = "invoke_from"
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class Graph(BaseModel):
|
|||||||
# fetch nodes that have no predecessor node
|
# fetch nodes that have no predecessor node
|
||||||
root_node_configs = []
|
root_node_configs = []
|
||||||
all_node_id_config_mapping: dict[str, dict] = {}
|
all_node_id_config_mapping: dict[str, dict] = {}
|
||||||
|
|
||||||
for node_config in node_configs:
|
for node_config in node_configs:
|
||||||
node_id = node_config.get("id")
|
node_id = node_config.get("id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@@ -141,6 +142,7 @@ class Graph(BaseModel):
|
|||||||
node_config.get("id")
|
node_config.get("id")
|
||||||
for node_config in root_node_configs
|
for node_config in root_node_configs
|
||||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||||
|
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
elif isinstance(item, NodeRunSucceededEvent):
|
elif isinstance(item, NodeRunSucceededEvent):
|
||||||
if item.node_type == NodeType.END:
|
if item.node_type in (NodeType.END, NodeType.KNOWLEDGE_INDEX):
|
||||||
self.graph_runtime_state.outputs = (
|
self.graph_runtime_state.outputs = (
|
||||||
dict(item.route_node_state.node_run_result.outputs)
|
dict(item.route_node_state.node_run_result.outputs)
|
||||||
if item.route_node_state.node_run_result
|
if item.route_node_state.node_run_result
|
||||||
@@ -320,10 +320,10 @@ class GraphEngine:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
# It may not be necessary, but it is necessary. :)
|
# It may not be necessary, but it is necessary. :)
|
||||||
if (
|
if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
|
||||||
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
|
NodeType.END.value,
|
||||||
== NodeType.END.value
|
NodeType.KNOWLEDGE_INDEX.value,
|
||||||
):
|
]:
|
||||||
break
|
break
|
||||||
|
|
||||||
previous_route_node_state = route_node_state
|
previous_route_node_state = route_node_state
|
||||||
|
|||||||
3
api/core/workflow/nodes/datasource/__init__.py
Normal file
3
api/core/workflow/nodes/datasource/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .datasource_node import DatasourceNode
|
||||||
|
|
||||||
|
__all__ = ["DatasourceNode"]
|
||||||
468
api/core/workflow/nodes/datasource/datasource_node.py
Normal file
468
api/core/workflow/nodes/datasource/datasource_node.py
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceParameter,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
)
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||||
|
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||||
|
from core.file import File
|
||||||
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
|
from core.variables.segments import ArrayAnySegment
|
||||||
|
from core.variables.variables import ArrayAnyVariable
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.tool.exc import ToolFileError
|
||||||
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
|
from models.model import UploadFile
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
|
||||||
|
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||||
|
from .entities import DatasourceNodeData
|
||||||
|
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||||
|
"""
|
||||||
|
Datasource Node
|
||||||
|
"""
|
||||||
|
|
||||||
|
_node_data_cls = DatasourceNodeData
|
||||||
|
_node_type = NodeType.DATASOURCE
|
||||||
|
|
||||||
|
def _run(self) -> Generator:
|
||||||
|
"""
|
||||||
|
Run the datasource node
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_data = cast(DatasourceNodeData, self.node_data)
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||||
|
if not datasource_type:
|
||||||
|
raise DatasourceNodeError("Datasource type is not set")
|
||||||
|
datasource_type = datasource_type.value
|
||||||
|
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||||
|
if not datasource_info:
|
||||||
|
raise DatasourceNodeError("Datasource info is not set")
|
||||||
|
datasource_info = datasource_info.value
|
||||||
|
# get datasource runtime
|
||||||
|
try:
|
||||||
|
from core.datasource.datasource_manager import DatasourceManager
|
||||||
|
|
||||||
|
if datasource_type is None:
|
||||||
|
raise DatasourceNodeError("Datasource type is not set")
|
||||||
|
|
||||||
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
|
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||||
|
datasource_name=node_data.datasource_name or "",
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||||
|
)
|
||||||
|
except DatasourceNodeError as e:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs={},
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
error=f"Failed to get datasource runtime: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# get parameters
|
||||||
|
datasource_parameters = datasource_runtime.entity.parameters
|
||||||
|
parameters = self._generate_parameters(
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_data=self.node_data,
|
||||||
|
)
|
||||||
|
parameters_for_log = self._generate_parameters(
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_data=self.node_data,
|
||||||
|
for_log=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
match datasource_type:
|
||||||
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=node_data.provider_name,
|
||||||
|
plugin_id=node_data.plugin_id,
|
||||||
|
)
|
||||||
|
if credentials:
|
||||||
|
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||||
|
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||||
|
datasource_runtime.get_online_document_page_content(
|
||||||
|
user_id=self.user_id,
|
||||||
|
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||||
|
workspace_id=datasource_info.get("workspace_id"),
|
||||||
|
page_id=datasource_info.get("page").get("page_id"),
|
||||||
|
type=datasource_info.get("page").get("type"),
|
||||||
|
),
|
||||||
|
provider_type=datasource_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield from self._transform_message(
|
||||||
|
messages=online_document_result,
|
||||||
|
parameters_for_log=parameters_for_log,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.ONLINE_DRIVE:
|
||||||
|
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=node_data.provider_name,
|
||||||
|
plugin_id=node_data.plugin_id,
|
||||||
|
)
|
||||||
|
if credentials:
|
||||||
|
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||||
|
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||||
|
datasource_runtime.online_drive_download_file(
|
||||||
|
user_id=self.user_id,
|
||||||
|
request=OnlineDriveDownloadFileRequest(
|
||||||
|
key=datasource_info.get("key"),
|
||||||
|
bucket=datasource_info.get("bucket"),
|
||||||
|
),
|
||||||
|
provider_type=datasource_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield from self._transform_message(
|
||||||
|
messages=online_drive_result,
|
||||||
|
parameters_for_log=parameters_for_log,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
outputs={
|
||||||
|
**datasource_info,
|
||||||
|
"datasource_type": datasource_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
related_id = datasource_info.get("related_id")
|
||||||
|
if not related_id:
|
||||||
|
raise DatasourceNodeError("File is not exist")
|
||||||
|
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||||
|
if not upload_file:
|
||||||
|
raise ValueError("Invalid upload file Info")
|
||||||
|
|
||||||
|
file_info = File(
|
||||||
|
id=upload_file.id,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension="." + upload_file.extension,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
type=FileType.CUSTOM,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
remote_url=upload_file.source_url,
|
||||||
|
related_id=upload_file.id,
|
||||||
|
size=upload_file.size,
|
||||||
|
storage_key=upload_file.key,
|
||||||
|
)
|
||||||
|
variable_pool.add([self.node_id, "file"], [file_info])
|
||||||
|
for key, value in datasource_info.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = ["file", key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_id=self.node_id,
|
||||||
|
variable_key_list=new_key_list,
|
||||||
|
variable_value=value,
|
||||||
|
)
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
outputs={
|
||||||
|
"file_info": datasource_info,
|
||||||
|
"datasource_type": datasource_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
error=f"Failed to transform datasource message: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except DatasourceNodeError as e:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
error=f"Failed to invoke datasource: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_parameters(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
datasource_parameters: Sequence[DatasourceParameter],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
node_data: DatasourceNodeData,
|
||||||
|
for_log: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||||
|
variable_pool (VariablePool): The variable pool containing the variables.
|
||||||
|
node_data (ToolNodeData): The data associated with the tool node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||||
|
|
||||||
|
"""
|
||||||
|
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||||
|
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
if node_data.datasource_parameters:
|
||||||
|
for parameter_name in node_data.datasource_parameters:
|
||||||
|
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||||
|
if not parameter:
|
||||||
|
result[parameter_name] = None
|
||||||
|
continue
|
||||||
|
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||||
|
if datasource_input.type == "variable":
|
||||||
|
variable = variable_pool.get(datasource_input.value)
|
||||||
|
if variable is None:
|
||||||
|
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||||
|
parameter_value = variable.value
|
||||||
|
elif datasource_input.type in {"mixed", "constant"}:
|
||||||
|
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||||
|
parameter_value = segment_group.log if for_log else segment_group.text
|
||||||
|
else:
|
||||||
|
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||||
|
result[parameter_name] = parameter_value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||||
|
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||||
|
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||||
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
|
def _append_variables_recursively(
|
||||||
|
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Append variables recursively
|
||||||
|
:param node_id: node id
|
||||||
|
:param variable_key_list: variable key list
|
||||||
|
:param variable_value: variable value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||||
|
|
||||||
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
if isinstance(variable_value, dict):
|
||||||
|
for key, value in variable_value.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = variable_key_list + [key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
node_data: DatasourceNodeData,
|
||||||
|
) -> Mapping[str, Sequence[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
:param graph_config: graph config
|
||||||
|
:param node_id: node id
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
if node_data.datasource_parameters:
|
||||||
|
for parameter_name in node_data.datasource_parameters:
|
||||||
|
input = node_data.datasource_parameters[parameter_name]
|
||||||
|
if input.type == "mixed":
|
||||||
|
assert isinstance(input.value, str)
|
||||||
|
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||||
|
for selector in selectors:
|
||||||
|
result[selector.variable] = selector.value_selector
|
||||||
|
elif input.type == "variable":
|
||||||
|
result[parameter_name] = input.value
|
||||||
|
elif input.type == "constant":
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = {node_id + "." + key: value for key, value in result.items()}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _transform_message(
|
||||||
|
self,
|
||||||
|
messages: Generator[DatasourceMessage, None, None],
|
||||||
|
parameters_for_log: dict[str, Any],
|
||||||
|
datasource_info: dict[str, Any],
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
|
"""
|
||||||
|
# transform message and handle file storage
|
||||||
|
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||||
|
messages=messages,
|
||||||
|
user_id=self.user_id,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
files: list[File] = []
|
||||||
|
json: list[dict] = []
|
||||||
|
|
||||||
|
variables: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for message in message_stream:
|
||||||
|
if message.type in {
|
||||||
|
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
DatasourceMessage.MessageType.BINARY_LINK,
|
||||||
|
DatasourceMessage.MessageType.IMAGE,
|
||||||
|
}:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
|
||||||
|
url = message.message.text
|
||||||
|
if message.meta:
|
||||||
|
transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE)
|
||||||
|
else:
|
||||||
|
transfer_method = FileTransferMethod.DATASOURCE_FILE
|
||||||
|
|
||||||
|
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||||
|
datasource_file = session.scalar(stmt)
|
||||||
|
if datasource_file is None:
|
||||||
|
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"datasource_file_id": datasource_file_id,
|
||||||
|
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
|
||||||
|
"transfer_method": transfer_method,
|
||||||
|
"url": url,
|
||||||
|
}
|
||||||
|
file = file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
)
|
||||||
|
files.append(file)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||||
|
# get tool file id
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
assert message.meta
|
||||||
|
|
||||||
|
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||||
|
datasource_file = session.scalar(stmt)
|
||||||
|
if datasource_file is None:
|
||||||
|
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"datasource_file_id": datasource_file_id,
|
||||||
|
"transfer_method": FileTransferMethod.DATASOURCE_FILE,
|
||||||
|
}
|
||||||
|
|
||||||
|
files.append(
|
||||||
|
file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
text += message.message.text
|
||||||
|
yield RunStreamChunkEvent(
|
||||||
|
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||||
|
if self.node_type == NodeType.AGENT:
|
||||||
|
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||||
|
agent_execution_metadata = {
|
||||||
|
key: value
|
||||||
|
for key, value in msg_metadata.items()
|
||||||
|
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||||
|
}
|
||||||
|
json.append(message.message.json_object)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
stream_text = f"Link: {message.message.text}\n"
|
||||||
|
text += stream_text
|
||||||
|
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||||
|
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||||
|
variable_name = message.message.variable_name
|
||||||
|
variable_value = message.message.variable_value
|
||||||
|
if message.message.stream:
|
||||||
|
if not isinstance(variable_value, str):
|
||||||
|
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||||
|
if variable_name not in variables:
|
||||||
|
variables[variable_name] = ""
|
||||||
|
variables[variable_name] += variable_value
|
||||||
|
|
||||||
|
yield RunStreamChunkEvent(
|
||||||
|
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
variables[variable_name] = variable_value
|
||||||
|
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||||
|
assert message.meta is not None
|
||||||
|
files.append(message.meta["file"])
|
||||||
|
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
outputs={"json": json, "files": files, **variables, "text": text},
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||||
|
},
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
41
api/core/workflow/nodes/datasource/entities.py
Normal file
41
api/core/workflow/nodes/datasource/entities.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEntity(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
provider_name: str # redundancy
|
||||||
|
provider_type: str
|
||||||
|
datasource_name: Optional[str] = "local_file"
|
||||||
|
datasource_configurations: dict[str, Any] | None = None
|
||||||
|
plugin_unique_identifier: str | None = None # redundancy
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||||
|
class DatasourceInput(BaseModel):
|
||||||
|
# TODO: check this type
|
||||||
|
value: Union[Any, list[str]]
|
||||||
|
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||||
|
|
||||||
|
@field_validator("type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_type(cls, value, validation_info: ValidationInfo):
|
||||||
|
typ = value
|
||||||
|
value = validation_info.data.get("value")
|
||||||
|
if typ == "mixed" and not isinstance(value, str):
|
||||||
|
raise ValueError("value must be a string")
|
||||||
|
elif typ == "variable":
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError("value must be a list")
|
||||||
|
for val in value:
|
||||||
|
if not isinstance(val, str):
|
||||||
|
raise ValueError("value must be a list of strings")
|
||||||
|
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||||
|
raise ValueError("value must be a string, int, float, or bool")
|
||||||
|
return typ
|
||||||
|
|
||||||
|
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||||
16
api/core/workflow/nodes/datasource/exc.py
Normal file
16
api/core/workflow/nodes/datasource/exc.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
class DatasourceNodeError(ValueError):
|
||||||
|
"""Base exception for datasource node errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameterError(DatasourceNodeError):
|
||||||
|
"""Exception raised for errors in datasource parameters."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileError(DatasourceNodeError):
|
||||||
|
"""Exception raised for errors related to datasource files."""
|
||||||
|
|
||||||
|
pass
|
||||||
@@ -7,12 +7,14 @@ class NodeType(StrEnum):
|
|||||||
ANSWER = "answer"
|
ANSWER = "answer"
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||||
|
KNOWLEDGE_INDEX = "knowledge-index"
|
||||||
IF_ELSE = "if-else"
|
IF_ELSE = "if-else"
|
||||||
CODE = "code"
|
CODE = "code"
|
||||||
TEMPLATE_TRANSFORM = "template-transform"
|
TEMPLATE_TRANSFORM = "template-transform"
|
||||||
QUESTION_CLASSIFIER = "question-classifier"
|
QUESTION_CLASSIFIER = "question-classifier"
|
||||||
HTTP_REQUEST = "http-request"
|
HTTP_REQUEST = "http-request"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
DATASOURCE = "datasource"
|
||||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||||
LOOP = "loop"
|
LOOP = "loop"
|
||||||
|
|||||||
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .knowledge_index_node import KnowledgeIndexNode
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeIndexNode"]
|
||||||
159
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
159
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class RerankingModelConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Reranking Model Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reranking_provider_name: str
|
||||||
|
reranking_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Vector Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_weight: float
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Keyword Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_weight: float
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedScoreConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted score Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_setting: VectorSetting
|
||||||
|
keyword_setting: KeywordSetting
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Embedding Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EconomySetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Economy Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_number: int
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Retrieval Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||||
|
top_k: int
|
||||||
|
score_threshold: Optional[float] = 0.5
|
||||||
|
score_threshold_enabled: bool = False
|
||||||
|
reranking_mode: str = "reranking_model"
|
||||||
|
reranking_enable: bool = True
|
||||||
|
reranking_model: Optional[RerankingModelConfig] = None
|
||||||
|
weights: Optional[WeightedScoreConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMethod(BaseModel):
|
||||||
|
"""
|
||||||
|
Knowledge Index Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indexing_technique: Literal["high_quality", "economy"]
|
||||||
|
embedding_setting: EmbeddingSetting
|
||||||
|
economy_setting: EconomySetting
|
||||||
|
|
||||||
|
|
||||||
|
class FileInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
File Info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentIcon(BaseModel):
|
||||||
|
"""
|
||||||
|
Document Icon.
|
||||||
|
"""
|
||||||
|
|
||||||
|
icon_url: str
|
||||||
|
icon_type: str
|
||||||
|
icon_emoji: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
workspace_id: str
|
||||||
|
page_id: str
|
||||||
|
page_type: str
|
||||||
|
icon: OnlineDocumentIcon
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
website import info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
General Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
general_chunks: list[str]
|
||||||
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_content: str
|
||||||
|
child_contents: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_child_chunks: list[ParentChildChunk]
|
||||||
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Knowledge index Node Data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "knowledge-index"
|
||||||
|
chunk_structure: str
|
||||||
|
index_chunk_variable_selector: list[str]
|
||||||
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
class KnowledgeIndexNodeError(ValueError):
|
||||||
|
"""Base class for KnowledgeIndexNode errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotExistError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model does not exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model credentials are not initialized."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotSupportedError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelQuotaExceededError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model provider quota is exceeded."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelTypeError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model is not a Large Language Model."""
|
||||||
165
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
165
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
||||||
|
from ..base import BaseNode
|
||||||
|
from .entities import KnowledgeIndexNodeData
|
||||||
|
from .exc import (
|
||||||
|
KnowledgeIndexNodeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_retrieval_model = {
|
||||||
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
"reranking_enable": False,
|
||||||
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||||
|
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||||
|
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
|
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||||
|
if not dataset_id:
|
||||||
|
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||||
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||||
|
if not dataset:
|
||||||
|
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||||
|
|
||||||
|
# extract variables
|
||||||
|
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||||
|
if not variable:
|
||||||
|
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||||
|
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||||
|
if invoke_from:
|
||||||
|
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||||
|
else:
|
||||||
|
is_preview = False
|
||||||
|
chunks = variable.value
|
||||||
|
variables = {"chunks": chunks}
|
||||||
|
if not chunks:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||||
|
)
|
||||||
|
|
||||||
|
# index knowledge
|
||||||
|
try:
|
||||||
|
if is_preview:
|
||||||
|
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=variables,
|
||||||
|
process_data=None,
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
results = self._invoke_knowledge_index(
|
||||||
|
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||||
|
)
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||||
|
)
|
||||||
|
|
||||||
|
except KnowledgeIndexNodeError as e:
|
||||||
|
logger.warning("Error when running knowledge index node")
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _invoke_knowledge_index(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
node_data: KnowledgeIndexNodeData,
|
||||||
|
chunks: Mapping[str, Any],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
) -> Any:
|
||||||
|
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||||
|
if not document_id:
|
||||||
|
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||||
|
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||||
|
if not batch:
|
||||||
|
raise KnowledgeIndexNodeError("Batch is required.")
|
||||||
|
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||||
|
if not document:
|
||||||
|
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||||
|
# chunk nodes by chunk size
|
||||||
|
indexing_start_at = time.perf_counter()
|
||||||
|
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||||
|
index_processor.index(dataset, document, chunks)
|
||||||
|
indexing_end_at = time.perf_counter()
|
||||||
|
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||||
|
# update document status
|
||||||
|
document.indexing_status = "completed"
|
||||||
|
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
document.word_count = (
|
||||||
|
db.session.query(func.sum(DocumentSegment.word_count))
|
||||||
|
.filter(
|
||||||
|
DocumentSegment.document_id == document.id,
|
||||||
|
DocumentSegment.dataset_id == dataset.id,
|
||||||
|
)
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
# update document segment status
|
||||||
|
db.session.query(DocumentSegment).filter(
|
||||||
|
DocumentSegment.document_id == document.id,
|
||||||
|
DocumentSegment.dataset_id == dataset.id,
|
||||||
|
).update(
|
||||||
|
{
|
||||||
|
DocumentSegment.status: "completed",
|
||||||
|
DocumentSegment.enabled: True,
|
||||||
|
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"dataset_name": dataset.name,
|
||||||
|
"batch": batch.value,
|
||||||
|
"document_id": document.id,
|
||||||
|
"document_name": document.name,
|
||||||
|
"created_at": document.created_at.timestamp(),
|
||||||
|
"display_status": document.indexing_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||||
|
return index_processor.format_preview(chunks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
@@ -57,10 +57,6 @@ class MultipleRetrievalConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
"""
|
|
||||||
Model Config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
name: str
|
name: str
|
||||||
mode: str
|
mode: str
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
|
|||||||
from core.workflow.nodes.answer import AnswerNode
|
from core.workflow.nodes.answer import AnswerNode
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
from core.workflow.nodes.code import CodeNode
|
from core.workflow.nodes.code import CodeNode
|
||||||
|
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||||
from core.workflow.nodes.end import EndNode
|
from core.workflow.nodes.end import EndNode
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.http_request import HttpRequestNode
|
from core.workflow.nodes.http_request import HttpRequestNode
|
||||||
from core.workflow.nodes.if_else import IfElseNode
|
from core.workflow.nodes.if_else import IfElseNode
|
||||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||||
|
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||||
from core.workflow.nodes.llm import LLMNode
|
from core.workflow.nodes.llm import LLMNode
|
||||||
@@ -124,4 +126,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||||||
LATEST_VERSION: AgentNode,
|
LATEST_VERSION: AgentNode,
|
||||||
"1": AgentNode,
|
"1": AgentNode,
|
||||||
},
|
},
|
||||||
|
NodeType.DATASOURCE: {
|
||||||
|
LATEST_VERSION: DatasourceNode,
|
||||||
|
"1": DatasourceNode,
|
||||||
|
},
|
||||||
|
NodeType.KNOWLEDGE_INDEX: {
|
||||||
|
LATEST_VERSION: KnowledgeIndexNode,
|
||||||
|
"1": KnowledgeIndexNode,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ def build_from_mapping(
|
|||||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||||
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||||
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||||
|
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
|
||||||
}
|
}
|
||||||
|
|
||||||
build_func = build_functions.get(transfer_method)
|
build_func = build_functions.get(transfer_method)
|
||||||
@@ -305,6 +306,53 @@ def _build_from_tool_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_from_datasource_file(
|
||||||
|
*,
|
||||||
|
mapping: Mapping[str, Any],
|
||||||
|
tenant_id: str,
|
||||||
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
|
) -> File:
|
||||||
|
datasource_file = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(
|
||||||
|
UploadFile.id == mapping.get("datasource_file_id"),
|
||||||
|
UploadFile.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if datasource_file is None:
|
||||||
|
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||||
|
|
||||||
|
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||||
|
|
||||||
|
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||||
|
|
||||||
|
specified_type = mapping.get("type")
|
||||||
|
|
||||||
|
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||||
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return File(
|
||||||
|
id=mapping.get("id"),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
filename=datasource_file.name,
|
||||||
|
type=file_type,
|
||||||
|
transfer_method=transfer_method,
|
||||||
|
remote_url=datasource_file.source_url,
|
||||||
|
related_id=datasource_file.id,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=datasource_file.mime_type,
|
||||||
|
size=datasource_file.size,
|
||||||
|
storage_key=datasource_file.key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_file_valid_with_config(
|
def _is_file_valid_with_config(
|
||||||
*,
|
*,
|
||||||
input_file_type: str,
|
input_file_type: str,
|
||||||
|
|||||||
@@ -36,7 +36,10 @@ from core.variables.variables import (
|
|||||||
StringVariable,
|
StringVariable,
|
||||||
Variable,
|
Variable,
|
||||||
)
|
)
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
from core.workflow.constants import (
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedSegmentTypeError(Exception):
|
class UnsupportedSegmentTypeError(Exception):
|
||||||
@@ -75,6 +78,12 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
|
|||||||
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||||
|
if not mapping.get("variable"):
|
||||||
|
raise VariableError("missing variable")
|
||||||
|
return mapping["variable"]
|
||||||
|
|
||||||
|
|
||||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||||
"""
|
"""
|
||||||
This factory function is used to create the environment variable or the conversation variable,
|
This factory function is used to create the environment variable or the conversation variable,
|
||||||
|
|||||||
@@ -56,6 +56,13 @@ external_knowledge_info_fields = {
|
|||||||
|
|
||||||
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||||
|
|
||||||
|
icon_info_fields = {
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
dataset_detail_fields = {
|
dataset_detail_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
@@ -81,6 +88,13 @@ dataset_detail_fields = {
|
|||||||
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
|
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
|
||||||
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
|
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
|
||||||
"built_in_field_enabled": fields.Boolean,
|
"built_in_field_enabled": fields.Boolean,
|
||||||
|
"pipeline_id": fields.String,
|
||||||
|
"runtime_mode": fields.String,
|
||||||
|
"chunk_structure": fields.String,
|
||||||
|
"icon_info": fields.Nested(icon_info_fields),
|
||||||
|
"is_published": fields.Boolean,
|
||||||
|
"total_documents": fields.Integer,
|
||||||
|
"total_available_documents": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_query_detail_fields = {
|
dataset_query_detail_fields = {
|
||||||
|
|||||||
164
api/fields/rag_pipeline_fields.py
Normal file
164
api/fields/rag_pipeline_fields.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
from flask_restful import fields # type: ignore
|
||||||
|
|
||||||
|
from fields.workflow_fields import workflow_partial_fields
|
||||||
|
from libs.helper import AppIconUrlField, TimestampField
|
||||||
|
|
||||||
|
pipeline_detail_kernel_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
}
|
||||||
|
|
||||||
|
related_app_list = {
|
||||||
|
"data": fields.List(fields.Nested(pipeline_detail_kernel_fields)),
|
||||||
|
"total": fields.Integer,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_detail_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||||
|
"tracing": fields.Raw,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||||
|
|
||||||
|
app_partial_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String(attribute="desc_or_prompt"),
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
"tags": fields.List(fields.Nested(tag_fields)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
app_pagination_fields = {
|
||||||
|
"page": fields.Integer,
|
||||||
|
"limit": fields.Integer(attribute="per_page"),
|
||||||
|
"total": fields.Integer,
|
||||||
|
"has_more": fields.Boolean(attribute="has_next"),
|
||||||
|
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
template_fields = {
|
||||||
|
"name": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"mode": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
template_list_fields = {
|
||||||
|
"data": fields.List(fields.Nested(template_fields)),
|
||||||
|
}
|
||||||
|
|
||||||
|
site_fields = {
|
||||||
|
"access_token": fields.String(attribute="code"),
|
||||||
|
"code": fields.String,
|
||||||
|
"title": fields.String,
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
"description": fields.String,
|
||||||
|
"default_language": fields.String,
|
||||||
|
"chat_color_theme": fields.String,
|
||||||
|
"chat_color_theme_inverted": fields.Boolean,
|
||||||
|
"customize_domain": fields.String,
|
||||||
|
"copyright": fields.String,
|
||||||
|
"privacy_policy": fields.String,
|
||||||
|
"custom_disclaimer": fields.String,
|
||||||
|
"customize_token_strategy": fields.String,
|
||||||
|
"prompt_public": fields.Boolean,
|
||||||
|
"app_base_url": fields.String,
|
||||||
|
"show_workflow_steps": fields.Boolean,
|
||||||
|
"use_icon_as_answer_icon": fields.Boolean,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_tool_fields = {
|
||||||
|
"type": fields.String,
|
||||||
|
"tool_name": fields.String,
|
||||||
|
"provider_id": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_detail_fields_with_site = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
"enable_site": fields.Boolean,
|
||||||
|
"enable_api": fields.Boolean,
|
||||||
|
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||||
|
"site": fields.Nested(site_fields),
|
||||||
|
"api_base_url": fields.String,
|
||||||
|
"use_icon_as_answer_icon": fields.Boolean,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
app_site_fields = {
|
||||||
|
"app_id": fields.String,
|
||||||
|
"access_token": fields.String(attribute="code"),
|
||||||
|
"code": fields.String,
|
||||||
|
"title": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"default_language": fields.String,
|
||||||
|
"customize_domain": fields.String,
|
||||||
|
"copyright": fields.String,
|
||||||
|
"privacy_policy": fields.String,
|
||||||
|
"custom_disclaimer": fields.String,
|
||||||
|
"customize_token_strategy": fields.String,
|
||||||
|
"prompt_public": fields.Boolean,
|
||||||
|
"show_workflow_steps": fields.Boolean,
|
||||||
|
"use_icon_as_answer_icon": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String}
|
||||||
|
|
||||||
|
pipeline_import_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"status": fields.String,
|
||||||
|
"pipeline_id": fields.String,
|
||||||
|
"dataset_id": fields.String,
|
||||||
|
"current_dsl_version": fields.String,
|
||||||
|
"imported_dsl_version": fields.String,
|
||||||
|
"error": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline_import_check_dependencies_fields = {
|
||||||
|
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
|
||||||
|
}
|
||||||
@@ -40,6 +40,23 @@ conversation_variable_fields = {
|
|||||||
"description": fields.String,
|
"description": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pipeline_variable_fields = {
|
||||||
|
"label": fields.String,
|
||||||
|
"variable": fields.String,
|
||||||
|
"type": fields.String,
|
||||||
|
"belong_to_node_id": fields.String,
|
||||||
|
"max_length": fields.Integer,
|
||||||
|
"required": fields.Boolean,
|
||||||
|
"unit": fields.String,
|
||||||
|
"default_value": fields.Raw,
|
||||||
|
"options": fields.List(fields.String),
|
||||||
|
"placeholder": fields.String,
|
||||||
|
"tooltips": fields.String,
|
||||||
|
"allowed_file_types": fields.List(fields.String),
|
||||||
|
"allow_file_extension": fields.List(fields.String),
|
||||||
|
"allow_file_upload_methods": fields.List(fields.String),
|
||||||
|
}
|
||||||
|
|
||||||
workflow_fields = {
|
workflow_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"graph": fields.Raw(attribute="graph_dict"),
|
"graph": fields.Raw(attribute="graph_dict"),
|
||||||
@@ -55,6 +72,7 @@ workflow_fields = {
|
|||||||
"tool_published": fields.Boolean,
|
"tool_published": fields.Boolean,
|
||||||
"environment_variables": fields.List(EnvironmentVariableField()),
|
"environment_variables": fields.List(EnvironmentVariableField()),
|
||||||
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
||||||
|
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_partial_fields = {
|
workflow_partial_fields = {
|
||||||
|
|||||||
1
api/installed_plugins.jsonl
Normal file
1
api/installed_plugins.jsonl
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"not_installed": [], "plugin_install_failed": []}
|
||||||
@@ -0,0 +1,113 @@
|
|||||||
|
"""add_pipeline_info
|
||||||
|
|
||||||
|
Revision ID: b35c3db83d09
|
||||||
|
Revises: d28f2004b072
|
||||||
|
Create Date: 2025-05-15 15:58:05.179877
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'b35c3db83d09'
|
||||||
|
down_revision = '0ab65e1cc7fa'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('pipeline_built_in_templates',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=False),
|
||||||
|
sa.Column('icon', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('copyright', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('privacy_policy', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('position', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('install_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('language', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('pipeline_customized_templates',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=False),
|
||||||
|
sa.Column('icon', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('position', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('install_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('language', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('pipelines',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False),
|
||||||
|
sa.Column('mode', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('created_by', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('tool_builtin_datasource_providers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider')
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True))
|
||||||
|
|
||||||
|
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('rag_pipeline_variables')
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('chunk_structure')
|
||||||
|
batch_op.drop_column('pipeline_id')
|
||||||
|
batch_op.drop_column('runtime_mode')
|
||||||
|
batch_op.drop_column('icon_info')
|
||||||
|
batch_op.drop_column('keyword_number')
|
||||||
|
|
||||||
|
op.drop_table('tool_builtin_datasource_providers')
|
||||||
|
op.drop_table('pipelines')
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('pipeline_customized_template_tenant_idx')
|
||||||
|
|
||||||
|
op.drop_table('pipeline_customized_templates')
|
||||||
|
op.drop_table('pipeline_built_in_templates')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""add_pipeline_info_2
|
||||||
|
|
||||||
|
Revision ID: abb18a379e62
|
||||||
|
Revises: b35c3db83d09
|
||||||
|
Create Date: 2025-05-16 16:59:16.423127
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'abb18a379e62'
|
||||||
|
down_revision = 'b35c3db83d09'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('mode')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
"""add_pipeline_info_3
|
||||||
|
|
||||||
|
Revision ID: c459994abfa8
|
||||||
|
Revises: abb18a379e62
|
||||||
|
Create Date: 2025-05-30 00:33:14.068312
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'c459994abfa8'
|
||||||
|
down_revision = 'abb18a379e62'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('datasource_oauth_params',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
|
||||||
|
)
|
||||||
|
op.create_table('datasource_providers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('auth_type', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_provider_plugin_id_provider_idx')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
|
||||||
|
batch_op.drop_column('pipeline_id')
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
|
||||||
|
batch_op.drop_column('pipeline_id')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
|
||||||
|
batch_op.drop_column('yaml_content')
|
||||||
|
batch_op.drop_column('chunk_structure')
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
|
||||||
|
batch_op.drop_column('yaml_content')
|
||||||
|
batch_op.drop_column('chunk_structure')
|
||||||
|
|
||||||
|
op.drop_table('datasource_providers')
|
||||||
|
op.drop_table('datasource_oauth_params')
|
||||||
|
# ### end Alembic commands ###
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user