mirror of
https://github.com/langgenius/dify.git
synced 2026-02-08 17:24:00 +00:00
Compare commits
167 Commits
feat/trigg
...
feature/sm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9dadecdd42 | ||
|
|
c20e0ad90d | ||
|
|
22f64d60bb | ||
|
|
7b7d332239 | ||
|
|
b1d189324a | ||
|
|
00fb468f2e | ||
|
|
bbbb6e04cb | ||
|
|
f5161d9add | ||
|
|
787251f00e | ||
|
|
cfe21f0826 | ||
|
|
196f691865 | ||
|
|
7a5bb1cfac | ||
|
|
b80d55b764 | ||
|
|
dd71625f52 | ||
|
|
19936d23d1 | ||
|
|
decf0f3da0 | ||
|
|
7242a67f84 | ||
|
|
c4884eb669 | ||
|
|
d49f3327e4 | ||
|
|
633e68a2f7 | ||
|
|
809f48f733 | ||
|
|
578b1b45ea | ||
|
|
86c3c58e64 | ||
|
|
8d803a26eb | ||
|
|
aa3129c2a9 | ||
|
|
97c924fe29 | ||
|
|
591c463e4b | ||
|
|
e1691fddaa | ||
|
|
b4d4351203 | ||
|
|
f7b1348623 | ||
|
|
2619c7553a | ||
|
|
f79d8baf63 | ||
|
|
bbdcbac544 | ||
|
|
d552680e72 | ||
|
|
df43c6ab8a | ||
|
|
cd47a47c3b | ||
|
|
e5d4235f1b | ||
|
|
f60aa36fa0 | ||
|
|
b2bcb6d21a | ||
|
|
ee04f0d250 | ||
|
|
b6cea71023 | ||
|
|
6462328620 | ||
|
|
fd86cadf67 | ||
|
|
c43c72c1a3 | ||
|
|
d77c2e4d17 | ||
|
|
1a7898dff1 | ||
|
|
af662b100b | ||
|
|
595df172a8 | ||
|
|
70bc5ca7f4 | ||
|
|
30617feff8 | ||
|
|
756864c85b | ||
|
|
c8c94ef870 | ||
|
|
10d51ada59 | ||
|
|
00f3a53f1c | ||
|
|
d2f0551170 | ||
|
|
cba2b9b2ad | ||
|
|
029d5d36ac | ||
|
|
8d897153a5 | ||
|
|
2e914808ea | ||
|
|
16e9ea44a9 | ||
|
|
d00a72a435 | ||
|
|
36580221aa | ||
|
|
e686cc9eab | ||
|
|
66196459d5 | ||
|
|
a5387b304e | ||
|
|
beb1448441 | ||
|
|
272102c06d | ||
|
|
36406cd62f | ||
|
|
87c41c88a3 | ||
|
|
095c56a646 | ||
|
|
244c132656 | ||
|
|
043ec46c33 | ||
|
|
0e4f19eee0 | ||
|
|
ff34969f21 | ||
|
|
9a7245e1df | ||
|
|
4906eeac18 | ||
|
|
4da93ba579 | ||
|
|
319ecdd312 | ||
|
|
4a8ac18879 | ||
|
|
0c1ec35244 | ||
|
|
46375aacdb | ||
|
|
e6d4331994 | ||
|
|
2a0abc51b1 | ||
|
|
3bb67885ef | ||
|
|
e682749d03 | ||
|
|
9b83b0aadd | ||
|
|
0cac330bc2 | ||
|
|
fb8114792a | ||
|
|
eab6f65409 | ||
|
|
915023b809 | ||
|
|
f104839672 | ||
|
|
6841a09667 | ||
|
|
e937c8c72e | ||
|
|
960bb8a9b4 | ||
|
|
9b36059292 | ||
|
|
a4acc64afd | ||
|
|
1fc4844beb | ||
|
|
25c69ac540 | ||
|
|
bccd18b838 | ||
|
|
96a0b9991e | ||
|
|
2913d17fe2 | ||
|
|
d9e45a1abe | ||
|
|
24b4289d6c | ||
|
|
fb6ccccc3d | ||
|
|
8b74ae683a | ||
|
|
dd08957381 | ||
|
|
f486d1bcee | ||
|
|
2c2069f77c | ||
|
|
407323f817 | ||
|
|
2e2c87c5a1 | ||
|
|
cf222ecfed | ||
|
|
2c343e98cc | ||
|
|
2d4d4b6b8a | ||
|
|
25ae492247 | ||
|
|
16d30fbd60 | ||
|
|
69f712b713 | ||
|
|
f4522fd695 | ||
|
|
760a2c656c | ||
|
|
8940decd1b | ||
|
|
0c4193bd91 | ||
|
|
cd40cde790 | ||
|
|
c60c754ac9 | ||
|
|
ef80d3b707 | ||
|
|
24e8d21b3f | ||
|
|
d823da18db | ||
|
|
1e3df09fc6 | ||
|
|
75a10c276c | ||
|
|
50050527eb | ||
|
|
a39b185627 | ||
|
|
15270f09af | ||
|
|
f6a5ac0698 | ||
|
|
2b79da722b | ||
|
|
71d69e43cd | ||
|
|
5bc6e8a433 | ||
|
|
68076f2e22 | ||
|
|
8c38363038 | ||
|
|
345ac8333c | ||
|
|
2375047ef0 | ||
|
|
857a48012e | ||
|
|
208fe3d7de | ||
|
|
92cddbcc02 | ||
|
|
599b53c9cb | ||
|
|
062b173c66 | ||
|
|
db690013fd | ||
|
|
e93bfe3d41 | ||
|
|
ab910c736c | ||
|
|
4047a6bb12 | ||
|
|
df2478dc26 | ||
|
|
4cc3f6045b | ||
|
|
1550316b8d | ||
|
|
87394d2512 | ||
|
|
bad59c95bc | ||
|
|
9f138ef246 | ||
|
|
6453fc4973 | ||
|
|
f62f926537 | ||
|
|
b3dafd913b | ||
|
|
b2d8a7eaf1 | ||
|
|
3e54414191 | ||
|
|
a173546c8d | ||
|
|
aa69d90489 | ||
|
|
4ba1292455 | ||
|
|
bb01c31f30 | ||
|
|
cd90b2ca9e | ||
|
|
9a65350cf7 | ||
|
|
680eb7a9f6 | ||
|
|
878420463c | ||
|
|
4692e20daf |
@@ -1,4 +1,4 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
npm add -g pnpm@10.15.0
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: "\U0001F510 Security Vulnerabilities"
|
||||
url: "https://github.com/langgenius/dify/security/advisories/new"
|
||||
about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues.
|
||||
- name: "\U0001F4A1 Model Providers & Plugins"
|
||||
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
||||
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
||||
|
||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@@ -2,8 +2,6 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@@ -17,10 +15,12 @@ jobs:
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: "3.11"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format ..
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@@ -8,8 +8,7 @@ on:
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "deploy/rag-dev"
|
||||
- "feat/rag-2"
|
||||
- "hotfix/**"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
4
.github/workflows/deploy-dev.yml
vendored
4
.github/workflows/deploy-dev.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -231,5 +231,7 @@ api/.env.backup
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.serena
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
90
AGENTS.md
90
AGENTS.md
@@ -4,85 +4,51 @@
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
The codebase is split into:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
## Backend Workflow
|
||||
|
||||
### Backend (API)
|
||||
- Run backend CLI commands through `uv run --project api <command>`.
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
- Integration tests are CI-only and are not expected to run in the local environment.
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
## Frontend Workflow
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
pnpm lint
|
||||
pnpm lint:fix
|
||||
pnpm test
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
## Testing & Quality Practices
|
||||
|
||||
### Backend Testing
|
||||
- Follow TDD: red → green → refactor.
|
||||
- Use `pytest` for backend tests with Arrange-Act-Assert structure.
|
||||
- Enforce strong typing; avoid `Any` and prefer explicit type annotations.
|
||||
- Write self-documenting code; only add comments that explain intent.
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
## Language Style
|
||||
|
||||
## Code Style Requirements
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
||||
|
||||
### Python
|
||||
## General Practices
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
- Prefer editing existing files; add new documentation only when requested.
|
||||
- Inject dependencies through constructors and preserve clean architecture boundaries.
|
||||
- Handle errors with domain-specific exceptions at the correct layer.
|
||||
|
||||
### TypeScript/JavaScript
|
||||
## Project Conventions
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||
- Backend architecture adheres to DDD and Clean Architecture principles.
|
||||
- Async work runs through Celery with Redis as the broker.
|
||||
- Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text.
|
||||
|
||||
6
Makefile
6
Makefile
@@ -26,7 +26,6 @@ prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@@ -61,8 +60,9 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format and check with fixes..."
|
||||
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
||||
24
README.md
24
README.md
@@ -40,18 +40,18 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
|
||||
@@ -304,6 +304,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
@@ -377,6 +379,19 @@ SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
|
||||
# SMTP authentication type: 'basic' for username/password, 'oauth2' for Microsoft OAuth 2.0
|
||||
# Use 'oauth2' for Microsoft Exchange/Outlook due to Basic Auth retirement (September 2025)
|
||||
SMTP_AUTH_TYPE=basic
|
||||
|
||||
# Microsoft OAuth 2.0 configuration for SMTP authentication
|
||||
# Required when SMTP_AUTH_TYPE=oauth2 and using Microsoft Exchange/Outlook
|
||||
# Setup: Create Azure AD app → Add Mail.Send + SMTP.Send permissions → Get Client ID/Secret
|
||||
# For Exchange Online: SMTP_SERVER=smtp.office365.com, SMTP_PORT=587, SMTP_USE_TLS=true
|
||||
MICROSOFT_OAUTH2_CLIENT_ID=
|
||||
MICROSOFT_OAUTH2_CLIENT_SECRET=
|
||||
MICROSOFT_OAUTH2_TENANT_ID=common
|
||||
MICROSOFT_OAUTH2_ACCESS_TOKEN=
|
||||
# Sendgid configuration
|
||||
SENDGRID_API_KEY=
|
||||
# Sentry configuration
|
||||
@@ -406,6 +421,9 @@ SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
SSRF_POOL_MAX_CONNECTIONS=100
|
||||
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
@@ -416,6 +434,10 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
CODE_EXECUTION_SSL_VERIFY=True
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
@@ -436,9 +458,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@@ -462,7 +481,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
@@ -517,12 +535,6 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
|
||||
@@ -30,6 +30,7 @@ select = [
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"T201", # print-found
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
@@ -91,11 +92,18 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -80,10 +80,10 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
|
||||
@@ -1,20 +1,11 @@
|
||||
import logging
|
||||
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
_log("gRPC patched with gevent.")
|
||||
print("gRPC patched with gevent.", flush=True) # noqa: T201
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
_log("psycopg2 patched with gevent.")
|
||||
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
422
api/commands.py
422
api/commands.py
@@ -10,16 +10,17 @@ from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account.email = new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
@@ -739,18 +739,18 @@ where sites.id is null limit 1000"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
print(f"App {app_id} not found")
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
print(f"Fix failed for app {app.id}")
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
print(f"Fixing missing site for app {app.id}")
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
@@ -1227,55 +1227,6 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@@ -1497,41 +1448,52 @@ def transform_datasource_credentials():
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
@@ -1544,37 +1506,48 @@ def transform_datasource_credentials():
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
@@ -1587,36 +1560,45 @@ def transform_datasource_credentials():
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
continue
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@@ -113,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent connections for the code execution HTTP client",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of persistent keep-alive connections for the code execution HTTP client",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||
description="Keep-alive expiry in seconds for idle connections (set to None to disable)",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="Maximum allowed numeric value in code execution",
|
||||
default=9223372036854775807,
|
||||
@@ -153,15 +168,9 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
CODE_EXECUTION_SSL_VERIFY: bool = Field(
|
||||
description="Enable or disable SSL verification for code execution requests",
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -415,6 +424,21 @@ class HttpConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent connections for the SSRF HTTP client",
|
||||
default=100,
|
||||
)
|
||||
|
||||
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of persistent keep-alive connections for the SSRF HTTP client",
|
||||
default=20,
|
||||
)
|
||||
|
||||
SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||
description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
||||
" when the app is behind a single trusted reverse proxy.",
|
||||
@@ -553,11 +577,6 @@ class WorkflowConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
@@ -802,6 +821,32 @@ class MailConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_AUTH_TYPE: str = Field(
|
||||
description="SMTP authentication type ('basic' or 'oauth2')",
|
||||
default="basic",
|
||||
)
|
||||
|
||||
# Microsoft OAuth 2.0 configuration for SMTP
|
||||
MICROSOFT_OAUTH2_CLIENT_ID: str | None = Field(
|
||||
description="Microsoft OAuth 2.0 client ID for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MICROSOFT_OAUTH2_CLIENT_SECRET: str | None = Field(
|
||||
description="Microsoft OAuth 2.0 client secret for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MICROSOFT_OAUTH2_TENANT_ID: str = Field(
|
||||
description="Microsoft OAuth 2.0 tenant ID (use 'common' for multi-tenant)",
|
||||
default="common",
|
||||
)
|
||||
|
||||
MICROSOFT_OAUTH2_ACCESS_TOKEN: str | None = Field(
|
||||
description="Microsoft OAuth 2.0 access token for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
|
||||
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
|
||||
default=50,
|
||||
@@ -961,22 +1006,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@@ -1100,7 +1129,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
|
||||
@@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
|
||||
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
|
||||
default="DEFAULT_ANALYZER",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
@@ -37,3 +37,15 @@ class OceanBaseVectorConfig(BaseSettings):
|
||||
"with older versions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
||||
description=(
|
||||
"Fulltext parser to use for text indexing. "
|
||||
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
|
||||
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
|
||||
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
|
||||
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
|
||||
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
|
||||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,10 +30,10 @@ class NacosHttpClient:
|
||||
params = {}
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
@@ -78,7 +78,7 @@ class NacosHttpClient:
|
||||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp = httpx.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
|
||||
@@ -9,8 +9,6 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
"""
|
||||
@@ -43,11 +41,3 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@@ -1,31 +1,10 @@
|
||||
from importlib import import_module
|
||||
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from .explore.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from .explore.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageListApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from .explore.workflow import (
|
||||
InstalledAppWorkflowRunApi,
|
||||
InstalledAppWorkflowTaskStopApi,
|
||||
)
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
|
||||
api = ExternalApi(
|
||||
@@ -35,23 +14,23 @@ api = ExternalApi(
|
||||
description="Console management APIs for app configuration, monitoring, and administration",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
console_ns = Namespace("console", description="Console management API operations", path="/")
|
||||
|
||||
# File
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
RESOURCE_MODULES = (
|
||||
"controllers.console.app.app_import",
|
||||
"controllers.console.explore.audio",
|
||||
"controllers.console.explore.completion",
|
||||
"controllers.console.explore.conversation",
|
||||
"controllers.console.explore.message",
|
||||
"controllers.console.explore.workflow",
|
||||
"controllers.console.files",
|
||||
"controllers.console.remote_files",
|
||||
)
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
for module_name in RESOURCE_MODULES:
|
||||
import_module(module_name)
|
||||
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
from . import (
|
||||
admin,
|
||||
@@ -87,7 +66,6 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@@ -151,94 +129,8 @@ from .workspace import (
|
||||
workspace,
|
||||
)
|
||||
|
||||
# Explore Audio
|
||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||
|
||||
# Explore Completion
|
||||
api.add_resource(
|
||||
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
CompletionStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
api.add_resource(
|
||||
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
ChatStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
|
||||
# Explore Conversation
|
||||
api.add_resource(
|
||||
ConversationRenameApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationUnPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
|
||||
|
||||
# Explore Message
|
||||
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
||||
api.add_resource(
|
||||
MessageFeedbackApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageMoreLikeThisApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageSuggestedQuestionApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
# Explore Workflow
|
||||
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
api.add_resource(
|
||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"activate",
|
||||
@@ -304,7 +196,6 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
@@ -28,12 +29,6 @@ from services.feature_service import FeatureService
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
@@ -138,7 +133,7 @@ class AppListApi(Resource):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
@@ -219,7 +214,7 @@ class AppApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
@@ -297,7 +292,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
||||
@@ -20,7 +20,10 @@ from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -74,6 +77,7 @@ class AppImportApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -98,6 +102,7 @@ class AppImportConfirmApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
@@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
@@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
|
||||
@@ -12,7 +12,6 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@@ -199,11 +198,13 @@ class InstructionGenerateApi(Resource):
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
|
||||
@@ -62,6 +62,9 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
|
||||
@@ -50,8 +50,9 @@ class DailyMessageStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -187,8 +188,9 @@ class DailyTerminalsStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -259,8 +261,9 @@ class DailyTokenCostStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -340,8 +343,9 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -426,8 +430,9 @@ LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -502,8 +507,9 @@ class AverageResponseTimeStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
@@ -576,8 +582,9 @@ class TokensPerSecondStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -20,7 +19,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
@@ -36,7 +34,6 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.trigger_debug_service import TriggerDebugService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -799,24 +796,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@api.doc("get_workflow_config")
|
||||
@api.doc(description="Get workflow configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Workflow configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.doc("get_all_published_workflows")
|
||||
@@ -1006,165 +985,3 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={
|
||||
"app_id": "Application ID",
|
||||
"node_id": "Node ID"
|
||||
})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerNodeRequest",
|
||||
{
|
||||
"trigger_name": fields.String(required=True, description="Trigger name"),
|
||||
"subscription_id": fields.String(required=True, description="Subscription ID"),
|
||||
}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
try:
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
user_inputs = event.model_dump()
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query="",
|
||||
files=[],
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
"trigger_name": fields.String(required=True, description="Trigger name"),
|
||||
"subscription_id": fields.String(required=True, description="Subscription ID"),
|
||||
}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
trigger_name = args["trigger_name"]
|
||||
subscription_id = args["subscription_id"]
|
||||
|
||||
event = TriggerDebugService.poll_event(
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
subscription_id=subscription_id,
|
||||
node_id=node_id,
|
||||
trigger_name=trigger_name,
|
||||
)
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting"})
|
||||
|
||||
workflow_args = {
|
||||
"inputs": event.model_dump(),
|
||||
"query": "",
|
||||
"files": [],
|
||||
}
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
workflow_args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 500
|
||||
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account, AppMode
|
||||
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
|
||||
class PluginTriggerApi(Resource):
|
||||
"""Workflow Plugin Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def post(self, app_model):
|
||||
"""Create plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=False, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=False, location="json")
|
||||
parser.add_argument("trigger_name", type=str, required=False, location="json")
|
||||
parser.add_argument("subscription_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args["provider_id"],
|
||||
trigger_name=args["trigger_name"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def get(self, app_model):
|
||||
"""Get plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def put(self, app_model):
|
||||
"""Update plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return jsonable_encoder(plugin_trigger)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
def delete(self, app_model):
|
||||
"""Delete plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = args["node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.filter(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
# Add computed fields for marshal_with
|
||||
base_url = dify_config.SERVICE_API_URL
|
||||
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
|
||||
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@@ -2,7 +2,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@@ -10,6 +10,7 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -33,6 +34,7 @@ class ApiKeyAuthDataSource(Resource):
|
||||
return {"sources": []}
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source/binding")
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -54,6 +56,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -66,8 +69,3 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
@@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
@@ -25,6 +25,7 @@ from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -52,6 +53,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register/validity")
|
||||
class EmailRegisterCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -92,6 +94,7 @@ class EmailRegisterCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
class EmailRegisterResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -148,8 +151,3 @@ class EmailRegisterResetApi(Resource):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
|
||||
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
|
||||
api.add_resource(EmailRegisterResetApi, "/email-register")
|
||||
|
||||
@@ -221,8 +221,3 @@ class ForgotPasswordResetApi(Resource):
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
||||
@@ -7,7 +7,7 @@ from flask_restx import Resource, reqparse
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
@@ -34,6 +34,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for user login."""
|
||||
|
||||
@@ -91,6 +92,7 @@ class LoginApi(Resource):
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
@@ -102,6 +104,7 @@ class LogoutApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/reset-password")
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@@ -130,6 +133,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
@@ -162,6 +166,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
@@ -218,6 +223,7 @@ class EmailCodeLoginApi(Resource):
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -229,11 +235,3 @@ class RefreshTokenApi(Resource):
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
@@ -101,8 +101,10 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
except httpx.RequestError as e:
|
||||
error_text = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
from .. import console_ns
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -86,6 +86,7 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid
|
||||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider")
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@@ -108,6 +109,7 @@ class OAuthServerAppApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/authorize")
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -125,6 +127,7 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/token")
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource):
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
class Subscription(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -26,6 +27,7 @@ class Subscription(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
class Invoices(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -36,7 +38,3 @@ class Invoices(Resource):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, "/billing/subscription")
|
||||
api.add_resource(Invoices, "/billing/invoices")
|
||||
|
||||
@@ -6,10 +6,11 @@ from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import api
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/compliance/download")
|
||||
class ComplianceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -30,6 +31,3 @@ class ComplianceApi(Resource):
|
||||
ip=ip_address,
|
||||
device_info=device_info,
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(ComplianceApi, "/compliance/download")
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
@@ -27,6 +27,10 @@ from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/data-source/integrates",
|
||||
"/data-source/integrates/<uuid:binding_id>/<string:action>",
|
||||
)
|
||||
class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -109,6 +113,7 @@ class DataSourceApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notion/pre-import/pages")
|
||||
class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -196,6 +201,10 @@ class DataSourceNotionListApi(Resource):
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -269,6 +278,7 @@ class DataSourceNotionApi(Resource):
|
||||
return response.model_dump(), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -285,6 +295,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -301,16 +312,3 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
||||
api.add_resource(
|
||||
DataSourceNotionApi,
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
||||
api.add_resource(
|
||||
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import flask_restx
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@@ -30,24 +31,20 @@ from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
@@ -92,7 +89,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
@@ -147,7 +144,7 @@ class DatasetListApi(Resource):
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -192,7 +189,7 @@ class DatasetListApi(Resource):
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
@@ -224,7 +221,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@@ -288,7 +285,7 @@ class DatasetApi(Resource):
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
@@ -369,7 +366,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
@@ -688,7 +685,7 @@ class DatasetApiKeyApi(Resource):
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restx.abort(
|
||||
api.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
@@ -733,7 +730,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
)
|
||||
|
||||
if key is None:
|
||||
flask_restx.abort(404, message="API key not found")
|
||||
api.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
@@ -782,7 +779,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@@ -809,6 +805,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@@ -838,7 +835,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
@@ -863,6 +859,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
@@ -54,6 +55,7 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
@@ -211,13 +213,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
@@ -417,7 +419,9 @@ class DatasetInitApi(Resource):
|
||||
|
||||
try:
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
knowledge_config=knowledge_config,
|
||||
account=cast(Account, current_user),
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -451,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
|
||||
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
|
||||
|
||||
@@ -513,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
@@ -752,7 +756,7 @@ class DocumentApi(DocumentResource):
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@@ -1072,7 +1076,9 @@ class DocumentRenameApi(DocumentResource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
@@ -1113,6 +1119,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -1146,29 +1153,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
api.add_resource(
|
||||
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
|
||||
)
|
||||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(
|
||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||
)
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
ChildChunkDeleteIndexError,
|
||||
@@ -37,6 +37,7 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -139,6 +140,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -193,6 +195,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -244,6 +247,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -345,6 +349,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -384,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||
str(job_id),
|
||||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_user.current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
@@ -393,7 +406,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
def get(self, job_id=None, dataset_id=None, document_id=None):
|
||||
if job_id is None:
|
||||
raise NotFound("The job does not exist.")
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@@ -403,6 +418,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||
class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -457,7 +473,8 @@ class ChildChunkAddApi(Resource):
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
@@ -546,13 +563,17 @@ class ChildChunkAddApi(Resource):
|
||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
|
||||
)
|
||||
class ChildChunkUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -660,33 +681,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(
|
||||
args.get("content"), child_chunk, segment, document, dataset
|
||||
)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
|
||||
)
|
||||
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentBatchImportApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
ChildChunkAddApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
|
||||
)
|
||||
api.add_resource(
|
||||
ChildChunkUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
@@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
@@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@@ -20,6 +21,7 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
@@ -59,7 +61,7 @@ class DatasetsHitTestingBase:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
@@ -16,6 +16,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -50,6 +51,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -60,6 +62,7 @@ class DatasetMetadataApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@@ -68,7 +71,7 @@ class DatasetMetadataApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@@ -87,6 +90,7 @@ class DatasetMetadataApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -97,6 +101,7 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -116,6 +121,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -136,10 +142,3 @@ class DocumentMetadataEditApi(Resource):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
@@ -19,6 +19,7 @@ from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -68,6 +69,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/callback")
|
||||
class DatasourceOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider_id: str):
|
||||
@@ -123,6 +125,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -165,6 +168,7 @@ class DatasourceAuth(Resource):
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -188,6 +192,7 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -213,6 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/list")
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -225,6 +231,7 @@ class DatasourceAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/default-list")
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -237,6 +244,7 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -271,6 +279,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -291,6 +300,7 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -311,52 +321,3 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOAuthCallback,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource/<path:provider_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceHardCodeAuthListApi,
|
||||
"/auth/plugin/datasource/default-list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDefaultApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
@@ -13,6 +13,7 @@ from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -49,9 +50,3 @@ class DataSourceContentPreviewApi(Resource):
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
@@ -20,18 +20,19 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
def _validate_description_length(description: str) -> str:
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -45,6 +46,7 @@ class PipelineTemplateListApi(Resource):
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -57,6 +59,7 @@ class PipelineTemplateDetailApi(Resource):
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -73,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -112,6 +115,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -129,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -144,21 +148,3 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
return import_info, 201
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
@@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
@@ -104,13 +76,14 @@ def _api_prerequisite(f):
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
return None
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class RagPipelineVariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
@@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
@@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
return draft_vars
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
@@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -66,6 +67,7 @@ class RagPipelineImportApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -90,6 +92,7 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -128,22 +132,3 @@ class RagPipelineExportApi(Resource):
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
|
||||
@@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
@@ -51,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -148,6 +148,7 @@ class DraftRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -182,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -216,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -250,6 +253,7 @@ class DraftRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -370,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
#
|
||||
# return result
|
||||
#
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -412,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -454,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -487,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -505,6 +513,7 @@ class RagPipelineTaskStopApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
|
||||
class PublishedRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -560,6 +569,7 @@ class PublishedRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -578,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -609,18 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
class RagPipelineConfigApi(Resource):
|
||||
"""Resource for rag pipeline configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, pipeline_id):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -669,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -726,6 +727,7 @@ class RagPipelineByIdApi(Resource):
|
||||
return workflow
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -751,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -776,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -801,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -827,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -848,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
|
||||
class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -866,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -889,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/datasource-plugins")
|
||||
class DatasourceListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -904,6 +913,7 @@ class DatasourceListApi(Resource):
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class RagPipelineWorkflowLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -925,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
|
||||
class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -942,6 +953,7 @@ class RagPipelineTransformApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -971,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/recommended-plugins")
|
||||
class RagPipelineRecommendedPluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -979,118 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
|
||||
return recommended_plugins
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTaskStopApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDraftNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelinePublishedDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftDatasourceNodeRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
PublishedRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllRagPipelineApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigsApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineByIdApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunListApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunDetailApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowRunNodeExecutionListApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceListApi,
|
||||
"/rag/pipelines/datasource-plugins",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineFirstStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineFirstStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineTransformApi,
|
||||
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineDatasourceVariableApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineRecommendedPluginApi,
|
||||
"/rag/pipelines/recommended-plugins",
|
||||
)
|
||||
|
||||
@@ -26,9 +26,15 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
|
||||
endpoint="installed_app_audio",
|
||||
)
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -65,6 +71,10 @@ class ChatAudioApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/text-to-audio",
|
||||
endpoint="installed_app_text",
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
from flask_restx import reqparse
|
||||
|
||||
@@ -33,10 +33,16 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages",
|
||||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -87,6 +93,10 @@ class CompletionApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
@@ -100,6 +110,10 @@ class CompletionStopApi(InstalledAppResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages",
|
||||
endpoint="installed_app_chat_completion",
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -153,6 +167,10 @@ class ChatApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -16,7 +16,13 @@ from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations",
|
||||
endpoint="installed_app_conversations",
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
@@ -52,6 +58,10 @@ class ConversationListApi(InstalledAppResource):
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
@@ -70,6 +80,10 @@ class ConversationApi(InstalledAppResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, installed_app, c_id):
|
||||
@@ -95,6 +109,10 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
@@ -114,6 +132,10 @@ class ConversationPinApi(InstalledAppResource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -36,9 +36,15 @@ from services.errors.message import (
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages",
|
||||
endpoint="installed_app_messages",
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
@@ -66,6 +72,10 @@ class MessageListApi(InstalledAppResource):
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
def post(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -93,6 +103,10 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -139,6 +153,10 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
|
||||
@@ -27,9 +27,12 @@ from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
"""
|
||||
@@ -70,6 +73,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
"""
|
||||
|
||||
@@ -26,9 +26,12 @@ from libs.login import login_required
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@console_ns.route("/files/upload")
|
||||
class FileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -88,6 +91,7 @@ class FileApi(Resource):
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@console_ns.route("/files/<uuid:file_id>/preview")
|
||||
class FilePreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -98,6 +102,7 @@ class FilePreviewApi(Resource):
|
||||
return {"content": text}
|
||||
|
||||
|
||||
@console_ns.route("/files/support-type")
|
||||
class FileSupportTypeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -19,7 +19,10 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
@@ -35,6 +38,7 @@ class RemoteFileInfoApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
@@ -10,9 +9,12 @@ from controllers.console.wraps import (
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
from libs.login import login_required
|
||||
|
||||
from . import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/spec/schema-definitions")
|
||||
class SpecSchemaDefinitionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -30,6 +32,3 @@ class SpecSchemaDefinitionsApi(Resource):
|
||||
logger.exception("Failed to get schema definitions from local registry")
|
||||
# Return empty array as fallback
|
||||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
@@ -17,6 +17,7 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -52,6 +53,7 @@ class TagListApi(Resource):
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource):
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, "/tags")
|
||||
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
|
||||
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
|
||||
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from packaging import version
|
||||
|
||||
@@ -57,7 +57,11 @@ class VersionApi(Resource):
|
||||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10))
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args["current_version"]
|
||||
|
||||
@@ -516,20 +516,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
parser.add_argument("provider", type=str, required=True, location="args")
|
||||
parser.add_argument("action", type=str, required=True, location="args")
|
||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="args")
|
||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
@@ -21,8 +21,8 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
@@ -1,589 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if there's a system OAuth client
|
||||
system_client = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_configured": bool(custom_params),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params or {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@@ -24,20 +24,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_model = None
|
||||
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not user_model:
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
@@ -46,11 +40,21 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
|
||||
@@ -9,9 +9,10 @@ from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
class MCPRequestError(Exception):
|
||||
@@ -194,6 +195,50 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
||||
"""Get end user from existing session - optimized query"""
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
||||
) -> EndUser:
|
||||
"""Create end user in existing session"""
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="mcp",
|
||||
name=client_name,
|
||||
session_id=mcp_server_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.flush() # Use flush instead of commit to keep transaction open
|
||||
session.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
def _handle_mcp_request(
|
||||
self,
|
||||
app: App,
|
||||
mcp_server: AppMCPServer,
|
||||
mcp_request: mcp_types.ClientRequest,
|
||||
user_input_form: list[VariableEntity],
|
||||
session: Session,
|
||||
request_id: Union[int, str],
|
||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
||||
"""Handle MCP request and return response"""
|
||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
||||
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
@@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
@@ -31,12 +32,6 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = reqparse.RequestParser()
|
||||
dataset_create_parser.add_argument(
|
||||
@@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
|
||||
)
|
||||
dataset_create_parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
@@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
|
||||
type=_validate_name,
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"description", location="json", store_missing=False, type=_validate_description_length
|
||||
"description", location="json", store_missing=False, type=validate_description_length
|
||||
)
|
||||
dataset_update_parser.add_argument(
|
||||
"indexing_technique",
|
||||
@@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource):
|
||||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
@@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
@@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource):
|
||||
data["embedding_available"] = True
|
||||
|
||||
# force update search method to keyword_search if indexing_technique is economic
|
||||
data["retrieval_model_dict"]["search_method"] = "keyword_search"
|
||||
retrieval_model_dict = data.get("retrieval_model_dict")
|
||||
if retrieval_model_dict:
|
||||
retrieval_model_dict["search_method"] = "keyword_search"
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
@@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource):
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
if (
|
||||
data.get("retrieval_model")
|
||||
and data.get("retrieval_model").get("reranking_model")
|
||||
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
@@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
@@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import EndUser
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.file_service import FileService
|
||||
@@ -109,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
if args.get("embedding_model_provider"):
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
|
||||
)
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
@@ -188,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
args.get("retrieval_model")
|
||||
and args.get("retrieval_model").get("reranking_model")
|
||||
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
@@ -311,8 +313,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
@@ -406,9 +406,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
|
||||
@@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
@@ -1,41 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
raise NotFound("Endpoint not found")
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@@ -1,46 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
# Get webhook trigger, workflow, and node configuration
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||
|
||||
# Extract request data
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate request against node configuration
|
||||
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
|
||||
if not validation_result["valid"]:
|
||||
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
@@ -261,6 +261,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
# questions is a list of strings, not a list of Message objects
|
||||
# so we can directly return it
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
|
||||
@@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -551,7 +551,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from configs import dify_config
|
||||
@@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
@@ -35,9 +39,8 @@ class AppQueueManager:
|
||||
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
|
||||
redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
@@ -79,9 +82,21 @@ class AppQueueManager:
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._clear_task_belong_cache()
|
||||
self._q.put(None)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom):
|
||||
def _clear_task_belong_cache(self) -> None:
|
||||
"""
|
||||
Remove the task belong cache key once listening is finished.
|
||||
"""
|
||||
try:
|
||||
redis_client.delete(self._task_belong_cache_key)
|
||||
except RedisError:
|
||||
logger.exception(
|
||||
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
|
||||
)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
|
||||
@@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
context=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
|
||||
@@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -53,8 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -68,8 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -83,8 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@@ -97,8 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@@ -127,26 +119,24 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||
# start node get inputs
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=inputs,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@@ -165,10 +155,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@@ -195,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@@ -210,7 +196,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -246,7 +231,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -440,16 +424,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
@@ -482,7 +465,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -34,7 +34,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -44,7 +43,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@@ -53,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@@ -107,7 +87,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
root_node_id=self._root_node_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -79,7 +80,6 @@ class WorkflowBasedAppRunner:
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
@@ -113,22 +113,88 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
(either single iteration or single loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
|
||||
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool, graph_runtime_state)
|
||||
|
||||
Raises:
|
||||
ValueError: If neither single_iteration_run nor single_loop_run is specified
|
||||
"""
|
||||
# Create initial runtime state with variable pool containing environment variables
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
),
|
||||
start_at=time.time(),
|
||||
)
|
||||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
||||
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
Get graph and variable pool for single node execution (iteration or loop).
|
||||
|
||||
Args:
|
||||
workflow: The workflow instance
|
||||
node_id: The node ID to execute
|
||||
user_inputs: User inputs for the node
|
||||
graph_runtime_state: The graph runtime state
|
||||
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
|
||||
node_type_label: Label for error messages ('iteration' or 'loop')
|
||||
|
||||
Returns:
|
||||
A tuple containing (graph, variable_pool)
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
@@ -146,23 +212,27 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
@@ -191,30 +261,26 @@ class WorkflowBasedAppRunner:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_version = iteration_node_config.get("data", {}).get("version", "1")
|
||||
node_type = NodeType(target_node_config.get("data", {}).get("type"))
|
||||
node_version = target_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
@@ -235,120 +301,44 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue]
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue]
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType]
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
@@ -1,9 +0,0 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,43 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
@@ -1,35 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
@@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get custom provider record.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderCredential.id).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
@@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session=session,
|
||||
query_factory=lambda: select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
),
|
||||
@@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
|
||||
logger.warning("Error generating next credential name: %s", str(e))
|
||||
return "API KEY 1"
|
||||
|
||||
def _get_provider_names(self):
|
||||
"""
|
||||
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
@@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Find all load balancing configs that use this credential_id
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source,
|
||||
)
|
||||
@@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
@@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Check if this credential is used in load balancing configs
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
)
|
||||
@@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the provider record
|
||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
@@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
)
|
||||
@@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# if this is the last credential, we need to delete the custom model record
|
||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
@@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
@@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
def _switch(s: Session):
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
stmt = select(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
@@ -207,7 +207,6 @@ class ProviderConfig(BasicProviderConfig):
|
||||
required: bool = False
|
||||
default: Union[int, str, float, bool] | None = None
|
||||
options: list[Option] | None = None
|
||||
multiple: bool | None = False
|
||||
label: I18nObject | None = None
|
||||
help: I18nObject | None = None
|
||||
url: str | None = None
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import StrEnum
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from httpx import Timeout, post
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
@@ -13,9 +13,17 @@ from core.helper.code_executor.javascript.javascript_transformer import NodeJsTe
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
CODE_EXECUTION_SSL_VERIFY = dify_config.CODE_EXECUTION_SSL_VERIFY
|
||||
_CODE_EXECUTOR_CLIENT_LIMITS = httpx.Limits(
|
||||
max_connections=dify_config.CODE_EXECUTION_POOL_MAX_CONNECTIONS,
|
||||
max_keepalive_connections=dify_config.CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS,
|
||||
keepalive_expiry=dify_config.CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY,
|
||||
)
|
||||
_CODE_EXECUTOR_CLIENT_KEY = "code_executor:http_client"
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
@@ -38,6 +46,13 @@ class CodeLanguage(StrEnum):
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
def _build_code_executor_client() -> httpx.Client:
|
||||
return httpx.Client(
|
||||
verify=CODE_EXECUTION_SSL_VERIFY,
|
||||
limits=_CODE_EXECUTOR_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
dependencies_cache: dict[str, str] = {}
|
||||
dependencies_cache_lock = Lock()
|
||||
@@ -76,17 +91,21 @@ class CodeExecutor:
|
||||
"enable_network": True,
|
||||
}
|
||||
|
||||
timeout = httpx.Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None,
|
||||
)
|
||||
|
||||
client = get_pooled_http_client(_CODE_EXECUTOR_CLIENT_KEY, _build_code_executor_client)
|
||||
|
||||
try:
|
||||
response = post(
|
||||
response = client.post(
|
||||
str(url),
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionError("Code execution service is unavailable")
|
||||
@@ -106,8 +125,8 @@ class CodeExecutor:
|
||||
|
||||
try:
|
||||
response_data = response.json()
|
||||
except:
|
||||
raise CodeExecutionError("Failed to parse response")
|
||||
except Exception as e:
|
||||
raise CodeExecutionError("Failed to parse response") from e
|
||||
|
||||
if (code := response_data.get("code")) != 0:
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
|
||||
|
||||
59
api/core/helper/http_client_pooling.py
Normal file
59
api/core/helper/http_client_pooling.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""HTTP client pooling utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
|
||||
ClientBuilder = Callable[[], httpx.Client]
|
||||
|
||||
|
||||
class HttpClientPoolFactory:
|
||||
"""Thread-safe factory that maintains reusable HTTP client instances."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._clients: dict[str, httpx.Client] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_or_create(self, key: str, builder: ClientBuilder) -> httpx.Client:
|
||||
"""Return a pooled client associated with ``key`` creating it on demand."""
|
||||
client = self._clients.get(key)
|
||||
if client is not None:
|
||||
return client
|
||||
|
||||
with self._lock:
|
||||
client = self._clients.get(key)
|
||||
if client is None:
|
||||
client = builder()
|
||||
self._clients[key] = client
|
||||
return client
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""Close all pooled clients and clear the pool."""
|
||||
with self._lock:
|
||||
for client in self._clients.values():
|
||||
client.close()
|
||||
self._clients.clear()
|
||||
|
||||
|
||||
_factory = HttpClientPoolFactory()
|
||||
|
||||
|
||||
def get_pooled_http_client(key: str, builder: ClientBuilder) -> httpx.Client:
|
||||
"""Return a pooled client for the given ``key`` using ``builder`` when missing."""
|
||||
return _factory.get_or_create(key, builder)
|
||||
|
||||
|
||||
def close_all_pooled_clients() -> None:
|
||||
"""Close every client created through the pooling factory."""
|
||||
_factory.close_all()
|
||||
|
||||
|
||||
def _register_shutdown_hook() -> None:
|
||||
atexit.register(close_all_pooled_clients)
|
||||
|
||||
|
||||
_register_shutdown_hook()
|
||||
@@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
@@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
@@ -8,27 +8,23 @@ import time
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
|
||||
try:
|
||||
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(config_value).lower()
|
||||
if http_request_node_ssl_verify_lower == "true":
|
||||
http_request_node_ssl_verify = True
|
||||
elif http_request_node_ssl_verify_lower == "false":
|
||||
http_request_node_ssl_verify = False
|
||||
else:
|
||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||
except NameError:
|
||||
http_request_node_ssl_verify = True
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
_SSRF_CLIENT_LIMITS = httpx.Limits(
|
||||
max_connections=dify_config.SSRF_POOL_MAX_CONNECTIONS,
|
||||
max_keepalive_connections=dify_config.SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS,
|
||||
keepalive_expiry=dify_config.SSRF_POOL_KEEPALIVE_EXPIRY,
|
||||
)
|
||||
|
||||
|
||||
class MaxRetriesExceededError(ValueError):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
@@ -36,6 +32,45 @@ class MaxRetriesExceededError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
|
||||
return {
|
||||
"http://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTP_URL,
|
||||
),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _build_ssrf_client(verify: bool) -> httpx.Client:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
return httpx.Client(
|
||||
mounts=_create_proxy_mounts(),
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
return httpx.Client(verify=verify, limits=_SSRF_CLIENT_LIMITS)
|
||||
|
||||
|
||||
def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
||||
if not isinstance(ssl_verify_enabled, bool):
|
||||
raise ValueError("SSRF client verify flag must be a boolean")
|
||||
|
||||
return get_pooled_http_client(
|
||||
_SSL_VERIFIED_POOL_KEY if ssl_verify_enabled else _SSL_UNVERIFIED_POOL_KEY,
|
||||
lambda: _build_ssrf_client(verify=ssl_verify_enabled),
|
||||
)
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
@@ -50,33 +85,22 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
)
|
||||
|
||||
if "ssl_verify" not in kwargs:
|
||||
kwargs["ssl_verify"] = http_request_node_ssl_verify
|
||||
|
||||
ssl_verify = kwargs.pop("ssl_verify")
|
||||
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
|
||||
}
|
||||
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
else:
|
||||
with httpx.Client(verify=ssl_verify) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
return response
|
||||
else:
|
||||
logger.warning(
|
||||
"Received status code %s for URL %s which is in the force list", response.status_code, url
|
||||
"Received status code %s for URL %s which is in the force list",
|
||||
response.status_code,
|
||||
url,
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
|
||||
@@ -28,7 +28,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
@@ -462,19 +461,18 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
|
||||
if not raw_agent_log:
|
||||
return []
|
||||
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
|
||||
|
||||
def dict_of_event(event: AgentLogEvent):
|
||||
return {
|
||||
"status": event.status,
|
||||
"error": event.error,
|
||||
"data": event.data,
|
||||
return [
|
||||
{
|
||||
"status": event["status"],
|
||||
"error": event["error"],
|
||||
"data": event["data"],
|
||||
}
|
||||
|
||||
return [dict_of_event(event) for event in parsed]
|
||||
for event in raw_agent_log
|
||||
]
|
||||
|
||||
inputs = last_run.load_full_inputs(session, storage)
|
||||
last_run_dict = {
|
||||
|
||||
@@ -74,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
@@ -95,11 +95,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
@@ -111,12 +111,12 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
|
||||
@@ -15,7 +15,7 @@ class GPT2Tokenizer:
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
tokens = _tokenizer.encode(text) # type: ignore
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user