mirror of
https://github.com/langgenius/dify.git
synced 2026-02-05 23:53:58 +00:00
Compare commits
460 Commits
dev/plugin
...
feat/knowl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3107ec878a | ||
|
|
7d5fcfef4c | ||
|
|
e862ab0def | ||
|
|
15f80a72b8 | ||
|
|
1a7de23864 | ||
|
|
10fccd2b3f | ||
|
|
b568947e00 | ||
|
|
7dcbb75839 | ||
|
|
ffdfcdd4a4 | ||
|
|
f4604bf6d0 | ||
|
|
3a72b76c32 | ||
|
|
49dd77e219 | ||
|
|
b9f223d9d4 | ||
|
|
3c4da03575 | ||
|
|
7692476097 | ||
|
|
428438eeca | ||
|
|
b7c546f2ad | ||
|
|
0ed892a747 | ||
|
|
5e2bd407a8 | ||
|
|
a4668e0ffc | ||
|
|
1ca79ea729 | ||
|
|
ebb6de5f52 | ||
|
|
2adc704463 | ||
|
|
b74f1b3c07 | ||
|
|
f60e650400 | ||
|
|
83d0142641 | ||
|
|
56c7f49625 | ||
|
|
7c1d842cfe | ||
|
|
2ea3b64a45 | ||
|
|
824f8d8994 | ||
|
|
31c17e6378 | ||
|
|
50cfb7c9ec | ||
|
|
8281c688ca | ||
|
|
ad9d6eb5f4 | ||
|
|
aa3dc9002c | ||
|
|
4a43e165fb | ||
|
|
4d25b598f9 | ||
|
|
3e9c3d0bb7 | ||
|
|
fec3bb4469 | ||
|
|
d4a09805a3 | ||
|
|
7e1d9894fb | ||
|
|
a8a8a5513c | ||
|
|
470e72c820 | ||
|
|
beebba0340 | ||
|
|
4e27d82d68 | ||
|
|
cdeaf3f70b | ||
|
|
24839bb3e1 | ||
|
|
1650dbfbb1 | ||
|
|
fd11817044 | ||
|
|
6642fc6012 | ||
|
|
2710242982 | ||
|
|
1de84fdda0 | ||
|
|
3befbc1d68 | ||
|
|
62c413aca5 | ||
|
|
6887b501b8 | ||
|
|
f93bf131ab | ||
|
|
ef1f429437 | ||
|
|
c966bf1474 | ||
|
|
899df30bf6 | ||
|
|
8d8d3e3f2f | ||
|
|
5f0fa38ec6 | ||
|
|
cc1fe70d34 | ||
|
|
15ee1e11be | ||
|
|
c8b4a76530 | ||
|
|
6ee4eba86b | ||
|
|
357d2e8be8 | ||
|
|
b5accda3fe | ||
|
|
de4752a16b | ||
|
|
60427f1adf | ||
|
|
1a313c868d | ||
|
|
0b32b1988f | ||
|
|
e56c051d97 | ||
|
|
0a6b4d01d7 | ||
|
|
98b139c680 | ||
|
|
f0a3c14adb | ||
|
|
51947575c2 | ||
|
|
cb8debee3e | ||
|
|
d56079a549 | ||
|
|
c08b451874 | ||
|
|
ac336ff359 | ||
|
|
4cbd511cd7 | ||
|
|
c03adcb154 | ||
|
|
04dade2f9b | ||
|
|
f69220ca96 | ||
|
|
a5e24ff6d3 | ||
|
|
71976f9192 | ||
|
|
39ec6c8025 | ||
|
|
e370045ac4 | ||
|
|
28edbbac0b | ||
|
|
782abcecd8 | ||
|
|
4deb02fc2c | ||
|
|
f967180dc2 | ||
|
|
cead13cbc3 | ||
|
|
078c151065 | ||
|
|
17babca362 | ||
|
|
8efed8858c | ||
|
|
0d411a0b5a | ||
|
|
13f0c01f93 | ||
|
|
3c014f3ae5 | ||
|
|
e4c4490175 | ||
|
|
94a62f6b4e | ||
|
|
d76af08784 | ||
|
|
f748d6c7c4 | ||
|
|
76e24d91c0 | ||
|
|
5ce4ddc0ed | ||
|
|
491d641485 | ||
|
|
172c5f19cc | ||
|
|
b7d168ac59 | ||
|
|
fb309462ad | ||
|
|
b56d2b739b | ||
|
|
fb7b2c8ff3 | ||
|
|
c3440a27fb | ||
|
|
ff3d3f71fb | ||
|
|
9685b9a302 | ||
|
|
07c7b7b886 | ||
|
|
8d75abc976 | ||
|
|
aa6452b3bf | ||
|
|
3799d40937 | ||
|
|
d2ff8a2381 | ||
|
|
5f51a19de2 | ||
|
|
71e0bfcbd8 | ||
|
|
d815c74fc5 | ||
|
|
107e44c8fb | ||
|
|
adf7eea7fe | ||
|
|
6e73ad2fc6 | ||
|
|
06412b37d3 | ||
|
|
63665a5ff1 | ||
|
|
05a43e3e80 | ||
|
|
83fdb42520 | ||
|
|
cbf405beea | ||
|
|
af2aede783 | ||
|
|
e359ace633 | ||
|
|
a5555f90c6 | ||
|
|
78664c8903 | ||
|
|
45070535bd | ||
|
|
048e8cf0d1 | ||
|
|
598d208e54 | ||
|
|
8102cee8df | ||
|
|
c9eb9c14d7 | ||
|
|
e77cd87842 | ||
|
|
ac5e3caebc | ||
|
|
23066a9ba8 | ||
|
|
0249f15609 | ||
|
|
2f523dd29f | ||
|
|
b34d815883 | ||
|
|
51cc63d9ce | ||
|
|
430af95b53 | ||
|
|
0164d1410a | ||
|
|
cbc5045b7a | ||
|
|
b980c07af8 | ||
|
|
e231cf2c48 | ||
|
|
80d8e47e42 | ||
|
|
fee4dd7d7a | ||
|
|
00cf5f3841 | ||
|
|
9ee0c7a694 | ||
|
|
6ee7ca1890 | ||
|
|
f589397f25 | ||
|
|
ee080dddf9 | ||
|
|
ee6841648c | ||
|
|
5a57dad93c | ||
|
|
4199998c7e | ||
|
|
39656f7f84 | ||
|
|
bf39e314d8 | ||
|
|
8cc4c109d0 | ||
|
|
a1cdca02e3 | ||
|
|
1b21d7513d | ||
|
|
d5c708c62b | ||
|
|
342d4060ff | ||
|
|
05232d36f0 | ||
|
|
636dde94c7 | ||
|
|
75fe785d88 | ||
|
|
a61da6cf95 | ||
|
|
93c3699128 | ||
|
|
6357450a7a | ||
|
|
6339706c68 | ||
|
|
65a4cb769b | ||
|
|
63206a7967 | ||
|
|
9a6f120e5c | ||
|
|
dedc1b0c3a | ||
|
|
46bb246ecc | ||
|
|
3c628d0c26 | ||
|
|
c2983ecbb7 | ||
|
|
527c1cf608 | ||
|
|
93786f516c | ||
|
|
a175d6b2d7 | ||
|
|
296fd82bbf | ||
|
|
4ccd571364 | ||
|
|
ae72514cb4 | ||
|
|
16b49ac436 | ||
|
|
c377eb8c28 | ||
|
|
337eff2b79 | ||
|
|
b7ac287fec | ||
|
|
c1a85b0208 | ||
|
|
01efdee1dd | ||
|
|
0af9c4fd9d | ||
|
|
ee38bd8817 | ||
|
|
86291c13e4 | ||
|
|
7679a57f18 | ||
|
|
dcf19549cb | ||
|
|
574a6c1ded | ||
|
|
c34877aecf | ||
|
|
632b2bac2a | ||
|
|
77a62f33b3 | ||
|
|
ad899844a1 | ||
|
|
b10d6051ba | ||
|
|
fb44cd87e7 | ||
|
|
89af726985 | ||
|
|
6f2d5ff099 | ||
|
|
687455ca31 | ||
|
|
8c5928da2f | ||
|
|
772009115d | ||
|
|
0452dfd029 | ||
|
|
eead6abe85 | ||
|
|
5c6d919a4a | ||
|
|
e39eddab03 | ||
|
|
db726e02a0 | ||
|
|
e4b8220bc2 | ||
|
|
08cfcb453c | ||
|
|
992e1eedde | ||
|
|
c2ce8e638e | ||
|
|
ba3659a792 | ||
|
|
965fabd578 | ||
|
|
accbbae755 | ||
|
|
49bd1a7a49 | ||
|
|
5ff9cee326 | ||
|
|
200f9af5d8 | ||
|
|
1443fd6739 | ||
|
|
e63ae36665 | ||
|
|
cfa7c89dfe | ||
|
|
a6835ac64d | ||
|
|
a700b49461 | ||
|
|
22df86fe8a | ||
|
|
24734009b9 | ||
|
|
959d060a44 | ||
|
|
4492295683 | ||
|
|
88fac0d898 | ||
|
|
8b30099672 | ||
|
|
97a3727962 | ||
|
|
2cb640de15 | ||
|
|
fb4ee813c7 | ||
|
|
6300e506fb | ||
|
|
a0543ab8fb | ||
|
|
634cb6233e | ||
|
|
db68ae4a73 | ||
|
|
d25e79e794 | ||
|
|
183b943803 | ||
|
|
5828abcd62 | ||
|
|
56bd0dedfe | ||
|
|
f6136427a4 | ||
|
|
21fd58caf9 | ||
|
|
9a69d03fbe | ||
|
|
1d2118fc5d | ||
|
|
bc0724b499 | ||
|
|
5cdbfe2f41 | ||
|
|
5fd82084f9 | ||
|
|
f0637ba332 | ||
|
|
115c9486c3 | ||
|
|
8b5231b7ee | ||
|
|
38cae29757 | ||
|
|
7a2b2a04c9 | ||
|
|
fe677cc5f9 | ||
|
|
28c9ec3f4f | ||
|
|
6baa98f166 | ||
|
|
e9d69f020a | ||
|
|
3c89d45a2d | ||
|
|
baab81714e | ||
|
|
507bb3549a | ||
|
|
2d1e5fb4e0 | ||
|
|
b9198639e2 | ||
|
|
43c7739b88 | ||
|
|
f65d577f54 | ||
|
|
b88145096f | ||
|
|
33219e850a | ||
|
|
3040d538f7 | ||
|
|
4e1af81e11 | ||
|
|
56e19fd8f5 | ||
|
|
d330d31ee5 | ||
|
|
0858108423 | ||
|
|
2cd976846a | ||
|
|
5d2c88ef59 | ||
|
|
fe3cde973e | ||
|
|
794f495ef2 | ||
|
|
0dda682033 | ||
|
|
01d8d10f1c | ||
|
|
c711c5e36e | ||
|
|
1e27557865 | ||
|
|
2d9632d8b9 | ||
|
|
7e42de1e7b | ||
|
|
bd674d27be | ||
|
|
5735761920 | ||
|
|
405b704f02 | ||
|
|
f38abaaa6a | ||
|
|
c8a5fee622 | ||
|
|
fe1c0ac602 | ||
|
|
e79c3e4531 | ||
|
|
3ea3df7189 | ||
|
|
b01e7d778e | ||
|
|
7c45859594 | ||
|
|
aa9fd76072 | ||
|
|
e7d947379f | ||
|
|
8cd386f2c1 | ||
|
|
987e1b9ced | ||
|
|
81a77d0623 | ||
|
|
ac1f93e3d5 | ||
|
|
0d5c0b4fe4 | ||
|
|
d1c480a7d8 | ||
|
|
007b561e32 | ||
|
|
c100f24f7d | ||
|
|
d92cb994a9 | ||
|
|
413326905e | ||
|
|
5605ff9803 | ||
|
|
84b7a4607a | ||
|
|
10cc4e758c | ||
|
|
8070be9b76 | ||
|
|
f1f1baae9c | ||
|
|
f20c9ef763 | ||
|
|
f798add31c | ||
|
|
8c2dbe876f | ||
|
|
6fd0a55b00 | ||
|
|
bb58f5c6e5 | ||
|
|
18edeb8e0a | ||
|
|
459cb9dd72 | ||
|
|
f9e2c738b0 | ||
|
|
739e15f88b | ||
|
|
5bf86ff66d | ||
|
|
c657378d06 | ||
|
|
685e8cdc7d | ||
|
|
d36dece0af | ||
|
|
5f61aa85db | ||
|
|
e5837b88e0 | ||
|
|
ffdc6f5c60 | ||
|
|
99c8f364ae | ||
|
|
a0a1243c90 | ||
|
|
b916b4064a | ||
|
|
dea2962a79 | ||
|
|
1450e5d5cb | ||
|
|
43a2d4335b | ||
|
|
11270a7ef2 | ||
|
|
53e1b45d40 | ||
|
|
bedbd658fe | ||
|
|
7b62b5578e | ||
|
|
ccbe42eb5f | ||
|
|
45f8651a3d | ||
|
|
7754431a34 | ||
|
|
fa7215cfea | ||
|
|
678c89891a | ||
|
|
beebcbd962 | ||
|
|
8495ed3348 | ||
|
|
31cca4a849 | ||
|
|
43ffccc8fd | ||
|
|
a81293cf5a | ||
|
|
276701e1b7 | ||
|
|
8e1cf3233c | ||
|
|
dd551e6ca8 | ||
|
|
ae1eeb9b2a | ||
|
|
b58f8dd7b4 | ||
|
|
118fa66567 | ||
|
|
699d41deec | ||
|
|
dd0462c1dc | ||
|
|
a470e0e60e | ||
|
|
2622159763 | ||
|
|
dfaf639790 | ||
|
|
ae96f66a08 | ||
|
|
570b7d18ac | ||
|
|
a9c21ef929 | ||
|
|
e27a03ae15 | ||
|
|
56b7853afe | ||
|
|
e12f4009d3 | ||
|
|
6dfc31a542 | ||
|
|
c9f80b46a1 | ||
|
|
0025b27200 | ||
|
|
0dd05d7b6d | ||
|
|
7c83d5ce76 | ||
|
|
a57f60a6e0 | ||
|
|
2f36692bf9 | ||
|
|
bcdb407be8 | ||
|
|
d4e007f9db | ||
|
|
8563155d1b | ||
|
|
8236373498 | ||
|
|
196bfeaaf4 | ||
|
|
957ab093c9 | ||
|
|
e9e5c8806a | ||
|
|
c8bc3892b3 | ||
|
|
735e57b73a | ||
|
|
635a53ea38 | ||
|
|
7b76b1ff82 | ||
|
|
47c8824be6 | ||
|
|
1c3213184e | ||
|
|
d9cced8419 | ||
|
|
c3359a9291 | ||
|
|
2da32e49d0 | ||
|
|
1837692a66 | ||
|
|
5dcd25a613 | ||
|
|
507fff0259 | ||
|
|
0ad9dbea63 | ||
|
|
4c28034224 | ||
|
|
1d575524c3 | ||
|
|
dc255cc154 | ||
|
|
ea497f828f | ||
|
|
153dc5b3f3 | ||
|
|
a91951b374 | ||
|
|
68c10a1672 | ||
|
|
592f85f7a9 | ||
|
|
cda9f6ec6b | ||
|
|
64706c709c | ||
|
|
9722e6bcb1 | ||
|
|
1907d791e1 | ||
|
|
fb3a701c86 | ||
|
|
947bfdc807 | ||
|
|
7a3e756020 | ||
|
|
435e71eb60 | ||
|
|
91cb80f795 | ||
|
|
3c1d32e3ac | ||
|
|
eef79a5196 | ||
|
|
2223dfb266 | ||
|
|
9693b5ad0c | ||
|
|
d4bf575d0a | ||
|
|
73ce692e24 | ||
|
|
661392eaef | ||
|
|
c472ea6c67 | ||
|
|
4eaba3049a | ||
|
|
00d1c45518 | ||
|
|
87c746f6bb | ||
|
|
70c001436e | ||
|
|
cf73374c1b | ||
|
|
b0d53c0ac4 | ||
|
|
9c7bcd5abc | ||
|
|
b7c5abc5dd | ||
|
|
de01ca8d55 | ||
|
|
60e75dc748 | ||
|
|
279dee485d | ||
|
|
db8bf2a85e | ||
|
|
46ba16fe90 | ||
|
|
886a160115 | ||
|
|
cf4e9f317e | ||
|
|
1fa3b9cfd8 | ||
|
|
50a5cfe56a | ||
|
|
ece82b87bf | ||
|
|
12ea085e22 | ||
|
|
41ed2e0cc2 | ||
|
|
113ff27d07 | ||
|
|
ec711d094d | ||
|
|
a073de44e9 | ||
|
|
6ce02b07d3 | ||
|
|
f47712beae | ||
|
|
4a8d3c54ca | ||
|
|
c8b0160ea9 | ||
|
|
531ffaec4f | ||
|
|
c28998a6f0 | ||
|
|
4b4741f7ed | ||
|
|
25b8a512bf | ||
|
|
02d26818ad | ||
|
|
31e8b134d1 | ||
|
|
d52476c1c9 | ||
|
|
f29b44acd8 | ||
|
|
ed7fcc5f7d | ||
|
|
c6f34f5c17 | ||
|
|
e1db77eec2 | ||
|
|
563d81277b | ||
|
|
364df36ac4 |
4
.github/workflows/api-tests.yml
vendored
4
.github/workflows/api-tests.yml
vendored
@@ -4,6 +4,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
@@ -26,9 +27,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Poetry and Python ${{ matrix.python-version }}
|
||||
uses: ./.github/actions/setup-poetry
|
||||
|
||||
17
.github/workflows/build-push.yml
vendored
17
.github/workflows/build-push.yml
vendored
@@ -5,7 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "dev/plugin-deploy"
|
||||
- "plugins/beta"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
@@ -80,12 +80,10 @@ jobs:
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
DIGEST: ${{ steps.build.outputs.digest }}
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
sanitized_digest=${DIGEST#sha256:}
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -135,15 +133,10 @@ jobs:
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
env:
|
||||
IMAGE_NAME: ${{ env[matrix.image_name_env] }}
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf "$IMAGE_NAME@sha256:%s " *)
|
||||
$(printf '${{ env[matrix.image_name_env] }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
env:
|
||||
IMAGE_NAME: ${{ env[matrix.image_name_env] }}
|
||||
IMAGE_VERSION: ${{ steps.meta.outputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools inspect "$IMAGE_NAME:$IMAGE_VERSION"
|
||||
docker buildx imagetools inspect ${{ env[matrix.image_name_env] }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
3
.github/workflows/db-migration-test.yml
vendored
3
.github/workflows/db-migration-test.yml
vendored
@@ -20,9 +20,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Poetry and Python
|
||||
uses: ./.github/actions/setup-poetry
|
||||
|
||||
2
.github/workflows/expose_service_ports.sh
vendored
2
.github/workflows/expose_service_ports.sh
vendored
@@ -9,6 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
|
||||
17
.github/workflows/style.yml
vendored
17
.github/workflows/style.yml
vendored
@@ -4,6 +4,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@@ -17,9 +18,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@@ -62,9 +60,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@@ -92,7 +87,7 @@ jobs:
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm run lint
|
||||
run: yarn run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
@@ -101,9 +96,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@@ -132,9 +124,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
@@ -152,7 +141,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
DEFAULT_BRANCH: plugins/beta
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
5
.github/workflows/tool-test-sdks.yaml
vendored
5
.github/workflows/tool-test-sdks.yaml
vendored
@@ -26,9 +26,6 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
@@ -38,7 +35,7 @@ jobs:
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: pnpm install
|
||||
|
||||
- name: Test
|
||||
run: pnpm test
|
||||
|
||||
@@ -16,7 +16,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
|
||||
17
.github/workflows/vdb-tests.yml
vendored
17
.github/workflows/vdb-tests.yml
vendored
@@ -28,9 +28,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Poetry and Python ${{ matrix.python-version }}
|
||||
uses: ./.github/actions/setup-poetry
|
||||
@@ -54,15 +51,7 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Store (TiDB)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: docker/tidb/docker-compose.yaml
|
||||
services: |
|
||||
tidb
|
||||
tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -78,9 +67,7 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
|
||||
- name: Check TiDB Ready
|
||||
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
35
.github/workflows/web-tests.yml
vendored
35
.github/workflows/web-tests.yml
vendored
@@ -22,34 +22,25 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: web/**
|
||||
# to run pnpm, should install package canvas, but it always install failed on amd64 under ubuntu-latest
|
||||
# - name: Install pnpm
|
||||
# uses: pnpm/action-setup@v4
|
||||
# with:
|
||||
# version: 10
|
||||
# run_install: false
|
||||
|
||||
# - name: Setup Node.js
|
||||
# uses: actions/setup-node@v4
|
||||
# if: steps.changed-files.outputs.any_changed == 'true'
|
||||
# with:
|
||||
# node-version: 20
|
||||
# cache: pnpm
|
||||
# cache-dependency-path: ./web/package.json
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
# - name: Install dependencies
|
||||
# if: steps.changed-files.outputs.any_changed == 'true'
|
||||
# run: pnpm install --frozen-lockfile
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# - name: Run tests
|
||||
# if: steps.changed-files.outputs.any_changed == 'true'
|
||||
# run: pnpm test
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: pnpm test
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -163,7 +163,6 @@ docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
docker/tidb/volumes/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
|
||||
@@ -73,7 +73,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
||||
* [Docker](https://www.docker.com/)
|
||||
* [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
* [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
* [pnpm](https://pnpm.io/)
|
||||
* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
* [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. Installations
|
||||
|
||||
@@ -70,7 +70,7 @@ Dify 依赖以下工具和库:
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. 安装
|
||||
|
||||
@@ -73,7 +73,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. インストール
|
||||
|
||||
@@ -72,7 +72,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
|
||||
|
||||
### 4. Cài đặt
|
||||
|
||||
23
LICENSE
23
LICENSE
@@ -1,12 +1,12 @@
|
||||
# Open Source License
|
||||
|
||||
Dify is licensed under a modified version of the Apache License 2.0, with the following additional conditions:
|
||||
Dify is licensed under the Apache License 2.0, with the following additional conditions:
|
||||
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend.
|
||||
- Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker.
|
||||
|
||||
@@ -21,4 +21,19 @@ Apart from the specific conditions mentioned above, all other rights and restric
|
||||
|
||||
The interactive design of this product is protected by appearance patent.
|
||||
|
||||
© 2025 LangGenius, Inc.
|
||||
© 2024 LangGenius, Inc.
|
||||
|
||||
|
||||
----------
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
66
README.md
66
README.md
@@ -108,72 +108,6 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
|
||||
## Feature Comparison
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Programming Approach</td>
|
||||
<td align="center">API + App-oriented</td>
|
||||
<td align="center">Python Code</td>
|
||||
<td align="center">App-oriented</td>
|
||||
<td align="center">API-oriented</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Supported LLMs</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">OpenAI-only</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG Engine</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Workflow</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Observability</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Enterprise Feature (SSO/Access control)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Local Deployment</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using Dify
|
||||
|
||||
|
||||
16
README_FR.md
16
README_FR.md
@@ -55,7 +55,7 @@
|
||||
Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:
|
||||
</br> </br>
|
||||
|
||||
**1. Flux de travail** :
|
||||
**1. Flux de travail**:
|
||||
Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore.
|
||||
|
||||
|
||||
@@ -63,25 +63,27 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
|
||||
|
||||
|
||||
|
||||
**2. Prise en charge complète des modèles** :
|
||||
**2. Prise en charge complète des modèles**:
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
**3. IDE de prompt** :
|
||||
**3. IDE de prompt**:
|
||||
Interface intuitive pour créer des prompts, comparer les performances des modèles et ajouter des fonctionnalités supplémentaires telles que la synthèse vocale à une application basée sur des chats.
|
||||
|
||||
**4. Pipeline RAG** :
|
||||
**4. Pipeline RAG**:
|
||||
Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants.
|
||||
|
||||
**5. Capacités d'agent** :
|
||||
**5. Capac
|
||||
|
||||
ités d'agent**:
|
||||
Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha.
|
||||
|
||||
**6. LLMOps** :
|
||||
**6. LLMOps**:
|
||||
Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations.
|
||||
|
||||
**7. Backend-as-a-Service** :
|
||||
**7. Backend-as-a-Service**:
|
||||
Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier.
|
||||
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
|
||||
- **企業/組織向けのDify</br>**
|
||||
企業中心の機能を提供しています。[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)して企業のニーズについて相談してください。 </br>
|
||||
> AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t23mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングとして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。
|
||||
> AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングどして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。
|
||||
|
||||
|
||||
## 最新の情報を入手
|
||||
|
||||
@@ -87,7 +87,9 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
## Feature Comparison
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<tr
|
||||
|
||||
>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
|
||||
69
README_SI.md
69
README_SI.md
@@ -106,73 +106,6 @@ Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-star
|
||||
**7. Backend-as-a-Service**:
|
||||
AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko.
|
||||
|
||||
## Primerjava Funkcij
|
||||
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Funkcija</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Programski pristop</td>
|
||||
<td align="center">API + usmerjeno v aplikacije</td>
|
||||
<td align="center">Python koda</td>
|
||||
<td align="center">Usmerjeno v aplikacije</td>
|
||||
<td align="center">Usmerjeno v API</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Podprti LLM-ji</td>
|
||||
<td align="center">Bogata izbira</td>
|
||||
<td align="center">Bogata izbira</td>
|
||||
<td align="center">Bogata izbira</td>
|
||||
<td align="center">Samo OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG pogon</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Potek dela</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Spremljanje</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Funkcija za podjetja (SSO/nadzor dostopa)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Lokalna namestitev</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Uporaba Dify
|
||||
|
||||
@@ -254,4 +187,4 @@ Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj
|
||||
|
||||
## Licenca
|
||||
|
||||
To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
|
||||
To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
|
||||
@@ -55,11 +55,9 @@ RUN \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
media-types \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
libmagic1 \
|
||||
&& apt-get autoremove -y \
|
||||
|
||||
@@ -37,13 +37,7 @@
|
||||
|
||||
4. Create environment.
|
||||
|
||||
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand]
|
||||
|
||||
```bash
|
||||
poetry self add poetry-plugin-shell
|
||||
```
|
||||
|
||||
Then, You can execute `poetry shell` to activate the environment.
|
||||
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. You can execute `poetry shell` to activate the environment.
|
||||
|
||||
5. Install dependencies
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
@@ -17,12 +16,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
|
||||
@@ -707,13 +707,12 @@ def extract_unique_plugins(output_file: str, input_file: str):
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_plugins(input_file: str, output_file: str, workers: int):
|
||||
def install_plugins(input_file: str, output_file: str):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting install plugins.", fg="white"))
|
||||
|
||||
PluginMigration.install_plugins(input_file, output_file, workers)
|
||||
PluginMigration.install_plugins(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Install plugins completed.", fg="green"))
|
||||
|
||||
@@ -373,8 +373,8 @@ class HttpConfig(BaseSettings):
|
||||
)
|
||||
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
||||
" when the app is behind a single trusted reverse proxy.",
|
||||
description="Enable or disable the X-Forwarded-For Proxy Fix middleware from Werkzeug"
|
||||
" to respect X-* headers to redirect clients",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
@@ -167,11 +166,6 @@ class DatabaseConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count(),
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -15,7 +15,7 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@@ -2,8 +2,6 @@ from contextvars import ContextVar
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
@@ -14,17 +12,8 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
"""
|
||||
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||
"""
|
||||
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_tool_providers")
|
||||
)
|
||||
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
|
||||
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers")
|
||||
)
|
||||
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HiddenValue:
|
||||
pass
|
||||
|
||||
|
||||
_default = HiddenValue()
|
||||
|
||||
|
||||
class RecyclableContextVar(Generic[T]):
|
||||
"""
|
||||
RecyclableContextVar is a wrapper around ContextVar
|
||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||
|
||||
NOTE: you need to call `increment_thread_recycles` before requests
|
||||
"""
|
||||
|
||||
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
|
||||
|
||||
@classmethod
|
||||
def increment_thread_recycles(cls):
|
||||
try:
|
||||
recycles = cls._thread_recycles.get()
|
||||
cls._thread_recycles.set(recycles + 1)
|
||||
except LookupError:
|
||||
cls._thread_recycles.set(0)
|
||||
|
||||
def __init__(self, context_var: ContextVar[T]):
|
||||
self._context_var = context_var
|
||||
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
|
||||
|
||||
def get(self, default: T | HiddenValue = _default) -> T:
|
||||
thread_recycles = self._thread_recycles.get(0)
|
||||
self_updates = self._updates.get()
|
||||
if thread_recycles > self_updates:
|
||||
self._updates.set(thread_recycles)
|
||||
|
||||
# check if thread is recycled and should be updated
|
||||
if thread_recycles < self_updates:
|
||||
return self._context_var.get()
|
||||
else:
|
||||
# thread_recycles >= self_updates, means current context is invalid
|
||||
if isinstance(default, HiddenValue) or default is _default:
|
||||
raise LookupError
|
||||
else:
|
||||
return default
|
||||
|
||||
def set(self, value: T):
|
||||
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
|
||||
# increase it manually
|
||||
thread_recycles = self._thread_recycles.get(0)
|
||||
self_updates = self._updates.get()
|
||||
if thread_recycles > self_updates:
|
||||
self._updates.set(thread_recycles)
|
||||
|
||||
if self._updates.get() == self._thread_recycles.get(0):
|
||||
# after increment,
|
||||
self._updates.set(self._updates.get() + 1)
|
||||
|
||||
# set the context
|
||||
self._context_var.set(value)
|
||||
@@ -14,7 +14,6 @@ from controllers.console.wraps import account_initialization_required, enterpris
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@@ -73,9 +72,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality":
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
@@ -623,6 +620,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
|
||||
@@ -617,7 +617,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
@@ -678,7 +678,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return metadata, 201
|
||||
|
||||
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return built_in_fields, 200
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return 200
|
||||
|
||||
|
||||
class DocumentMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@@ -1,5 +1,3 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -73,8 +71,7 @@ class FilePreviewApi(Resource):
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"custom_config": json.loads(tenant.custom_config) if tenant.custom_config else {},
|
||||
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -10,7 +10,6 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
@@ -19,6 +18,26 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {"rating": fields.String}
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
@@ -70,7 +89,7 @@ class MessageListApi(Resource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -336,10 +336,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset is not exist.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
@@ -154,7 +154,7 @@ def validate_dataset_token(view=None):
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
@@ -34,6 +34,27 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
|
||||
@@ -329,7 +329,6 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
agent_thought = updated_agent_thought
|
||||
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
@@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@@ -61,7 +61,9 @@ class ModelConfigManager:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
@@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
|
||||
|
||||
class AdvancedChatMessageEntity(BaseModel):
|
||||
@@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
@@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
|
||||
@@ -140,7 +140,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
|
||||
@@ -149,7 +149,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
|
||||
@@ -141,7 +141,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
|
||||
@@ -42,6 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
|
||||
@@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
model_schema: AIModelEntity
|
||||
mode: str
|
||||
provider_model_bundle: ProviderModelBundle
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
credentials: dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
|
||||
call_depth: int = 0
|
||||
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = {}
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
@@ -844,7 +844,7 @@ class WorkflowCycleManage:
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return session.merge(cached_workflow_node_execution)
|
||||
return cached_workflow_node_execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
|
||||
@@ -6,10 +6,10 @@ from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import or_
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
@@ -191,11 +190,8 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -283,10 +279,7 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
@@ -1003,7 +996,7 @@ class ProviderConfigurations(BaseModel):
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@@ -1059,7 +1052,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if "/" not in key:
|
||||
key = str(ModelProviderID(key))
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations[key]
|
||||
|
||||
@@ -1074,7 +1067,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||
if "/" not in key:
|
||||
key = str(ModelProviderID(key))
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
|
||||
|
||||
@@ -41,13 +41,9 @@ class HostedModerationConfig(BaseModel):
|
||||
|
||||
|
||||
class HostingConfiguration:
|
||||
provider_map: dict[str, HostingProvider]
|
||||
provider_map: dict[str, HostingProvider] = {}
|
||||
moderation_config: Optional[HostedModerationConfig] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_map = {}
|
||||
self.moderation_config = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
||||
@@ -228,7 +228,7 @@ class LargeLanguageModel(AIModel):
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
prompt_message = AssistantPromptMessage(content="")
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
@@ -250,7 +250,7 @@ class LargeLanguageModel(AIModel):
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
assistant_message.content += chunk.delta.message.content
|
||||
prompt_message.content += chunk.delta.message.content
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
@@ -265,7 +265,7 @@ class LargeLanguageModel(AIModel):
|
||||
result=LLMResult(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_message,
|
||||
message=prompt_message,
|
||||
usage=usage or LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
import contexts
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
@@ -33,11 +34,9 @@ class ModelProviderExtension(BaseModel):
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
provider_position_map: dict[str, int]
|
||||
provider_position_map: dict[str, int] = {}
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.provider_position_map = {}
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelManager()
|
||||
|
||||
@@ -361,5 +360,11 @@ class ModelProviderFactory:
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
plugin_id = DEFAULT_PLUGIN_ID
|
||||
provider_name = provider
|
||||
if "/" in provider:
|
||||
# get the plugin_id before provider
|
||||
plugin_id = "/".join(provider.split("/")[:-1])
|
||||
provider_name = provider.split("/")[-1]
|
||||
|
||||
return str(plugin_id), provider_name
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
- claude-3-haiku@20240307
|
||||
- claude-3-opus@20240229
|
||||
- claude-3-sonnet@20240229
|
||||
- claude-3-5-sonnet-v2@20241022
|
||||
- claude-3-5-sonnet@20240620
|
||||
- gemini-1.0-pro-vision-001
|
||||
- gemini-1.0-pro-002
|
||||
- gemini-1.5-flash-001
|
||||
- gemini-1.5-flash-002
|
||||
- gemini-1.5-pro-001
|
||||
- gemini-1.5-pro-002
|
||||
- gemini-2.0-flash-001
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-lite-preview-02-05
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-pro-exp-02-05
|
||||
- gemini-exp-1114
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1206
|
||||
- gemini-flash-experimental
|
||||
- gemini-pro-experimental
|
||||
@@ -159,7 +159,7 @@ class GenericProviderID:
|
||||
if re.match(r"^[a-z0-9_-]+$", value):
|
||||
value = f"langgenius/{value}/{value}"
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin id {value}")
|
||||
raise ValueError("Invalid plugin id")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
@@ -173,15 +173,15 @@ class ModelProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius" and self.provider_name == "google":
|
||||
self.plugin_name = "gemini"
|
||||
self.provider_name = "gemini"
|
||||
|
||||
|
||||
class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
if self.provider_name in ["jina", "siliconflow"]:
|
||||
self.provider_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
@@ -212,9 +212,3 @@ class PluginDependency(BaseModel):
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
current_identifier: Optional[str] = None
|
||||
|
||||
|
||||
class MissingPluginDependency(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
current_identifier: Optional[str] = None
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Sequence
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
@@ -176,16 +175,14 @@ class PluginInstallationManager(BasePluginManager):
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_missing_dependencies(
|
||||
self, tenant_id: str, plugin_unique_identifiers: list[str]
|
||||
) -> list[MissingPluginDependency]:
|
||||
def fetch_missing_dependencies(self, tenant_id: str, plugin_unique_identifiers: list[str]) -> list[str]:
|
||||
"""
|
||||
Fetch missing dependencies
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/installation/missing",
|
||||
list[MissingPluginDependency],
|
||||
list[str],
|
||||
data={"plugin_unique_identifiers": plugin_unique_identifiers},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@@ -45,7 +45,7 @@ class PluginToolManager(BasePluginManager):
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
|
||||
@@ -30,7 +30,6 @@ from core.model_runtime.entities.provider_entities import (
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -100,15 +99,6 @@ class ProviderManager:
|
||||
tenant_id, provider_name_to_provider_records_dict
|
||||
)
|
||||
|
||||
# append providers with langgenius/openai/openai
|
||||
provider_name_list = list(provider_name_to_provider_records_dict.keys())
|
||||
for provider_name in provider_name_list:
|
||||
provider_id = ModelProviderID(provider_name)
|
||||
if str(provider_id) not in provider_name_list:
|
||||
provider_name_to_provider_records_dict[str(provider_id)] = provider_name_to_provider_records_dict[
|
||||
provider_name
|
||||
]
|
||||
|
||||
# Get all provider model records of the workspace
|
||||
provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id)
|
||||
|
||||
@@ -201,7 +191,7 @@ class ProviderManager:
|
||||
model_settings=model_settings,
|
||||
)
|
||||
|
||||
provider_configurations[str(ModelProviderID(provider_name))] = provider_configuration
|
||||
provider_configurations[provider_name] = provider_configuration
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
@@ -369,8 +359,7 @@ class ProviderManager:
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
# TODO: Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
provider_name_to_provider_records_dict[provider.provider_name].append(provider)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@@ -464,9 +453,11 @@ class ProviderManager:
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
(
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@@ -509,8 +500,7 @@ class ProviderManager:
|
||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
@@ -525,12 +515,13 @@ class ProviderManager:
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_record and not provider_record.is_valid:
|
||||
provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
@@ -88,17 +88,16 @@ class Jieba(BaseKeyword):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment_query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
||||
.first()
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
|
||||
if segment:
|
||||
documents.append(
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
@@ -30,7 +26,6 @@ default_retrieval_model = {
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
# Cache precompiled regular expressions to avoid repeated compilation
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
@@ -42,68 +37,77 @@ class RetrievalService:
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents: list[Document] = []
|
||||
threads: list[threading.Thread] = []
|
||||
exceptions: list[str] = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == "keyword_search":
|
||||
keyword_thread = threading.Thread(
|
||||
target=RetrievalService.keyword_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
embedding_thread = threading.Thread(
|
||||
target=RetrievalService.embedding_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"retrieval_method": retrieval_method,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.embedding_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.full_text_index_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
full_text_index_thread = threading.Thread(
|
||||
target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"retrieval_method": retrieval_method,
|
||||
"score_threshold": score_threshold,
|
||||
"top_k": top_k,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise ValueError(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
@@ -128,32 +132,19 @@ class RetrievalService:
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
|
||||
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
keyword = Keyword(dataset=dataset)
|
||||
|
||||
documents = keyword.search(
|
||||
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
|
||||
)
|
||||
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
@@ -170,22 +161,21 @@ class RetrievalService:
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
document_ids_filter: Optional[list[str]] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
query,
|
||||
cls.escape_query_for_search(query),
|
||||
search_type="similarity_score_threshold",
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={"group_id": [dataset.id]},
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
|
||||
if documents:
|
||||
@@ -196,7 +186,7 @@ class RetrievalService:
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@@ -226,11 +216,13 @@ class RetrievalService:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector_processor = Vector(dataset=dataset)
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
)
|
||||
|
||||
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
|
||||
if documents:
|
||||
@@ -241,7 +233,7 @@ class RetrievalService:
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@@ -258,106 +250,66 @@ class RetrievalService:
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
return json.dumps(query).strip('"')
|
||||
|
||||
@classmethod
|
||||
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
|
||||
"""Format retrieval documents with optimized batch processing"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Collect document IDs
|
||||
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Batch query dataset documents
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
# Process documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
return query.replace('"', '\\"')
|
||||
|
||||
@staticmethod
|
||||
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
|
||||
records = []
|
||||
include_segment_ids = []
|
||||
segment_child_map = {}
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
result = (
|
||||
db.session.query(ChildChunk, DocumentSegment)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == child_index_node_id,
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == child_chunk.segment_id,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
DocumentSegment.id,
|
||||
DocumentSegment.content,
|
||||
DocumentSegment.answer,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
if result:
|
||||
child_chunk, segment = result
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.append(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
# Handle normal documents
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
else:
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
@@ -372,21 +324,16 @@ class RetrievalService:
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
include_segment_ids.add(segment.id)
|
||||
include_segment_ids.append(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
"score": document.metadata.get("score", None),
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
# Add child chunks information to records
|
||||
records.append(record)
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
|
||||
@@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
@@ -194,11 +194,6 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = "WHERE 1=1"
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
@@ -207,7 +202,7 @@ class AnalyticdbVectorBySql:
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
@@ -225,17 +220,12 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
|
||||
@@ -123,21 +123,11 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
filter=f"document_id IN ({document_ids})",
|
||||
)
|
||||
else:
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||
)
|
||||
res = self._db.table(self._collection_name).search(
|
||||
anns=anns,
|
||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||
|
||||
@@ -95,15 +95,7 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
results: QueryResult = collection.query(
|
||||
query_embeddings=query_vector,
|
||||
n_results=kwargs.get("top_k", 4),
|
||||
where={"document_id": {"$in": document_ids_filter}},
|
||||
)
|
||||
else:
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Check if results contain data
|
||||
@@ -119,9 +111,8 @@ class ChromaVector(BaseVector):
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
if distance >= score_threshold:
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
|
||||
@@ -117,9 +117,6 @@ class ElasticSearchVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
@@ -148,9 +145,6 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
||||
@@ -168,12 +168,7 @@ class LindormVectorStore(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filters = []
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
try:
|
||||
params = {}
|
||||
if self._using_ugc:
|
||||
@@ -211,10 +206,7 @@ class LindormVectorStore(BaseVector):
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter", [])
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||
filters = kwargs.get("filter")
|
||||
routing = self._routing
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
|
||||
@@ -218,18 +218,12 @@ class MilvusVector(BaseVector):
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
@@ -245,11 +239,6 @@ class MilvusVector(BaseVector):
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f'metadata["document_id"] in ({document_ids})'
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
@@ -257,7 +246,6 @@ class MilvusVector(BaseVector):
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
|
||||
@@ -131,10 +131,6 @@ class MyScaleVector(BaseVector):
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
|
||||
@@ -154,11 +154,6 @@ class OceanBaseVector(BaseVector):
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = None
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
@@ -172,7 +167,6 @@ class OceanBaseVector(BaseVector):
|
||||
distance_func=func.l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
where_clause=where_clause,
|
||||
)
|
||||
docs = []
|
||||
for text, metadata, distance in cur:
|
||||
|
||||
@@ -154,9 +154,6 @@ class OpenSearchVector(BaseVector):
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
@@ -182,9 +179,6 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
|
||||
@@ -185,15 +185,10 @@ class OracleVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||
f" ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
@@ -246,15 +241,9 @@ class OracleVector(BaseVector):
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name}"
|
||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||
f"order by score(1) desc fetch first {top_k} rows only",
|
||||
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
|
||||
@@ -189,9 +189,6 @@ class PGVectoRS(BaseVector):
|
||||
.limit(kwargs.get("top_k", 4))
|
||||
.order_by("distance")
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
|
||||
res = session.execute(stmt)
|
||||
results = [(row[0], row[1]) for row in res]
|
||||
|
||||
|
||||
@@ -155,16 +155,10 @@ class PGVector(BaseVector):
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" {where_clause}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
@@ -182,16 +176,10 @@ class PGVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
|
||||
@@ -286,26 +286,27 @@ class QdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
@@ -330,14 +331,6 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@@ -384,14 +377,6 @@ class QdrantVector(BaseVector):
|
||||
),
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
||||
@@ -223,12 +223,8 @@ class RelytVector(BaseVector):
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = kwargs.get("filter", {})
|
||||
if document_ids_filter:
|
||||
filter["document_id"] = document_ids_filter
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
@@ -250,9 +246,9 @@ class RelytVector(BaseVector):
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>'{key!r}' = {value[0]!r}"
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||
|
||||
@@ -145,16 +145,11 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
filter=filter,
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
|
||||
@@ -326,14 +326,6 @@ class TidbOnQdrantVector(BaseVector):
|
||||
),
|
||||
],
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
@@ -376,14 +368,6 @@ class TidbOnQdrantVector(BaseVector):
|
||||
)
|
||||
]
|
||||
)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
scroll_filter.must.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchAny(any=document_ids_filter),
|
||||
)
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@@ -55,13 +54,14 @@ class TiDBVector(BaseVector):
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
|
||||
Column("id", String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
Field.VECTOR.value,
|
||||
"vector",
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
|
||||
),
|
||||
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
|
||||
Column("text", TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column(
|
||||
@@ -96,7 +96,6 @@ class TiDBVector(BaseVector):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
create_statement = sql_text(f"""
|
||||
@@ -105,14 +104,14 @@ class TiDBVector(BaseVector):
|
||||
text TEXT NOT NULL,
|
||||
meta JSON NOT NULL,
|
||||
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
KEY (doc_id),
|
||||
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
|
||||
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
# tidb vector not support 'CREATE/ADD INDEX' now
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@@ -195,36 +194,23 @@ class TiDBVector(BaseVector):
|
||||
)
|
||||
|
||||
docs = []
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
|
||||
if self._distance_func == "l2":
|
||||
tidb_func = "Vec_l2_distance"
|
||||
elif self._distance_func == "cosine":
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
else:
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(f"""
|
||||
SELECT meta, text, distance
|
||||
FROM (
|
||||
SELECT
|
||||
meta,
|
||||
text,
|
||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||
FROM {self._collection_name}
|
||||
{where_clause}
|
||||
ORDER BY distance ASC
|
||||
LIMIT :top_k
|
||||
) t
|
||||
WHERE distance <= :distance
|
||||
""")
|
||||
res = session.execute(
|
||||
select_statement,
|
||||
params={
|
||||
"query_vector_str": query_vector_str,
|
||||
"distance": distance,
|
||||
"top_k": top_k,
|
||||
},
|
||||
select_statement = sql_text(
|
||||
f"""SELECT meta, text, distance FROM (
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
) t WHERE distance < {distance};"""
|
||||
)
|
||||
res = session.execute(select_statement)
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
for meta, text, distance in results:
|
||||
metadata = json.loads(meta)
|
||||
@@ -241,16 +227,6 @@ class TiDBVector(BaseVector):
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
case "l2":
|
||||
tidb_dist_func = "VEC_L2_DISTANCE"
|
||||
case "cosine":
|
||||
tidb_dist_func = "VEC_COSINE_DISTANCE"
|
||||
case _:
|
||||
tidb_dist_func = "VEC_COSINE_DISTANCE"
|
||||
return tidb_dist_func
|
||||
|
||||
|
||||
class TiDBVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|
||||
|
||||
@@ -88,20 +88,7 @@ class UpstashVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
filter = f"document_id in ({document_ids})"
|
||||
else:
|
||||
filter = ""
|
||||
result = self.index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
include_metadata=True,
|
||||
include_data=True,
|
||||
include_vectors=False,
|
||||
filter=filter,
|
||||
)
|
||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in result:
|
||||
|
||||
@@ -49,10 +49,6 @@ class BaseVector(ABC):
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, document_id: str, metadata: dict) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata and "doc_id" in text.metadata:
|
||||
|
||||
@@ -177,11 +177,7 @@ class VikingDBVector(BaseVector):
|
||||
query_vector, limit=kwargs.get("top_k", 4)
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
docs = self._get_search_res(results, score_threshold)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
|
||||
return docs
|
||||
return self._get_search_res(results, score_threshold)
|
||||
|
||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
|
||||
@@ -168,16 +168,16 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
try:
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where={"operator": "ContainsAny", "path": ["id"], "valueTextArray": ids},
|
||||
output="minimal",
|
||||
)
|
||||
except weaviate.UnexpectedStatusCodeException as e:
|
||||
# tolerate not found error
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
for uuid in ids:
|
||||
try:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
except weaviate.UnexpectedStatusCodeException as e:
|
||||
# tolerate not found error
|
||||
if e.status_code != 404:
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
@@ -187,10 +187,8 @@ class WeaviateVector(BaseVector):
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(kwargs.get("top_k", 4))
|
||||
@@ -235,10 +233,8 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BuiltInField(str, Enum):
|
||||
document_name = "document_name"
|
||||
uploader = "uploader"
|
||||
upload_date = "upload_date"
|
||||
last_update_date = "last_update_date"
|
||||
source = "source"
|
||||
@@ -237,7 +237,6 @@ class DatasetRetrieval:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
@@ -292,11 +291,6 @@ class DatasetRetrieval:
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
@@ -328,7 +322,6 @@ class DatasetRetrieval:
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
|
||||
@@ -105,10 +105,10 @@ class ApiTool(Tool):
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
if parameter.default is not None:
|
||||
parameters[parameter.name] = parameter.default
|
||||
else:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
|
||||
if parameter.default is not None and parameter.name not in parameters:
|
||||
parameters[parameter.name] = parameter.default
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
@@ -246,11 +246,10 @@ class ToolEngine:
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result = json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
)
|
||||
text = json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)
|
||||
result += f"tool response: {text}."
|
||||
else:
|
||||
result += str(response.message)
|
||||
result += f"tool response: {response.message!r}."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@@ -188,7 +188,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
if isinstance(provider_controller, PluginToolProviderController):
|
||||
provider_id_entity = ToolProviderID(provider_id)
|
||||
provider_id_entity = GenericProviderID(provider_id)
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
@@ -572,96 +572,95 @@ class ToolManager:
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
if "builtin" in filters:
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
tool_provider_id = GenericProviderID(db_provider.provider)
|
||||
db_provider.provider = tool_provider_id.to_string()
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
tool_provider_id = str(ToolProviderID(db_provider.provider))
|
||||
db_provider.provider = tool_provider_id
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||
else:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
# get db api providers
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||
else:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
# get db api providers
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
return BuiltinToolProviderSort.sort(list(result_providers.values()))
|
||||
|
||||
|
||||
@@ -3,13 +3,11 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
@@ -56,6 +54,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
if dataset.provider == "external":
|
||||
@@ -126,6 +125,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
@@ -134,46 +134,50 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
index_node_ids = [document.metadata["doc_id"] for document in documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=record.score,
|
||||
)
|
||||
f"question:{segment.get_sign_content()} answer:{segment.answer}"
|
||||
)
|
||||
else:
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=record.score,
|
||||
)
|
||||
)
|
||||
retrieval_resource_list = []
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
document_segment = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
if not document_segment:
|
||||
continue
|
||||
if dataset and document_segment:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id, # type: ignore
|
||||
"document_name": document.name, # type: ignore
|
||||
"data_source_type": document.data_source_type, # type: ignore
|
||||
"document_id": document_segment.id,
|
||||
"document_name": document_segment.name,
|
||||
"data_source_type": document_segment.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": record.score or 0.0,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
}
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
@@ -183,19 +187,10 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
|
||||
if self.return_resource and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x.get("score") or 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||
item["position"] = position # type: ignore
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
@@ -665,7 +665,7 @@ class GraphEngine:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid.uuid4()),
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
@@ -680,7 +680,7 @@ class GraphEngine:
|
||||
start_at=retry_start_at,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
continue
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
|
||||
@@ -107,10 +107,8 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case "application/pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case "application/msword":
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | "application/msword":
|
||||
return _extract_text_from_doc(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case "text/csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
|
||||
@@ -144,10 +142,8 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case ".pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case ".doc":
|
||||
case ".doc" | ".docx":
|
||||
return _extract_text_from_doc(file_content)
|
||||
case ".docx":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case ".csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
case ".xls" | ".xlsx":
|
||||
@@ -207,33 +203,7 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC file.
|
||||
"""
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOCX file.
|
||||
Extract text from a DOC/DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
@@ -285,13 +255,13 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC: {e}")
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
def _download_file_content(file: File) -> bytes:
|
||||
@@ -359,29 +329,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
@@ -27,3 +30,20 @@ class EndNode(BaseNode[EndNodeData]):
|
||||
inputs=outputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -87,6 +88,23 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||
def _should_not_use_old_function(
|
||||
|
||||
@@ -590,7 +590,6 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
with flask_app.app_context():
|
||||
parallel_mode_run_id = uuid.uuid4().hex
|
||||
graph_engine_copy = graph_engine.create_copy()
|
||||
graph_engine_copy.graph_runtime_state.total_tokens = 0
|
||||
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
||||
variable_pool_copy.add([self.node_id, "index"], index)
|
||||
variable_pool_copy.add([self.node_id, "item"], item)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@@ -18,3 +21,16 @@ class IterationStartNode(BaseNode):
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import VisionConfig
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
@@ -75,48 +73,6 @@ class SingleRetrievalConfig(BaseModel):
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"starts with",
|
||||
"ends with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"is not empty",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Conditon detail
|
||||
"""
|
||||
|
||||
metadata_name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
@@ -128,7 +84,3 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@@ -16,7 +16,3 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError):
|
||||
|
||||
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model provider quota is exceeded."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
@@ -11,38 +9,21 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.list_operator.exc import InvalidConditionError
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
from extensions.ext_database import db
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document
|
||||
from models.dataset import Dataset, Document
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
ModelCredentialsNotInitializedError,
|
||||
ModelNotExistError,
|
||||
@@ -61,14 +42,13 @@ default_retrieval_model = {
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMNode):
|
||||
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@@ -83,7 +63,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
)
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||
outputs = {"result": results}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||
@@ -137,9 +117,6 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
metadata_filter_document_ids = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
@@ -169,7 +146,6 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
@@ -282,134 +258,6 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
item["metadata"]["position"] = position
|
||||
return retrieval_resource_list
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> dict[str, list[str]]:
|
||||
document_query = db.session.query(Document.id).filter(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
||||
if automatic_metadata_filters:
|
||||
for filter in automatic_metadata_filters:
|
||||
self._process_metadata_filter_func(
|
||||
filter.get("condition"), filter.get("metadata_name"), filter.get("value"), document_query
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
for condition in node_data.metadata_filtering_conditions.conditions:
|
||||
metadata_name = condition.metadata_name
|
||||
expected_value = condition.value
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).text
|
||||
self._process_metadata_filter_func(
|
||||
condition.comparison_operator, metadata_name, expected_value, document_query
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
documnents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list)
|
||||
for document in documnents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id)
|
||||
return metadata_filter_document_ids
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
metadata_model_config = node_data.metadata_model_config
|
||||
if metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
metadata_fields=all_metadata_fields,
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
sys_files=[],
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
def _process_metadata_filter_func(*, condition: str, metadata_name: str, value: str, query):
|
||||
match condition:
|
||||
case "contains":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}%"))
|
||||
case "not contains":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].notlike(f"%{value}%"))
|
||||
case "start with":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"{value}%"))
|
||||
case "end with":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].like(f"%{value}"))
|
||||
case "is", "=":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] == value)
|
||||
case "is not", "≠":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] != value)
|
||||
case "is empty":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].is_(None))
|
||||
case "is not empty":
|
||||
query = query.filter(Document.doc_metadata[metadata_name].isnot(None))
|
||||
case "before", "<":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] < value)
|
||||
case "after", ">":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] > value)
|
||||
case "≤", ">=":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] <= value)
|
||||
case "≥", ">=":
|
||||
query = query.filter(Document.doc_metadata[metadata_name] >= value)
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@@ -495,94 +343,3 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
@@ -30,7 +29,6 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables import (
|
||||
@@ -249,24 +247,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
content = invoke_result.message.content
|
||||
if content is None:
|
||||
message_text = ""
|
||||
elif isinstance(content, str):
|
||||
message_text = content
|
||||
elif isinstance(content, list):
|
||||
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
||||
message_text = "".join(
|
||||
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
||||
)
|
||||
else:
|
||||
message_text = str(content)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=message_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
return
|
||||
|
||||
model = None
|
||||
@@ -760,17 +740,11 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_name == model_instance.provider,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
).update({"quota_used": Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@@ -20,3 +23,13 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: StartNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
@@ -33,3 +36,16 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
break
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@@ -64,10 +64,6 @@ class ConditionProcessor:
|
||||
expected=expected_value,
|
||||
)
|
||||
group_results.append(result)
|
||||
# Implemented short-circuit evaluation for logical conditions
|
||||
if (operator == "and" and not result) or (operator == "or" and result):
|
||||
final_result = result
|
||||
return input_conditions, group_results, final_result
|
||||
|
||||
final_result = all(group_results) if operator == "and" else any(group_results)
|
||||
return input_conditions, group_results, final_result
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user