Compare commits

..

106 Commits

Author SHA1 Message Date
CodingOnStar
ebabbc8dce refactor: update Content Security Policy to allow 'wasm-unsafe-eval' in script-src directive 2025-10-20 14:26:09 +08:00
CodingOnStar
9a362d9ac6 refactor: improve nonce extraction from Content Security Policy in Google Analytics component 2025-10-20 14:26:09 +08:00
CodingOnStar
00af0ed63c refactor: update Content Security Policy to allow 'wasm-unsafe-eval' and set nonce in response headers 2025-10-20 14:26:09 +08:00
CodingOnStar
138a8b9f7a refactor: update Content Security Policy to allow 'wasm-unsafe-eval' and set nonce in response headers 2025-10-20 14:26:09 +08:00
CodingOnStar
ed6fd6f3d9 refactor: update Content Security Policy to include 'strict-dynamic' and improve nonce handling in Google Analytics component 2025-10-20 14:26:09 +08:00
CodingOnStar
3d7c7acba6 refactor: update Google Analytics component to use window.gtag for event tracking 2025-10-20 14:26:09 +08:00
CodingOnStar
70600ea861 refactor: remove console logs from Google Analytics component 2025-10-20 14:26:09 +08:00
CodingOnStar
b6c196fdbe feat: enhance filter components with Google Analytics event tracking on clear actions 2025-10-20 14:26:09 +08:00
CodingOnStar
81500b9b19 feat: add Google Analytics event tracking for filter selections in chart and workflow logs 2025-10-20 14:26:09 +08:00
CodingOnStar
eb7e03b0ce feat: integrate Google Analytics event tracking utility 2025-10-20 14:26:07 +08:00
Guangdong Liu
34fbcc9457 fix: ensure document re-querying in indexing process for consistency (#27077) 2025-10-20 14:12:39 +08:00
yangzheli
9cc8ac981b fix(web): improve UI consistency and remove related unused icons (#27004) 2025-10-20 14:03:16 +08:00
zyssyz123
1153dcef69 fix: delete migrate sync data script (#27061) 2025-10-20 14:54:24 +09:00
white-loub
f811471b18 fix: support structured output in streaming mode for LLM node (#27089) 2025-10-20 14:53:25 +09:00
hj24
2382229c7d fix variable-truncator max size comments (#27129) 2025-10-20 14:52:40 +09:00
crazywoola
f0e739be43 fix: immer version and ref in code base (#27130) 2025-10-20 14:49:26 +09:00
-LAN-
4dccdf9478 Ensure suggested questions parser returns typed sequence (#27104)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-20 13:01:09 +08:00
GuanMu
4c37d650d3 fix: update attribute types to allow undefined values in icon utilities (#27121) 2025-10-20 12:55:37 +08:00
Guangdong Liu
1b334e6966 fix: handle None values in dataset and document deletion logic (#27083) 2025-10-20 12:52:48 +08:00
crazywoola
d463bd6323 Revert "chore(deps): bump immer from 9.0.21 to 10.1.3 in /web" (#27119)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-20 11:28:45 +08:00
GuanMu
8c298b33cd Fix frontend type error (#27116) 2025-10-20 11:27:18 +08:00
非法操作
dc1a380888 chore: improve storybook (#27111) 2025-10-20 10:17:17 +08:00
dependabot[bot]
7e9be4d3d9 chore(deps): bump immer from 9.0.21 to 10.1.3 in /web (#27113)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-20 10:16:35 +08:00
dependabot[bot]
5579521ffc chore(deps-dev): bump cross-env from 7.0.3 to 10.1.0 in /web (#27112)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-20 10:12:30 +08:00
dependabot[bot]
ab1059134d chore(deps): bump pydantic-settings from 2.9.1 to 2.11.0 in /api (#27114)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-20 10:12:16 +08:00
dependabot[bot]
fe2ac66a52 chore(deps): bump html-to-image from 1.11.11 to 1.11.13 in /web (#27109)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-20 09:37:10 +08:00
dependabot[bot]
f87db2652b chore(deps): bump @lexical/selection from 0.36.2 to 0.37.0 in /web (#27108)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-20 09:37:02 +08:00
-LAN-
3f9f02b9e7 docs: mention backend lint gate in AGENTS (#27102)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-20 09:36:41 +08:00
-LAN-
578247ffbc feat(graph_engine): Support pausing workflow graph executions (#26585)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-10-19 21:33:41 +08:00
-LAN-
9a5f214623 refactor: replace localStorage with HTTP-only cookies for auth tokens (#24365)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Yunlu Wen <wylswz@163.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: GareArc <chen4851@purdue.edu>
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com>
Co-authored-by: minglu7 <1347866672@qq.com>
Co-authored-by: Ponder <ruan.lj@foxmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: heyszt <270985384@qq.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Guangdong Liu <liugddx@gmail.com>
Co-authored-by: Eric Guo <eric.guocz@gmail.com>
Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: XlKsyt <caixuesen@outlook.com>
Co-authored-by: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: GuanMu <ballmanjq@gmail.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Tonlo <123lzs123@gmail.com>
Co-authored-by: Yusuke Yamada <yamachu.dev@gmail.com>
Co-authored-by: Novice <novice12185727@gmail.com>
Co-authored-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com>
Co-authored-by: znn <jubinkumarsoni@gmail.com>
Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com>
2025-10-19 21:29:04 +08:00
QuantumGhost
141ca8904a fix(api): ensure JSON responses are properly serialized in ApiTool (#27097)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-19 18:56:02 +08:00
Asuka Minato
4488c090b2 fluent api (#27093)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-19 12:54:41 +09:00
Bowen Liang
59c1fde351 doc: add Grafana dashboard template link to docs (#27079)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-18 23:24:35 +08:00
GuanMu
cf7ff76165 fix(web): resolve TypeScript type errors in workflow components (#27086) 2025-10-18 23:09:00 +08:00
Jacky Su
ac79691d69 Feat/add status filter to workflow runs (#26850)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Jacky Su <jacky_su@trendmicro.com>
2025-10-18 12:15:29 +08:00
GuanMu
1a37989769 Fix type-check error (#27051) 2025-10-18 12:03:40 +08:00
Amy
830f891a74 Fix json in md when use quesion classifier node (#26992)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-18 11:58:40 +08:00
Eric Guo
5937a66e22 Sync same logic for datasets. (#27056) 2025-10-18 11:49:20 +08:00
wangxiaolei
894e38f713 fix: https://github.com/langgenius/dify/issues/27063 (#27074) 2025-10-18 11:47:04 +08:00
Guangdong Liu
e4b5b0e5fd feat: implement strict type validation for remote file uploads (#27010) 2025-10-18 11:44:11 +08:00
Guangdong Liu
598dd1f816 fix: allow optional config parameter and conditionally include message file ID (#26960) 2025-10-18 11:43:24 +08:00
Yongtao Huang
35e24d4d14 Chore: remove redundant tenant lookup in APIBasedExtensionAPI.post (#27067)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
2025-10-18 09:54:52 +08:00
GuanMu
fea2ffb3ba fix: improve URL validation logic in validateRedirectUrl function (#27058)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-17 17:46:28 +08:00
Wu Tianwei
64f55d55a1 fix: update TopK and Score Threshold components to use InputNumber and improve value handling (#27045) 2025-10-17 14:58:30 +08:00
2h0ng
bfda4ce7e6 Merge commit from fork 2025-10-17 14:58:15 +08:00
GuanMu
4f7cb7cd2a Fix type error (#27044) 2025-10-17 14:42:58 +08:00
NeatGuyCoding
6517323add Feature: add test containers based tests for mail register tasks (#27040)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-10-17 14:29:56 +08:00
NFish
531a0b755a fix: show 'Invalid email or password' error tip when web app login failed (#27034) 2025-10-17 14:03:34 +08:00
Joel
91bb8ae4d2 fix: happy-dom security issues (#27037) 2025-10-17 13:42:56 +08:00
GuanMu
8cafc20098 Fix type error (#27024)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-17 10:46:43 +08:00
-LAN-
9d5300440c Restore coverage for skipped workflow tests (#27018)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-17 09:11:48 +08:00
Guangdong Liu
58524d6d2b fix: remove unnecessary properties from condition draft (#27009)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-17 09:11:03 +08:00
Asuka Minato
19cc6ea993 fix 27003 (#27005)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-17 09:10:16 +08:00
quicksand
d7f0a31e24 Fix: User Context Loss When Invoking Workflow Tool Node in Knowledge … (#26495) 2025-10-17 09:09:45 +08:00
Yongtao Huang
312974aa20 Chore: remove unused class-level variables in DatasourceManager (#27011)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-17 09:07:28 +08:00
Dhruv Gorasiya
d19c100166 fix: logical error in Weaviate distance calculation (#27019)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-17 09:06:50 +08:00
Dhruv Gorasiya
a8ad80c405 Fixed Weaviate no module found issue (issue #26938) (#26964)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-16 22:41:48 +08:00
GuanMu
650e38e17f refactor: improve TypeScript types for NodeCardProps and debug configuration context (#27001) 2025-10-16 22:16:01 +08:00
-LAN-
24612adf2c Fix dispatcher idle hang and add pytest timeouts (#26998) 2025-10-16 22:15:03 +08:00
Xiyuan Chen
06649f6c21 Update email templates to improve clarity and consistency in messagin… (#26970)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Check i18n Files and Create PR / check-and-update (push) Has been cancelled
2025-10-16 01:42:22 -07:00
Yongtao Huang
8b61f5e9c4 Fix: avoid duplicate response_chunk update in convert_stream_simple_response (#26965)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-16 15:53:07 +08:00
GuanMu
6432898e7a refactor: update TypeScript definitions for custom JSX elements and clean up global declarations in emoji picker (#26985) 2025-10-16 15:51:39 +08:00
Asuka Minato
cced33d068 use deco to avoid current_user (#26077)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-16 15:45:51 +09:00
Xiyuan Chen
bd01af6415 fix: update load balancing configurations with new credential IDs and… (#26900) 2025-10-15 21:15:26 -07:00
wellCh4n
35011b810d feat: run with params from logs (#26787)
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
2025-10-16 11:01:11 +08:00
Xin Zhang
f295c7532c fix plugin installation permissions when using a local pkg (#26822)
Co-authored-by: zhangx1n <zhangxin@dify.ai>
2025-10-16 10:58:28 +08:00
zyssyz123
7065b67d07 add app mode for message (#26876)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-16 10:19:49 +08:00
GuanMu
c0b50ef61d chore: remove unused icon components and related features from the co… (#26933)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Check i18n Files and Create PR / check-and-update (push) Waiting to run
2025-10-15 16:48:02 +08:00
-LAN-
1d8cca4fa2 Fix: check external commands after node completion (#26891)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-15 16:47:43 +08:00
Wu Tianwei
3474c179e6 fix: enhance dataset menu and add service API translations (#26931) 2025-10-15 16:46:46 +08:00
GuanMu
433dad7e1a chore: add type-check script to package.json for TypeScript validation (#26929) 2025-10-15 16:37:46 +08:00
github-actions[bot]
be7ee380bc chore: translate i18n files and update type definitions (#26916)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-15 16:36:39 +08:00
yangzheli
cff5de626b feat(agent): similar to the start node of workflow, agent variables also support drag-and-drop (#26899) 2025-10-15 13:07:51 +08:00
znn
4d8b8f9210 allow editing of hidden inputs in preview (#24370)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2025-10-15 11:19:53 +08:00
Ademílson Tonato
a16ef7e73c refactor: Update Firecrawl to use v2 API (#24734)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-15 10:48:54 +08:00
kenwoodjw
c39dae06d4 fix: workflow token usage (#26723)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-10-15 10:39:51 +08:00
Novice
f906e70f6b chore: remove redundant dependencies (#26907)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-15 09:55:39 +08:00
lyzno1
5139119307 chore: bump pnpm version (#26905)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
2025-10-15 09:55:05 +08:00
Yusuke Yamada
1b537f904a fix: replace CodeGroup's POST /meta with GET /site (#26886) 2025-10-15 09:43:10 +08:00
-LAN-
556b631c54 Normalize null metadata handling in tool entities (#26890) 2025-10-15 09:42:22 +08:00
NeatGuyCoding
49df9ceaf3 minor fix: test cases for alibabacloud mysql and chinese translations (#26902)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-15 09:41:12 +08:00
Tonlo
92ec1ac27a Fix/remove logo in withoutbrand template (#26882) 2025-10-15 09:40:33 +08:00
-LAN-
e74097afdf Remove unused after_request hooks from console API keys (#26896)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-15 00:43:11 +08:00
Asuka Minato
8ddc4f2292 example to auto rollback (#26200)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-15 00:42:55 +09:00
非法操作
7b51320346 fix: when create provider credential set the provider record to vaild (#26868)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 19:42:48 +08:00
GuanMu
9e39be0770 fix: correct indentation in JSON payloads (#26871) 2025-10-14 19:41:01 +08:00
GuanMu
3e5e87930c feat: add Knip configuration for dead code detection and remove unused icon components (#26758)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 19:06:31 +08:00
hj24
15a5ba67f1 fix: use account id in workflow app log filter (#26811)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Check i18n Files and Create PR / check-and-update (push) Waiting to run
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 14:32:40 +08:00
github-actions[bot]
9e3b4dc90d chore: translate i18n files and update type definitions (#26859)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
2025-10-14 10:43:28 +08:00
Dhruv Gorasiya
48c42a9fba Weaviate update version (#25447)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-14 10:39:53 +08:00
XlKsyt
0b35bc1ede feat: add Tencent Cloud APM tracing integration (#25657)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:21:17 +08:00
Davide Delbianco
8e01bb40fe fix: Do not show the toggle button for chat input when all input hidden (#26826)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:15:06 +08:00
Guangdong Liu
9d21772820 fix: Validate transfer method in file mapping and improve file input handling (#26848)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:10:31 +08:00
NeatGuyCoding
b745839bdb Feature add test containers mail owner transfer task (#26854)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-10-14 10:01:47 +08:00
Eric Guo
59ad6e02ce Add timeout so any plugin daemon call (including the SSE path) that legitimately takes longer than 5s would right. (#26852)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-14 09:23:27 +08:00
Guangdong Liu
a3b33cbe28 refactor: streamline database session usage in batch_create_segment_to_index_task (#26795)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-10-14 09:22:48 +08:00
Davide Delbianco
7b8540281a fix: Chat Opener visibility flickering (#26836) 2025-10-14 09:21:00 +08:00
Asuka Minato
0a6b78f883 Use hook to get userid (#26839)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-14 09:20:37 +08:00
heyszt
56ee8f7d64 fix: files/support-type JSON serialization error (#26842) 2025-10-14 09:20:19 +08:00
Davide Delbianco
3cfcd32876 chore: Fix 25795 (#26823)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-13 17:44:51 +08:00
Davide Delbianco
06dcb55a9d chore: Don't show chat input area scrollbar overflow (#26828) 2025-10-13 17:43:46 +08:00
Ponder
ec6cafd7aa feat: Cache AppQueueManager.is_stopped() to reduce unnecessary Redis … (#26778) 2025-10-13 17:41:16 +08:00
Davide Delbianco
6e9858960d chore: Fix chat-input-area resize (#26824) 2025-10-13 17:36:15 +08:00
minglu7
150a8276b9 fix: avoid closing shared session during embeddings (#26830) 2025-10-13 17:36:00 +08:00
Davide Delbianco
c6a90d4bb3 fix: Don't hide chat streaming loader on '\n' content (#26829) 2025-10-13 17:31:52 +08:00
Davide Delbianco
c71fd7113c chore: Correct padding in embedded chatbot (#26832) 2025-10-13 17:29:47 +08:00
1173 changed files with 17943 additions and 45189 deletions

View File

@@ -1,7 +1,6 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
npm add -g pnpm@10.15.0
corepack enable
cd web && pnpm install
pipx install uv

View File

@@ -39,25 +39,11 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@@ -93,3 +79,19 @@ jobs:
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@@ -2,8 +2,6 @@ name: autofix.ci
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: read

View File

@@ -1,6 +1,7 @@
#!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
@@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

5
.gitignore vendored
View File

@@ -234,7 +234,4 @@ scripts/stress-test/reports/
# mcp
.playwright-mcp/
.serena/
# settings
*.local.json
.serena/

View File

@@ -14,7 +14,7 @@ The codebase is split into:
- Run backend CLI commands through `uv run --project api <command>`.
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.

View File

@@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
## Advanced Setup
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)

View File

@@ -454,9 +454,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Webhook request configuration
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
@@ -534,12 +531,6 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
# Interval time in minutes for polling scheduled workflows(default: 1 min)
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Position configuration
POSITION_TOOL_PINS=

View File

@@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion,workflow"
"dataset,generation,mail,ops_trace,app_deletion"
]
}
]

View File

@@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@@ -1227,55 +1227,6 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.

View File

@@ -174,17 +174,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class TriggerConfig(BaseSettings):
"""
Configuration for trigger
"""
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
description="Maximum allowed size for webhook request bodies in bytes",
default=10485760,
)
class PluginConfig(BaseSettings):
"""
Plugin configs
@@ -200,6 +189,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key",
)
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
@@ -554,7 +548,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB
# 1000 KiB
1024_000,
description="Maximum size for variable to trigger final truncation.",
)
@@ -996,22 +990,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable check upgradable plugin task",
default=True,
)
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
description="Enable workflow schedule poller task",
default=True,
)
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
description="Workflow schedule poller interval in minutes",
default=1,
)
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
description="Maximum number of schedules to process in each poll batch",
default=100,
)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
default=0,
)
class PositionConfig(BaseSettings):
@@ -1135,7 +1113,6 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
TriggerConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,

View File

@@ -55,3 +55,12 @@ else:
"properties",
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@@ -9,7 +9,6 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
"""
@@ -42,11 +41,3 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
ContextVar("plugin_trigger_providers")
)
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_trigger_providers_lock")
)

View File

@@ -66,7 +66,6 @@ from .app import (
workflow_draft_variable,
workflow_run,
workflow_statistic,
workflow_trigger,
)
# Import auth controllers
@@ -127,7 +126,6 @@ from .workspace import (
models,
plugin,
tool_providers,
trigger_providers,
workspace,
)
@@ -198,7 +196,6 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trigger_providers",
"version",
"website",
"workflow",
@@ -206,6 +203,5 @@ __all__ = [
"workflow_draft_variable",
"workflow_run",
"workflow_statistic",
"workflow_trigger",
"workspace",
]

View File

@@ -15,6 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
@@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization")
if auth_header is None:
auth_token = extract_access_token(request)
if not auth_token:
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
@@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
@only_edition_cloud
@admin_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("desc", type=str, location="json")
parser.add_argument("copyright", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()

View File

@@ -7,13 +7,12 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from . import api, console_ns
from .wraps import account_initialization_required, setup_required
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
@@ -57,9 +56,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@@ -68,15 +67,12 @@ class BaseApiKeyListResource(Resource):
return {"items": keys}
@marshal_with(api_key_fields)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.has_edit_permission:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@@ -93,7 +89,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@@ -112,9 +108,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
@@ -158,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -179,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -208,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
@@ -229,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@@ -25,11 +25,13 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
parser.add_argument("model_name", type=str, required=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args")
.add_argument("model_mode", type=str, required=True, location="args")
.add_argument("has_context", type=str, required=False, default="true", location="args")
.add_argument("model_name", type=str, required=True, location="args")
)
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)

View File

@@ -27,9 +27,11 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args")
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
)
args = parser.parse_args()

View File

@@ -1,15 +1,14 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_redis import redis_client
@@ -42,15 +41,15 @@ class AnnotationReplyActionApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
@@ -69,10 +68,8 @@ class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
@@ -98,15 +95,12 @@ class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@@ -124,10 +118,8 @@ class AnnotationReplyActionStatusApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
@@ -159,10 +151,8 @@ class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
@@ -198,14 +188,14 @@ class AnnotationApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
@@ -213,10 +203,8 @@ class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly
@@ -249,10 +237,8 @@ class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
@@ -271,16 +257,16 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@@ -288,10 +274,8 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@@ -310,10 +294,8 @@ class AnnotationBatchImportApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# check file
if "file" not in request.files:
@@ -341,10 +323,8 @@ class AnnotationBatchImportStatusApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key)
@@ -376,10 +356,8 @@ class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)

View File

@@ -1,7 +1,5 @@
import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import Account, App
from models import App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -56,6 +55,7 @@ class AppListApi(Resource):
@enterprise_license_required
def get(self):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value):
try:
@@ -63,34 +63,36 @@ class AppListApi(Resource):
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@@ -129,30 +131,26 @@ class AppListApi(Resource):
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
app = app_service.create_app(current_tenant_id, args, current_user)
return app, 201
@@ -205,21 +203,20 @@ class AppApi(Resource):
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site)
def put(self, app_model):
"""Update app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
app_service = AppService()
@@ -248,12 +245,9 @@ class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_model):
"""Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
app_service.delete_app(app_model)
@@ -283,27 +277,28 @@ class AppCopyApi(Resource):
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
@@ -340,16 +335,15 @@ class AppExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
return {
@@ -371,13 +365,9 @@ class AppNameApi(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
@@ -408,14 +398,13 @@ class AppIconApi(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
app_service = AppService()
@@ -441,13 +430,9 @@ class AppSiteStatus(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
@@ -475,11 +460,11 @@ class AppApiStatus(Resource):
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_api", type=bool, required=True, location="json")
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
@@ -520,13 +505,14 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id):
# add app trace
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config(

View File

@@ -1,20 +1,16 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
@@ -30,28 +26,29 @@ class AppImportApi(Resource):
@account_initialization_required
@marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
@@ -83,16 +80,16 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@edit_permission_required
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
@@ -109,10 +106,8 @@ class AppImportCheckDependenciesApi(Resource):
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
@edit_permission_required
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
def post(self, app_model: App):
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
@@ -166,8 +168,7 @@ class TextModesApi(Resource):
@account_initialization_required
def get(self, app_model):
try:
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
response = AudioService.transcript_tts_voices(

View File

@@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api, console_ns
@@ -15,7 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"
@@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"

View File

@@ -1,17 +1,16 @@
from datetime import datetime
import pytz # pip install pytz
import pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@@ -22,8 +21,8 @@ from fields.conversation_fields import (
)
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@@ -57,18 +56,24 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
query = sa.select(Conversation).where(
@@ -84,6 +89,7 @@ class CompletionConversationApi(Resource):
)
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -137,9 +143,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@@ -154,14 +159,12 @@ class CompletionConversationDetailApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -206,26 +209,32 @@ class ChatConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
@@ -260,6 +269,7 @@ class ChatConversationApi(Resource):
)
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -341,9 +351,8 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@@ -358,14 +367,12 @@ class ChatConversationDetailApi(Resource):
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -374,6 +381,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", type=str, location="args")
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
stmt = (

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@@ -12,13 +11,12 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
@@ -44,16 +42,18 @@ class RuleGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
@@ -94,17 +94,19 @@ class RuleCodeGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
@@ -141,15 +143,17 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
)
@@ -190,20 +194,25 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
parser = (
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
@@ -221,21 +230,21 @@ class InstructionGenerateApi(Resource):
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
@@ -244,7 +253,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
@@ -253,7 +262,7 @@ class InstructionGenerateApi(Resource):
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
@@ -292,8 +301,7 @@ class InstructionGenerationTemplateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
case "prompt":

View File

@@ -1,16 +1,15 @@
import json
from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
@@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
@api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
@@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
)
@api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_fields)
@edit_permission_required
def post(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
description = args.get("description")
@@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16),
)
db.session.add(server)
@@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_fields)
@edit_permission_required
def put(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server:
@@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
@edit_permission_required
def get(self, server_id):
if not current_user.is_editor:
raise NotFound()
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
)
if not server:

View File

@@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns
from controllers.console.app.error import (
@@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -26,8 +27,7 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
@@ -56,19 +56,19 @@ class ChatMessageListApi(Resource):
)
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
conversation = (
@@ -154,12 +154,13 @@ class MessageFeedbackApi(Resource):
@login_required
@account_initialization_required
def post(self, app_model):
if current_user is None:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
message_id = str(args["message_id"])
@@ -211,23 +212,21 @@ class MessageAnnotationApi(Resource):
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
@account_initialization_required
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
@@ -270,6 +269,7 @@ class MessageSuggestedQuestionApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id)
try:
@@ -304,12 +304,12 @@ class MessageApi(Resource):
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(message_detail_fields)
def get(self, app_model, message_id):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@@ -2,7 +2,6 @@ import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
@@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
)
@@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
@@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,

View File

@@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
@@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:

View File

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@@ -9,30 +8,36 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Account, Site
from libs.login import current_account_with_tenant, login_required
from models import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
parser.add_argument(
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args()
@@ -76,9 +81,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@@ -107,8 +113,6 @@ class AppSite(Resource):
if value is not None:
setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()
@@ -131,6 +135,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -140,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound
site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()

View File

@@ -4,7 +4,6 @@ from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
@@ -37,11 +36,13 @@ class DailyMessageStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -53,6 +54,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -109,13 +111,15 @@ class DailyConversationStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -175,11 +179,13 @@ class DailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -191,7 +197,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -247,11 +253,13 @@ class DailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -264,7 +272,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -322,11 +330,13 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -346,7 +356,7 @@ FROM
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -413,11 +423,13 @@ class UserSatisfactionRateStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -433,7 +445,7 @@ WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -494,11 +506,13 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -510,7 +524,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -566,11 +580,13 @@ class TokensPerSecondStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -585,7 +601,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@@ -12,14 +12,13 @@ import services
from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import file_factory, variable_factory
@@ -28,21 +27,13 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from models.account import Account
from models.model import AppMode
from models.provider_ids import TriggerProviderID
from models.workflow import NodeType, Workflow
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.trigger.trigger_debug_service import (
PluginTriggerDebugEvent,
TriggerDebugService,
WebhookDebugEvent,
)
from services.trigger.webhook_service import WebhookService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@@ -78,15 +69,11 @@ class DraftWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
@@ -118,24 +105,24 @@ class DraftWorkflowApi(Resource):
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@@ -157,10 +144,6 @@ class DraftWorkflowApi(Resource):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
try:
@@ -214,24 +197,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App):
"""
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
@@ -279,18 +259,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
@@ -331,18 +306,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
@@ -383,19 +353,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
@@ -436,19 +400,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
@@ -488,20 +446,17 @@ class DraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Run draft workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
@@ -534,17 +489,11 @@ class WorkflowTaskStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str):
"""
Stop workflow task
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
@@ -576,21 +525,18 @@ class DraftWorkflowNodeRunApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields)
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
user_inputs = args.get("inputs")
@@ -630,17 +576,11 @@ class PublishedWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get published workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model)
@@ -652,19 +592,17 @@ class PublishedWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Publish workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
# Validate name and comment length
@@ -710,17 +648,11 @@ class DefaultBlockConfigsApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_configs()
@@ -737,18 +669,12 @@ class DefaultBlockConfigApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
q = args.get("q")
@@ -777,24 +703,23 @@ class ConvertToWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@edit_permission_required
def post(self, app_model: App):
"""
Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
else:
args = {}
@@ -820,21 +745,20 @@ class PublishedAllWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get published workflows
"""
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
@@ -887,19 +811,17 @@ class WorkflowByIdApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
@@ -942,16 +864,11 @@ class WorkflowByIdApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str):
"""
Delete workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService()
# Create a session and manage the transaction
@@ -998,417 +915,3 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger")
class DraftWorkflowTriggerNodeApi(Resource):
"""
Single node debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
"""
@api.doc("poll_draft_workflow_trigger_node")
@api.doc(description="Poll for trigger events and execute single node when event arrives")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerNodeRequest",
{
"event_name": fields.String(required=True, description="Event name"),
"subscription_id": fields.String(required=True, description="Subscription ID"),
"provider_id": fields.String(required=True, description="Provider ID"),
},
)
)
@api.response(200, "Trigger event received and node executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Poll for trigger events and execute single node when event arrives
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("event_name", type=str, required=True, location="json", nullable=False)
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
parser.add_argument("provider_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
event_name = args["event_name"]
subscription_id = args["subscription_id"]
provider_id = args["provider_id"]
pool_key = PluginTriggerDebugEvent.build_pool_key(
tenant_id=app_model.tenant_id,
provider_id=provider_id,
subscription_id=subscription_id,
event_name=event_name,
)
event: PluginTriggerDebugEvent | None = TriggerDebugService.poll(
event_type=PluginTriggerDebugEvent,
pool_key=pool_key,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
if not event:
return jsonable_encoder({"status": "waiting"})
try:
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
user_inputs = event.model_dump()
node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query="",
files=[],
)
return jsonable_encoder(node_execution)
except Exception:
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{
"status": "error",
}
), 500
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/plugin/run")
class DraftWorkflowTriggerRunApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
"""
@api.doc("poll_draft_workflow_trigger_run")
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerRunRequest",
{
"node_id": fields.String(required=True, description="Node ID"),
},
)
)
@api.response(200, "Trigger event received and workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Poll for trigger events and execute full workflow when event arrives
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService()
workflow: Workflow | None = workflow_service.get_draft_workflow(
app_model=app_model,
workflow_id=None,
)
if not workflow:
return jsonable_encoder({"status": "error", "message": "Workflow not found"}), 404
node_data = workflow.get_node_config_by_id(node_id=node_id).get("data")
if not node_data:
return jsonable_encoder({"status": "error", "message": "Node config not found"}), 404
event_name = node_data.get("event_name")
subscription_id = node_data.get("subscription_id")
if not subscription_id:
return jsonable_encoder({"status": "error", "message": "Subscription ID not found"}), 404
provider_id = TriggerProviderID(node_data.get("provider_id"))
pool_key: str = PluginTriggerDebugEvent.build_pool_key(
tenant_id=app_model.tenant_id,
provider_id=provider_id,
subscription_id=subscription_id,
event_name=event_name,
)
event: PluginTriggerDebugEvent | None = TriggerDebugService.poll(
event_type=PluginTriggerDebugEvent,
pool_key=pool_key,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": 2000})
workflow_args = {
"inputs": event.model_dump(),
"query": "",
"files": [],
}
external_trace_id = get_external_trace_id(request)
if external_trace_id:
workflow_args["external_trace_id"] = external_trace_id
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger run")
return jsonable_encoder(
{
"status": "error",
}
), 500
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/webhook/run")
class DraftWorkflowTriggerWebhookRunApi(Resource):
"""
Full workflow debug when the start node is a webhook trigger
Path: /apps/<uuid:app_id>/workflows/draft/trigger/webhook/run
"""
@api.doc("draft_workflow_trigger_webhook_run")
@api.doc(description="Full workflow debug when the start node is a webhook trigger")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerWebhookRunRequest",
{
"node_id": fields.String(required=True, description="Node ID"),
},
)
)
@api.response(200, "Workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Full workflow debug when the start node is a webhook trigger
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
pool_key = WebhookDebugEvent.build_pool_key(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
node_id=node_id,
)
event: WebhookDebugEvent | None = TriggerDebugService.poll(
event_type=WebhookDebugEvent,
pool_key=pool_key,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": 2000})
payload = event.payload or {}
workflow_inputs = payload.get("inputs")
if workflow_inputs is None:
webhook_data = payload.get("webhook_data", {})
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
workflow_args = {
"inputs": workflow_inputs or {},
"query": "",
"files": [],
}
external_trace_id = get_external_trace_id(request)
if external_trace_id:
workflow_args["external_trace_id"] = external_trace_id
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger webhook run")
return jsonable_encoder(
{
"status": "error",
}
), 500
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/debug/webhook/run")
class DraftWorkflowNodeWebhookDebugRunApi(Resource):
"""Single node debug when the node is a webhook trigger."""
@api.doc("draft_workflow_node_webhook_debug_run")
@api.doc(description="Poll for webhook debug payload and execute single node when event arrives")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Node executed successfully")
@api.response(403, "Permission denied")
@api.response(400, "Invalid node type")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
pool_key = WebhookDebugEvent.build_pool_key(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
node_id=node_id,
)
event: WebhookDebugEvent | None = TriggerDebugService.poll(
event_type=WebhookDebugEvent,
pool_key=pool_key,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": 2000})
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise DraftWorkflowNotExist()
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
if node_type != NodeType.TRIGGER_WEBHOOK:
return jsonable_encoder(
{
"status": "error",
"message": "node is not webhook trigger",
}
), 400
payload = event.payload or {}
workflow_inputs = payload.get("inputs")
if workflow_inputs is None:
webhook_data = payload.get("webhook_data", {})
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=workflow_inputs or {},
account=current_user,
query="",
files=[],
)
return jsonable_encoder(workflow_node_execution)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/schedule/run")
class DraftWorkflowTriggerScheduleRunApi(Resource):
"""
Full workflow debug when the start node is a schedule trigger
Path: /apps/<uuid:app_id>/workflows/draft/trigger/schedule/run
"""
@api.doc("draft_workflow_trigger_schedule_run")
@api.doc(description="Full workflow debug when the start node is a schedule trigger")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerScheduleRunRequest",
{
"node_id": fields.String(required=True, description="Node ID"),
},
)
)
@api.response(200, "Workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Full workflow debug when the start node is a schedule trigger
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
workflow_args = {
"inputs": {},
"query": "",
"files": [],
}
try:
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger schedule run")
return jsonable_encoder(
{
"status": "error",
}
), 500

View File

@@ -42,33 +42,35 @@ class WorkflowAppLogApi(Resource):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None

View File

@@ -22,8 +22,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
from models import Account, App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
@@ -58,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
@@ -320,10 +321,11 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),

View File

@@ -1,6 +1,5 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
@@ -9,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import (
advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import login_required
from models import Account, App, AppMode, EndUser
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource):
@@ -25,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@setup_required
@login_required
@@ -35,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
"""
Get advanced chat app workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args = _parse_workflow_run_list_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@api.doc("get_advanced_chat_workflow_runs_count")
@api.doc(description="Get advanced chat workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result
@@ -52,6 +170,8 @@ class WorkflowRunListApi(Resource):
@api.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@setup_required
@login_required
@@ -62,13 +182,64 @@ class WorkflowRunListApi(Resource):
"""
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args = _parse_workflow_run_list_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource):
@api.doc("get_workflow_runs_count")
@api.doc(description="Get workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result

View File

@@ -4,7 +4,6 @@ from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse
from controllers.console import api, console_ns
@@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
@@ -29,11 +28,13 @@ class WorkflowDailyRunsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -49,7 +50,7 @@ WHERE
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -97,11 +98,13 @@ class WorkflowDailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -117,7 +120,7 @@ WHERE
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -165,11 +168,13 @@ class WorkflowDailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -185,7 +190,7 @@ WHERE
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@@ -238,11 +243,13 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@@ -271,7 +278,7 @@ GROUP BY
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@@ -1,149 +0,0 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.model import Account, AppMode
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
logger = logging.getLogger(__name__)
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
def get(self, app_model):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = args["node_id"]
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.filter(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
# Add computed fields for marshal_with
base_url = dify_config.SERVICE_API_URL
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
return webhook_trigger
class AppTriggersApi(Resource):
"""App Triggers list API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
def get(self, app_model):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
select(AppTrigger)
.where(
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
)
.scalars()
.all()
)
# Add computed icon field for each trigger
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for trigger in triggers:
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
class AppTriggerEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model):
"""Update app trigger (enable/disable)"""
parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.has_edit_permission:
raise Forbidden()
trigger_id = args["trigger_id"]
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
AppTrigger.id == trigger_id,
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
).scalar_one_or_none()
if not trigger:
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from libs.login import current_account_with_tenant
from models import App, AppMode
from models.account import Account
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account)
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from models import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser()
active_check_parser.add_argument(
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
)
active_check_parser.add_argument(
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
active_check_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
)
@@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
active_parser = reqparse.RequestParser()
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
active_parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate")

View File

@@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
@@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required
@account_initialization_required
def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
@@ -41,16 +41,20 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
@@ -63,9 +67,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204

View File

@@ -2,13 +2,12 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api, console_ns
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
@@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
@api.response(403, "Admin privileges required")
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()

View File

@@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
@@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
@@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"]
@@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match

View File

@@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
@@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"]
@@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match

View File

@@ -1,7 +1,5 @@
from typing import cast
import flask_login
from flask import request
from flask import make_response, request
from flask_restx import Resource, reqparse
import services
@@ -26,7 +24,17 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from models.account import Account
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_csrf_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
@@ -42,11 +50,13 @@ class LoginApi(Resource):
@email_password_login_enabled
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=str, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
@@ -89,19 +99,36 @@ class LoginApi(Resource):
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
@console_ns.route("/logout")
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}
response = make_response({"result": "success"})
else:
AccountService.logout(account=account)
flask_login.logout_user()
response = make_response({"result": "success"})
# Clear cookies on logout
clear_access_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
return response
@console_ns.route("/reset-password")
@@ -109,9 +136,11 @@ class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans":
@@ -137,9 +166,11 @@ class ResetPasswordSendEmailApi(Resource):
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
@@ -170,10 +201,12 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
)
args = parser.parse_args()
user_email = args["email"]
@@ -220,18 +253,46 @@ class EmailCodeLoginApi(Resource):
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
# Set HTTP-only secure cookies for tokens
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
return response
@console_ns.route("/refresh-token")
class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()
# Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token")
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401
try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
new_token_pair = AccountService.refresh_token(refresh_token)
# Create response with new cookies
response = make_response({"result": "success"})
# Update cookies with new tokens
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
set_access_token_to_cookie(request, response, new_token_pair.access_token)
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
return response
except Exception as e:
return {"result": "fail", "data": str(e)}, 401
return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
csrf_token = extract_csrf_token(request)
return {"logged_in": bool(token) and bool(csrf_token)}

View File

@@ -14,8 +14,12 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
@@ -153,9 +157,12 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:

View File

@@ -1,16 +1,15 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
from typing import Concatenate, ParamSpec, TypeVar
import flask_login
from flask import jsonify, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@@ -24,8 +23,7 @@ T = TypeVar("T")
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
@@ -91,8 +89,7 @@ class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
@@ -116,7 +113,8 @@ class OAuthServerUserAuthorizeApi(Resource):
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
@@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=False, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
try:

View File

@@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required
from models.model import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@@ -14,17 +13,15 @@ class Subscription(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@@ -34,7 +31,6 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
return BillingService.get_invoices(current_user.email, current_tenant_id)

View File

@@ -2,8 +2,7 @@ from flask import request
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from .. import console_ns
@@ -17,19 +16,16 @@ class ComplianceApi(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")
return BillingService.get_compliance_download_link(
doc_name=args.doc_name,
account_id=current_user.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
ip=ip_address,
device_info=device_info,
)

View File

@@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
@@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
).all()
@@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
@@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
@@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
)
text_docs = extractor.extract()
@@ -239,12 +244,14 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
# validate args
@@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],

View File

@@ -1,7 +1,6 @@
from typing import Any, cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@@ -30,10 +29,9 @@ from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@@ -138,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
@@ -146,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -207,50 +206,53 @@ class DatasetListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
parser.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
parser.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@@ -258,11 +260,11 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=cast(Account, current_user),
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
@@ -286,6 +288,7 @@ class DatasetApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -305,7 +308,7 @@ class DatasetApi(Resource):
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -351,73 +354,76 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
parser.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
@@ -440,7 +446,7 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
@@ -464,9 +470,9 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
try:
@@ -505,6 +511,7 @@ class DatasetQueryApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -539,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
parser = (
reqparse.RequestParser()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
).all()
if file_details is None:
@@ -592,7 +598,7 @@ class DatasetIndexingEstimateApi(Resource):
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@@ -608,7 +614,7 @@ class DatasetIndexingEstimateApi(Resource):
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
@@ -621,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
@@ -652,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required
@marshal_with(related_app_list)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -683,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
).all()
documents_status = []
for document in documents:
@@ -739,10 +745,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
@@ -752,12 +757,13 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
)
@@ -770,7 +776,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@@ -790,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
@@ -799,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
key = (
db.session.query(ApiToken)
.where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
@@ -898,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@@ -6,7 +6,6 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -53,9 +52,8 @@ from fields.document_fields import (
document_with_segments_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@@ -65,6 +63,7 @@ logger = logging.getLogger(__name__)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@@ -79,12 +78,13 @@ class DocumentResource(Resource):
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.")
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource):
@login_required
@account_initialization_required
def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args
document_id = req_data.get("document_id")
@@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
@@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if search:
search = f"%{search}%"
@@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -289,20 +292,20 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
parser.add_argument("data_source", type=dict, required=False, location="json")
parser.add_argument("process_rule", type=dict, required=False, location="json")
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
@@ -372,27 +375,28 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
@@ -402,7 +406,7 @@ class DatasetInitApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
@@ -419,9 +423,9 @@ class DatasetInitApi(Resource):
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
knowledge_config=knowledge_config,
account=cast(Account, current_user),
account=current_user,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@@ -482,7 +487,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
try:
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
[extract_setting],
data_process_rule_dict,
document.doc_form,
@@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@login_required
@account_initialization_required
def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
@@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
@@ -553,7 +559,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
@@ -569,7 +575,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
@@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
data_process_rule_dict,
document.doc_form,
@@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
@login_required
@account_initialization_required
def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
@@ -1034,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
def post(self, dataset_id):
"""retry document."""
parser = reqparse.RequestParser()
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -1077,14 +1087,14 @@ class DocumentRenameApi(DocumentResource):
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
@@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
@account_initialization_required
def get(self, dataset_id, document_id):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.")
if document.data_source_type != "website_crawl":
raise ValueError("Document is not a website document.")

View File

@@ -1,7 +1,6 @@
import uuid
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
@@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -59,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("hit_count_gte", type=int, default=None, location="args")
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args()
@@ -79,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment)
.where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
DocumentSegment.tenant_id == current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
@@ -115,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -148,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@@ -171,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -204,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -221,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -237,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
@@ -255,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -272,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -287,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -300,12 +313,14 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
@@ -317,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -333,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -361,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -372,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document:
raise NotFound("Document not found.")
parser = reqparse.RequestParser()
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
@@ -396,7 +416,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_tenant_id,
current_user.id,
)
except Exception as e:
@@ -427,6 +447,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -441,7 +463,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -453,7 +475,7 @@ class ChildChunkAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@@ -469,8 +491,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
content = args["content"]
@@ -483,6 +506,8 @@ class ChildChunkAddApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -499,15 +524,17 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args()
@@ -530,6 +557,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -546,7 +575,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -559,8 +588,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
chunks_data = args["chunks"]
@@ -580,6 +610,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -596,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -607,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
@@ -634,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@@ -650,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@@ -661,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
@@ -677,8 +711,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
content = args["content"]

View File

@@ -1,7 +1,4 @@
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -10,8 +7,7 @@ from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
@@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search
page, limit, current_tenant_id, search
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
@@ -60,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
@@ -85,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
tenant_id=current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -115,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=args,
@@ -148,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204
@@ -199,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
parser = (
reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
@@ -223,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
try:
dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
user_id=current_user.id,
args=args,
)
@@ -255,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@login_required
@account_initialization_required
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -265,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@@ -277,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
)
@@ -304,15 +313,17 @@ class BedrockRetrievalApi(Resource):
)
@api.response(200, "Bedrock retrieval test completed")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
parser = (
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()
# Call the knowledge retrieval service

View File

@@ -48,11 +48,12 @@ class DatasetsHitTestingBase:
@staticmethod
def parse_args():
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args()
@staticmethod

View File

@@ -1,13 +1,12 @@
from typing import Literal
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@@ -24,9 +23,12 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def post(self, dataset_id):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
@@ -59,8 +61,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
name = args["name"]
@@ -79,6 +81,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -108,6 +111,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -128,14 +132,16 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)

View File

@@ -1,19 +1,15 @@
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
@@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
user_id=current_user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
@@ -130,22 +126,24 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
parser = (
reqparse.RequestParser()
.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
@@ -160,8 +158,10 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
@@ -173,18 +173,21 @@ class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
@@ -197,18 +200,22 @@ class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
@@ -224,10 +231,10 @@ class DatasourceAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@@ -237,10 +244,10 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@@ -249,17 +256,20 @@ class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
@@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
@@ -284,16 +296,16 @@ class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
@@ -305,17 +317,20 @@ class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],

View File

@@ -26,10 +26,12 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
inputs = args.get("inputs")

View File

@@ -66,26 +66,28 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
@@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource):
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()

View File

@@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -13,7 +12,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@@ -27,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"yaml_content",
type=str,
nullable=False,
@@ -38,7 +35,7 @@ class CreateRagPipelineDatasetApi(Resource):
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@@ -58,12 +55,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
@@ -81,10 +78,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="",
description="",

View File

@@ -23,7 +23,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models.account import Account
from models import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
@@ -33,16 +33,18 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
@@ -206,10 +208,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),

View File

@@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -13,8 +10,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@@ -28,26 +24,29 @@ class RagPipelineImportApi(Resource):
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
@@ -74,15 +73,16 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
@@ -100,7 +100,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session:
@@ -117,12 +118,12 @@ class RagPipelineExportApi(Resource):
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=str, default="false", location="args")
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args()
with Session(db.engine) as session:

View File

@@ -18,6 +18,7 @@ from controllers.console.app.error import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
)
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError
@@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@@ -79,23 +77,25 @@ class DraftRagPipelineApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=False, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@@ -154,16 +154,15 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
@@ -194,11 +193,11 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
@@ -229,14 +228,17 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -264,17 +266,20 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
@@ -303,15 +308,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor:
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json")
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
@@ -344,15 +350,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor:
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json")
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
@@ -385,13 +392,16 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
inputs = args.get("inputs")
@@ -428,13 +438,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
inputs = args.get("inputs")
@@ -472,11 +485,13 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
inputs = args.get("inputs")
@@ -505,7 +520,8 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@@ -525,7 +541,8 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published:
return None
@@ -545,7 +562,8 @@ class PublishedRagPipelineApi(Resource):
Publish workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService()
@@ -580,7 +598,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
@@ -599,11 +618,11 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
q = args.get("q")
@@ -631,14 +650,17 @@ class PublishedAllRagPipelineApi(Resource):
"""
Get published workflows
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
@@ -681,12 +703,15 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes
"""
# Check permission
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
@@ -733,15 +758,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id")
if not node_id:
@@ -759,15 +781,12 @@ class PublishedRagPipelineFirstStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id")
if not node_id:
@@ -785,15 +804,12 @@ class DraftRagPipelineFirstStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id")
if not node_id:
@@ -811,15 +827,12 @@ class DraftRagPipelineSecondStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id")
if not node_id:
@@ -843,9 +856,11 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
@@ -880,7 +895,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id):
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
"""
@@ -903,14 +918,8 @@ class DatasourceListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
if not isinstance(user, Account):
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
@@ -940,9 +949,8 @@ class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
if not isinstance(current_user, Account):
raise Forbidden()
def post(self, dataset_id: str):
current_user, _ = current_account_with_tenant()
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
@@ -959,19 +967,20 @@ class RagPipelineDatasourceVariableApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields)
def post(self, pipeline: Pipeline):
"""
Set datasource variables
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=dict, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("start_node_title", type=str, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()

View File

@@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
parser = (
reqparse.RequestParser()
.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
)
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
# Create typed request and validate
@@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()

View File

@@ -3,8 +3,7 @@ from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
@@ -17,8 +16,7 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account):
raise ValueError("current_user is not an account")
_, current_tenant_id = current_account_with_tenant()
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
@@ -27,7 +25,7 @@ def get_rag_pipeline(
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
)

View File

@@ -81,11 +81,13 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)

View File

@@ -49,12 +49,14 @@ class CompletionApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
@@ -121,13 +123,15 @@ class ChatApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
args["auto_generate_name"] = False

View File

@@ -31,10 +31,12 @@ class ConversationListApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
)
args = parser.parse_args()
pinned = None
@@ -94,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, location="json")
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
try:

View File

@@ -12,10 +12,9 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required
from models import Account, App, InstalledApp, RecommendedApp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@@ -29,9 +28,7 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
current_user, current_tenant_id = current_account_with_tenant()
if app_id:
installed_apps = db.session.scalars(
@@ -69,31 +66,26 @@ class InstalledAppsListApi(Resource):
# Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list:
app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id,
app_codes=app_codes,
app_ids=app_ids,
)
# Keep only allowed apps
res = []
for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id]
if permissions.get(app_code):
if permissions.get(app_id):
res.append(installed_app)
installed_app_list = res
@@ -113,17 +105,15 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound("App not found")
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None:
@@ -163,9 +153,8 @@ class InstalledAppApi(InstalledAppResource):
"""
def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app)
@@ -174,8 +163,7 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app):
parser = reqparse.RequestParser()
parser.add_argument("is_pinned", type=inputs.boolean)
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args()
commit_args = False

View File

@@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -48,21 +47,22 @@ logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
@@ -78,18 +78,19 @@ class MessageListApi(InstalledAppResource):
)
class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json")
parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
@@ -109,14 +110,14 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
@@ -124,8 +125,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=current_user,
@@ -159,6 +158,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -167,8 +167,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)

View File

@@ -42,8 +42,7 @@ class RecommendedAppListApi(Resource):
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, location="args")
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
args = parser.parse_args()
language = args.get("language")

View File

@@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.login import current_user
from models import Account
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@@ -35,31 +34,30 @@ class SavedMessageListApi(InstalledAppResource):
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@@ -72,6 +70,7 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
@@ -79,8 +78,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204

View File

@@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user
from libs.login import current_user as current_user_
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@@ -31,6 +31,8 @@ from .. import console_ns
logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@@ -45,9 +47,11 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
assert current_user is not None
try:

View File

@@ -8,10 +8,8 @@ from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@@ -24,13 +22,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
installed_app = (
db.session.query(InstalledApp)
.where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.first()
)
@@ -56,14 +51,13 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id),
app_code=app_code,
app_id=app_id,
)
if not res:
raise AppAccessDeniedError()

View File

@@ -4,8 +4,7 @@ from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@@ -30,8 +29,7 @@ class CodeBasedExtensionAPI(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("module", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
args = parser.parse_args()
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
@@ -47,9 +45,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension")
@@ -70,16 +66,17 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
@@ -99,10 +96,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@@ -125,17 +120,17 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args()
extension_data_from_db.name = args["name"]
@@ -154,12 +149,10 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required
@account_initialization_required
def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db)

View File

@@ -1,7 +1,6 @@
from flask_restx import Resource, fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService
from . import api, console_ns
@@ -23,9 +22,9 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@@ -1,7 +1,6 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden
@@ -22,8 +21,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
@@ -53,6 +51,7 @@ class FileApi(Resource):
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@@ -65,16 +64,12 @@ class FileApi(Resource):
if not file.filename:
raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden()
if source not in ("datasets", None):
source = None
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
@@ -108,4 +103,4 @@ class FileSupportTypeApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@@ -57,8 +57,7 @@ class InitValidateAPI(Resource):
if tenant_count > 0:
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument("password", type=StrLen(30), required=True, location="json")
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"):

View File

@@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
@@ -41,8 +40,7 @@ class RemoteFileInfoApi(Resource):
class RemoteFileUploadApi(Resource):
@marshal_with(file_fields_with_signed_url)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, help="URL is required")
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
url = args["url"]
@@ -64,8 +62,7 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
assert isinstance(current_user, Account)
user = current_user
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@@ -69,10 +69,12 @@ class SetupApi(Resource):
if not get_init_validate_status():
raise NotInitValidateError()
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
)
args = parser.parse_args()
# setup

View File

@@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
@@ -24,11 +23,10 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(dataset_tag_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
return tags, 200
@@ -36,18 +34,23 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
tag = TagService.save_tags(args)
@@ -63,15 +66,13 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def patch(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
args = parser.parse_args()
@@ -87,8 +88,7 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
@@ -105,21 +105,22 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
parser.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
parser = (
reqparse.RequestParser()
.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
TagService.save_tag_binding(args)
@@ -133,17 +134,18 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
TagService.delete_tag_binding(args)

View File

@@ -37,8 +37,7 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
parser = reqparse.RequestParser()
parser.add_argument("current_version", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL

View File

@@ -2,11 +2,11 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
@@ -20,8 +20,9 @@ def plugin_permission_required(
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
with Session(db.engine) as session:
permission = (

View File

@@ -2,7 +2,6 @@ from datetime import datetime
import pytz
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -37,9 +36,8 @@ from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -50,9 +48,7 @@ class AccountInitApi(Resource):
@setup_required
@login_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
if account.status == "active":
raise AccountAlreadyInitedError()
@@ -61,9 +57,9 @@ class AccountInitApi(Resource):
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
parser.add_argument("timezone", type=timezone, required=True, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
args = parser.parse_args()
if dify_config.EDITION == "CLOUD":
@@ -106,8 +102,7 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields)
@enterprise_license_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
return current_user
@@ -118,10 +113,8 @@ class AccountNameApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
# Validate account name length
@@ -140,10 +133,8 @@ class AccountAvatarApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
@@ -158,10 +149,10 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
@@ -176,10 +167,10 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
@@ -194,10 +185,8 @@ class AccountTimezoneApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
@@ -216,12 +205,13 @@ class AccountPasswordApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json")
parser.add_argument("repeat_new_password", type=str, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
args = parser.parse_args()
if args["new_password"] != args["repeat_new_password"]:
@@ -253,9 +243,7 @@ class AccountIntegrateApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
@@ -298,9 +286,7 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code)
@@ -314,13 +300,13 @@ class AccountDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
args = parser.parse_args()
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
@@ -335,9 +321,11 @@ class AccountDeleteApi(Resource):
class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("feedback", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
args = parser.parse_args()
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
@@ -358,9 +346,7 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
return BillingService.EducationIdentity.verify(account.id, account.email)
@@ -380,14 +366,14 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
parser.add_argument("institution", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
args = parser.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@@ -399,9 +385,7 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat
@@ -425,10 +409,12 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("keywords", type=str, required=True, location="args")
parser.add_argument("page", type=int, required=False, location="args", default=0)
parser.add_argument("limit", type=int, required=False, location="args", default=20)
parser = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
args = parser.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@@ -441,11 +427,14 @@ class ChangeEmailSendEmailApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser.add_argument("phase", type=str, required=False, location="json")
parser.add_argument("token", type=str, required=False, location="json")
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
@@ -467,8 +456,6 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email:
raise InvalidEmailError()
else:
@@ -490,10 +477,12 @@ class ChangeEmailCheckApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"]
@@ -533,9 +522,11 @@ class ChangeEmailResetApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("new_email", type=email, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]):
@@ -551,8 +542,7 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()
@@ -569,8 +559,7 @@ class ChangeEmailResetApi(Resource):
class CheckEmailUnique(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
args = parser.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
raise AccountInFreezeError()

View File

@@ -3,8 +3,7 @@ from flask_restx import Resource, fields
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.agent_service import AgentService
@@ -21,12 +20,11 @@ class AgentProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@@ -45,9 +43,5 @@ class AgentProviderApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
current_user, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@@ -5,18 +5,10 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@@ -41,14 +33,16 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
@@ -87,11 +81,13 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
page = args["page"]
@@ -130,12 +126,14 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args()
page = args["page"]
@@ -172,10 +170,9 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
if not user.is_admin_or_owner:
@@ -212,12 +209,14 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser.add_argument("settings", type=dict, required=True)
parser.add_argument("name", type=str, required=True)
parser = (
reqparse.RequestParser()
.add_argument("endpoint_id", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
@@ -255,10 +254,9 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
@@ -288,10 +286,9 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]

View File

@@ -5,8 +5,8 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@@ -18,24 +18,25 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
# validate model load balancing credentials
@@ -72,24 +73,25 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
# validate model load balancing config credentials

View File

@@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@@ -41,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@@ -58,10 +57,12 @@ class MemberInviteEmailApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("emails", type=list, required=True, location="json")
parser.add_argument("role", type=str, required=True, default="admin", location="json")
parser.add_argument("language", type=str, required=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
invitee_emails = args["emails"]
@@ -69,9 +70,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@@ -120,8 +119,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
@@ -153,16 +151,13 @@ class MemberUpdateRoleApi(Resource):
@login_required
@account_initialization_required
def put(self, member_id):
parser = reqparse.RequestParser()
parser.add_argument("role", type=str, required=True, location="json")
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
new_role = args["role"]
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@@ -189,8 +184,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@@ -206,16 +200,13 @@ class SendOwnerTransferEmailApi(Resource):
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json")
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@@ -245,13 +236,14 @@ class OwnerTransferCheckApi(Resource):
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@@ -291,13 +283,13 @@ class OwnerTransfer(Resource):
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@@ -1,7 +1,6 @@
import io
from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@@ -23,14 +21,10 @@ class ModelProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
@@ -52,14 +46,12 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
@@ -73,23 +65,22 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
@@ -103,24 +94,23 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
@@ -135,19 +125,17 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
)
return {"result": "success"}, 204
@@ -159,19 +147,17 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
)
@@ -184,15 +170,13 @@ class ModelProviderValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
model_provider_service = ModelProviderService()
@@ -240,17 +224,13 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
@@ -276,14 +256,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email,
)

View File

@@ -1,6 +1,5 @@
import logging
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@@ -23,8 +22,9 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument(
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
@@ -34,8 +34,6 @@ class DefaultModelApi(Resource):
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
@@ -47,15 +45,15 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@@ -104,24 +102,26 @@ class ModelProviderModelApi(Resource):
@account_initialization_required
def post(self, provider: str):
# To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
args = parser.parse_args()
if args.get("config_from", "") == "custom-model":
@@ -129,7 +129,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -164,20 +164,22 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
@@ -195,20 +197,22 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@@ -257,24 +261,27 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
try:
@@ -301,29 +308,33 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -340,24 +351,28 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -373,24 +388,28 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -407,17 +426,19 @@ class ModelProviderModelEnableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
@@ -437,17 +458,19 @@ class ModelProviderModelDisableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
@@ -465,19 +488,21 @@ class ModelProviderModelValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@@ -511,11 +536,11 @@ class ModelProviderModelParameterRuleApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
parser = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@@ -531,8 +556,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required
@account_initialization_required
def get(self, model_type):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@@ -1,7 +1,6 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
@@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {
@@ -44,10 +43,12 @@ class PluginListApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=1)
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
args = parser.parse_args()
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
@@ -63,8 +64,7 @@ class PluginListLatestVersionsApi(Resource):
@login_required
@account_initialization_required
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_ids", type=list, required=True, location="json")
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args()
try:
@@ -81,10 +81,9 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
try:
@@ -99,9 +98,11 @@ class PluginListInstallationsFromIdsApi(Resource):
class PluginIconApi(Resource):
@setup_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("tenant_id", type=str, required=True, location="args")
req.add_argument("filename", type=str, required=True, location="args")
req = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
args = req.parse_args()
try:
@@ -120,7 +121,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
@@ -144,12 +145,14 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -167,7 +170,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
@@ -191,10 +194,11 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
@@ -217,13 +221,15 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -247,10 +253,11 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string
@@ -273,10 +280,11 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
try:
@@ -299,10 +307,11 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
try:
@@ -324,11 +333,13 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
parser.add_argument("page_size", type=int, required=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
try:
@@ -346,7 +357,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@@ -361,7 +372,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@@ -376,7 +387,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@@ -391,7 +402,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@@ -406,11 +417,13 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -430,14 +443,16 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
try:
@@ -462,11 +477,10 @@ class PluginUninstallApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@@ -480,19 +494,22 @@ class PluginChangePermissionApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("install_permission", type=str, required=True, location="json")
req.add_argument("debug_permission", type=str, required=True, location="json")
req = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@@ -503,7 +520,7 @@ class PluginFetchPermissionApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
@@ -529,31 +546,31 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required
def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id
parser = reqparse.RequestParser()
parser.add_argument("plugin_id", type=str, required=True, location="args")
parser.add_argument("provider", type=str, required=True, location="args")
parser.add_argument("action", type=str, required=True, location="args")
parser.add_argument("parameter", type=str, required=True, location="args")
parser.add_argument("credential_id", type=str, required=False, location="args")
parser.add_argument("provider_type", type=str, required=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
args = parser.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args["plugin_id"],
provider=args["provider"],
action=args["action"],
parameter=args["parameter"],
credential_id=args["credential_id"],
provider_type=args["provider_type"],
tenant_id,
user_id,
args["plugin_id"],
args["provider"],
args["action"],
args["parameter"],
args["provider_type"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@@ -567,17 +584,17 @@ class PluginChangePreferencesApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("permission", type=dict, required=True, location="json")
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
req = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@@ -623,7 +640,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
@@ -663,10 +680,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})

View File

@@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import (
Resource,
reqparse,
@@ -21,13 +20,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
# from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
@@ -55,13 +52,11 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
req = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
@@ -80,9 +75,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
@@ -98,9 +91,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@@ -111,13 +102,13 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
req = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
@@ -133,15 +124,16 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
if args["type"] not in CredentialType.values():
@@ -163,18 +155,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
@@ -195,7 +188,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
@@ -220,23 +213,24 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
@@ -260,14 +254,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
@@ -284,14 +275,13 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
@@ -310,24 +300,25 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
args = parser.parse_args()
@@ -352,17 +343,16 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
@@ -379,14 +369,13 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
@@ -403,8 +392,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
@@ -419,9 +407,9 @@ class ToolApiProviderSchemaApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
@@ -436,19 +424,20 @@ class ToolApiProviderPreviousTestApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
current_tenant_id,
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
@@ -464,23 +453,24 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
@@ -504,23 +494,24 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
@@ -547,16 +538,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = reqparser.parse_args()
@@ -573,14 +564,15 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
parser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
@@ -608,13 +600,13 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
parser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
args = parser.parse_args()
@@ -633,10 +625,9 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -655,8 +646,7 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
[
@@ -674,10 +664,9 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -711,19 +700,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
@@ -802,11 +790,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
@@ -816,18 +804,20 @@ class ToolOAuthCustomClient(Resource):
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@@ -837,20 +827,18 @@ class ToolOAuthCustomClient(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@@ -860,9 +848,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
tenant_id=current_tenant_id, provider_name=provider
)
)
@@ -873,7 +862,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@@ -889,25 +878,25 @@ class ToolProviderMCPApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
@@ -925,25 +914,28 @@ class ToolProviderMCPApi(Resource):
@login_required
@account_initialization_required
def put(self):
parser = reqparse.RequestParser()
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
@@ -961,10 +953,12 @@ class ToolProviderMCPApi(Resource):
@login_required
@account_initialization_required
def delete(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@@ -974,12 +968,14 @@ class ToolMCPAuthApi(Resource):
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
@@ -1020,8 +1016,8 @@ class ToolMCPDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@@ -1031,8 +1027,7 @@ class ToolMCPListAllApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
@@ -1045,7 +1040,7 @@ class ToolMCPUpdateApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
@@ -1056,9 +1051,11 @@ class ToolMCPUpdateApi(Resource):
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, nullable=False, location="args")
parser.add_argument("state", type=str, required=True, nullable=False, location="args")
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
args = parser.parse_args()
state_key = args["state"]
authorization_code = args["code"]

View File

@@ -1,590 +0,0 @@
import logging
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from libs.login import current_user, login_required
from models.account import Account
from models.provider_ids import TriggerProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger.workflow_plugin_trigger_service import WorkflowPluginTriggerService
logger = logging.getLogger(__name__)
class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
)
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
raise
class TriggerSubscriptionBuilderCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
)
return jsonable_encoder({"subscription_builder": subscription_builder})
except Exception as e:
logger.exception("Error adding provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
# Use atomic update_and_verify to prevent race conditions
return TriggerSubscriptionBuilderService.update_and_verify_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
),
)
except Exception as e:
logger.exception("Error verifying provider credential", exc_info=e)
raise ValueError(str(e)) from e
class TriggerSubscriptionBuilderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
except Exception as e:
logger.exception("Error updating provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
except Exception as e:
logger.exception("Error getting request logs for subscription builder", exc_info=e)
raise
class TriggerSubscriptionBuilderBuildApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
return 200
except Exception as e:
logger.exception("Error building provider credential", exc_info=e)
raise ValueError(str(e)) from e
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, subscription_id):
"""Delete a subscription instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
with Session(db.engine) as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
# Delete plugin triggers
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error deleting provider credential", exc_info=e)
raise
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Create subscription builder
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
)
# Create OAuth handler and proxy context
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
extra_data={
"subscription_builder_id": subscription_builder.id,
},
)
# Build redirect URI for callback
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
# Get authorization URL
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(
jsonable_encoder(
{
"authorization_url": authorization_url_response.authorization_url,
"subscription_builder_id": subscription_builder.id,
"subscription_builder": subscription_builder,
}
)
)
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e)
raise
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
# Use and validate proxy context
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
# Parse provider ID
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Get OAuth credentials from callback
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise Exception("Failed to get OAuth credentials")
# Update subscription builder
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=credentials,
credential_expires_at=expires_at,
),
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
# Get custom OAuth client params if exists
custom_params = TriggerProviderService.get_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if custom client is enabled
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if there's a system OAuth client
system_client = TriggerProviderService.get_oauth_client(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
return jsonable_encoder(
{
"configured": bool(custom_params or system_client),
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
"custom_configured": bool(custom_params),
"custom_enabled": is_custom_enabled,
"redirect_uri": redirect_uri,
"params": custom_params or {},
}
)
except Exception as e:
logger.exception("Error getting OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error configuring OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
# Trigger Subscription
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
api.add_resource(
TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
# Trigger Subscription Builder
api.add_resource(
TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
api.add_resource(
TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
# OAuth
api.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
)
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")

View File

@@ -23,8 +23,8 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models.account import Account, Tenant, TenantStatus
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService
from services.file_service import FileService
@@ -70,8 +70,7 @@ class TenantListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
@@ -85,7 +84,7 @@ class TenantListApi(Resource):
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
tenant_dicts.append(tenant_dict)
@@ -98,9 +97,11 @@ class WorkspaceListApi(Resource):
@setup_required
@admin_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
stmt = select(Tenant).order_by(Tenant.created_at.desc())
@@ -130,8 +131,7 @@ class TenantApi(Resource):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
@@ -155,10 +155,8 @@ class SwitchWorkspaceApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
# check if tenant_id is valid, 403 if not
@@ -181,16 +179,14 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("remove_webapp_brand", type=bool, location="json")
.add_argument("replace_webapp_logo", type=str, location="json")
)
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
@@ -212,8 +208,7 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@@ -253,15 +248,13 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required
# Change workspace name
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
if not current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"]
db.session.commit()

View File

@@ -12,8 +12,8 @@ from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_user
from models.account import Account, AccountStatus
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@@ -25,18 +25,12 @@ P = ParamSpec("P")
R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = _current_account()
if account.status == AccountStatus.UNINITIALIZED:
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
return view(*args, **kwargs)
@@ -80,9 +74,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@@ -94,10 +87,8 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@@ -138,9 +129,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@@ -163,13 +153,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
_, current_tenant_id = current_account_with_tenant()
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}"
key = f"rate_limit_{current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@@ -180,7 +168,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=tenant_id,
tenant_id=current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@@ -200,17 +188,15 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(tenant_id, utm_info_dict)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@@ -260,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated
def email_register_enabled(view):
def email_register_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.is_allow_register:
return view(*args, **kwargs)
@@ -289,9 +275,8 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@@ -301,14 +286,31 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated
def knowledge_pipeline_publish_enabled(view):
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)
return decorated
def edit_permission_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@@ -46,11 +46,13 @@ class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
parser = reqparse.RequestParser()
parser.add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
parser = (
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()

View File

@@ -16,12 +16,13 @@ class ToolFileApi(Resource):
def get(self, file_id, extension):
file_id = str(file_id)
parser = reqparse.RequestParser()
parser.add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
parser = (
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not verify_tool_file_signature(

View File

@@ -18,19 +18,17 @@ from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model
# Define parser for both documentation and validation
upload_parser = reqparse.RequestParser()
upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
upload_parser.add_argument(
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
upload_parser = (
reqparse.RequestParser()
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
.add_argument(
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
)
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
)
upload_parser.add_argument(
"nonce", type=str, required=True, location="args", help="Random string for signature verification"
)
upload_parser.add_argument(
"sign", type=str, required=True, location="args", help="HMAC signature for request validation"
)
upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
@files_ns.route("/upload/for-plugin")

View File

@@ -5,11 +5,13 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task
_mail_parser = reqparse.RequestParser()
_mail_parser.add_argument("to", type=str, action="append", required=True)
_mail_parser.add_argument("subject", type=str, required=True)
_mail_parser.add_argument("body", type=str, required=True)
_mail_parser.add_argument("substitutions", type=dict, required=False)
_mail_parser = (
reqparse.RequestParser()
.add_argument("to", type=str, action="append", required=True)
.add_argument("subject", type=str, required=True)
.add_argument("body", type=str, required=True)
.add_argument("substitutions", type=dict, required=False)
)
class BaseMail(Resource):
@@ -17,7 +19,7 @@ class BaseMail(Resource):
def post(self):
args = _mail_parser.parse_args()
send_inner_email_task.delay(
send_inner_email_task.delay( # type: ignore
to=args["to"],
subject=args["subject"],
body=args["body"],

View File

@@ -31,7 +31,7 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import length_prefixed_response
from models.account import Account, Tenant
from models import Account, Tenant
from models.model import EndUser

Some files were not shown because too many files have changed in this diff Show More