Compare commits

..

1 Commits

Author SHA1 Message Date
Stephen Zhou
dacba93e00 chore: dev with http2 2026-03-17 16:31:47 +08:00
2878 changed files with 49889 additions and 155038 deletions

View File

@@ -63,8 +63,7 @@ pnpm analyze-component <path> --review
### File Naming
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Integration tests: `web/__tests__/` directory
## Test Structure Template

View File

@@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = vi.fn()
// vi.mock('@/next/navigation', () => ({
// vi.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))

View File

@@ -1,13 +0,0 @@
have_fun: false
memory_config:
disabled: false
code_review:
disable: true
comment_severity_threshold: MEDIUM
max_review_comments: -1
pull_request_opened:
help: false
summary: false
code_review: false
include_drafts: false
ignore_patterns: []

2
.github/CODEOWNERS vendored
View File

@@ -36,7 +36,7 @@
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
/api/graphon/model_runtime/ @laipz8200 @WH-2099
/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
/api/core/workflow/nodes/agent/ @Nov1c444

View File

@@ -4,9 +4,10 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
uses: voidzero-dev/setup-vp@b5d848f5a62488f3d3d920f8aa6ac318a60c5f07 # v1
with:
working-directory: web
node-version-file: .nvmrc
node-version-file: "./web/.nvmrc"
cache: true
run-install: true
run-install: |
- cwd: ./web
args: ['--frozen-lockfile']

View File

@@ -12,7 +12,7 @@ jobs:
anti-slop:
runs-on: ubuntu-latest
steps:
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
- uses: peakoss/anti-slop@v0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
close-pr: false

View File

@@ -2,12 +2,6 @@ name: Run Pytest
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}
@@ -17,8 +11,6 @@ jobs:
test:
name: API Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
@@ -32,11 +24,10 @@ jobs:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -88,12 +79,21 @@ jobs:
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
files: ./coverage.xml
disable_search: true
flags: api
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@@ -39,7 +39,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -94,6 +94,11 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- name: Setup web environment
if: steps.web-changes.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web

View File

@@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: "3.12"
@@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: "3.12"

View File

@@ -56,14 +56,16 @@ jobs:
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
secrets: inherit
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
secrets: inherit
with:
base_sha: ${{ github.event.before || github.event.pull_request.base.sha }}
diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }}
head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }}
style-check:
name: Style Check

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true

View File

@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: false
python-version: "3.12"
@@ -84,20 +84,20 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run lint:ci
run: |
vp run lint:ci
# pnpm run lint:report
# continue-on-error: true
# - name: Annotate Code
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
# with:
# eslint-report: web/eslint_report.json
# github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
@@ -114,13 +114,6 @@ jobs:
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -31,7 +31,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -2,9 +2,16 @@ name: Web Tests
on:
workflow_call:
secrets:
CODECOV_TOKEN:
inputs:
base_sha:
required: false
type: string
diff_range_mode:
required: false
type: string
head_sha:
required: false
type: string
permissions:
contents: read
@@ -22,8 +29,8 @@ jobs:
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4, 5, 6]
shardTotal: [6]
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
defaults:
run:
shell: bash
@@ -56,7 +63,7 @@ jobs:
needs: [test]
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
VITEST_COVERAGE_SCOPE: app-components
defaults:
run:
shell: bash
@@ -80,16 +87,52 @@ jobs:
merge-multiple: true
- name: Merge reports
run: vp test --merge-reports --coverage --silent=passed-only
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
directory: web/coverage
flags: web
- name: Report app/components baseline coverage
run: node ./scripts/report-components-coverage-baseline.mjs
- name: Report app/components test touch
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
BASE_SHA: ${{ inputs.base_sha }}
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/report-components-test-touch.mjs
- name: Check app/components pure diff coverage
env:
BASE_SHA: ${{ inputs.base_sha }}
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/check-components-diff-coverage.mjs
- name: Check Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error
web-build:
name: Web Build

View File

@@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@@ -353,9 +353,6 @@ BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url

View File

@@ -1,14 +1,10 @@
[importlinter]
root_packages =
core
constants
context
graphon
dify_graph
configs
controllers
extensions
factories
libs
models
tasks
services
@@ -26,30 +22,40 @@ layers =
runtime
entities
containers =
graphon
dify_graph
ignore_imports =
graphon.nodes.base.node -> graphon.graph_events
graphon.nodes.iteration.iteration_node -> graphon.graph_events
graphon.nodes.loop.loop_node -> graphon.graph_events
dify_graph.nodes.base.node -> dify_graph.graph_events
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events
dify_graph.nodes.loop.loop_node -> dify_graph.graph_events
graphon.nodes.iteration.iteration_node -> graphon.graph_engine
graphon.nodes.loop.loop_node -> graphon.graph_engine
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine
# TODO(QuantumGhost): fix the import violation later
graphon.entities.pause_reason -> graphon.nodes.human_input.entities
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
dify_graph
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
type = forbidden
source_modules =
graphon
dify_graph
forbidden_modules =
constants
configs
context
controllers
extensions
factories
libs
models
services
tasks
@@ -82,14 +88,46 @@ forbidden_modules =
core.tools
core.trigger
core.variables
[importlinter:contract:workflow-third-party-imports]
name = Workflow Third-Party Imports
type = forbidden
source_modules =
graphon
forbidden_modules =
sqlalchemy
ignore_imports =
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
dify_graph.nodes.llm.node -> core.tools.signature
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.tool.tool_node -> services
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.__base.large_language_model -> configs
dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
dify_graph.model_runtime.model_providers.model_provider_factory -> configs
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids
[importlinter:contract:rsc]
name = RSC
@@ -98,7 +136,7 @@ layers =
graph_engine
response_coordinator
containers =
graphon.graph_engine
dify_graph.graph_engine
[importlinter:contract:worker]
name = Worker
@@ -107,7 +145,7 @@ layers =
graph_engine
worker
containers =
graphon.graph_engine
dify_graph.graph_engine
[importlinter:contract:graph-engine-architecture]
name = Graph Engine Architecture
@@ -123,28 +161,28 @@ layers =
worker_management
domain
containers =
graphon.graph_engine
dify_graph.graph_engine
[importlinter:contract:domain-isolation]
name = Domain Model Isolation
type = forbidden
source_modules =
graphon.graph_engine.domain
dify_graph.graph_engine.domain
forbidden_modules =
graphon.graph_engine.worker_management
graphon.graph_engine.command_channels
graphon.graph_engine.layers
graphon.graph_engine.protocols
dify_graph.graph_engine.worker_management
dify_graph.graph_engine.command_channels
dify_graph.graph_engine.layers
dify_graph.graph_engine.protocols
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
graphon.graph_engine.worker_management
dify_graph.graph_engine.worker_management
forbidden_modules =
graphon.graph_engine.orchestration
graphon.graph_engine.command_processing
graphon.graph_engine.event_management
dify_graph.graph_engine.orchestration
dify_graph.graph_engine.command_processing
dify_graph.graph_engine.event_management
[importlinter:contract:graph-traversal-components]
@@ -154,11 +192,11 @@ layers =
edge_processor
skip_propagator
containers =
graphon.graph_engine.graph_traversal
dify_graph.graph_engine.graph_traversal
[importlinter:contract:command-channels]
name = Command Channels Independence
type = independence
modules =
graphon.graph_engine.command_channels.in_memory_channel
graphon.graph_engine.command_channels.redis_channel
dify_graph.graph_engine.command_channels.in_memory_channel
dify_graph.graph_engine.command_channels.redis_channel

View File

@@ -100,7 +100,7 @@ ignore = [
"configs/*" = [
"N802", # invalid-function-name
]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name

View File

@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
nickname: NotRequired[str]
```
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python
from datetime import datetime

View File

@@ -1,11 +1,9 @@
import json
import logging
from typing import Any, cast
from typing import Any
import click
from pydantic import TypeAdapter
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from configs import dify_config
from core.helper import encrypter
@@ -50,15 +48,14 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = cast(
CursorResult,
db.session.execute(
delete(ToolOAuthSystemClient).where(
ToolOAuthSystemClient.provider == provider_name,
ToolOAuthSystemClient.plugin_id == plugin_id,
)
),
).rowcount
deleted_count = (
db.session.query(ToolOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@@ -100,15 +97,14 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = cast(
CursorResult,
db.session.execute(
delete(TriggerOAuthSystemClient).where(
TriggerOAuthSystemClient.provider == provider_name,
TriggerOAuthSystemClient.plugin_id == plugin_id,
)
),
).rowcount
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@@ -143,15 +139,14 @@ def setup_datasource_oauth_client(provider, client_params):
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = cast(
CursorResult,
db.session.execute(
delete(DatasourceOauthParamConfig).where(
DatasourceOauthParamConfig.provider == provider_name,
DatasourceOauthParamConfig.plugin_id == plugin_id,
)
),
).rowcount
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
@@ -197,9 +192,7 @@ def transform_datasource_credentials(environment: str):
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.scalars(
select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion")
).all()
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for notion_credential in notion_credentials:
@@ -208,7 +201,7 @@ def transform_datasource_credentials(environment: str):
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
@@ -257,9 +250,7 @@ def transform_datasource_credentials(environment: str):
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl")
).all()
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for firecrawl_credential in firecrawl_credentials:
@@ -268,7 +259,7 @@ def transform_datasource_credentials(environment: str):
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:
@@ -321,9 +312,7 @@ def transform_datasource_credentials(environment: str):
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader")
).all()
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for jina_credential in jina_credentials:
@@ -332,7 +321,7 @@ def transform_datasource_credentials(environment: str):
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
if not tenant:
continue
try:

View File

@@ -1,10 +1,7 @@
import json
from typing import cast
import click
import sqlalchemy as sa
from sqlalchemy import update
from sqlalchemy.engine import CursorResult
from configs import dify_config
from extensions.ext_database import db
@@ -743,17 +740,14 @@ def migrate_oss(
else:
try:
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
updated = cast(
CursorResult,
db.session.execute(
update(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.values(storage_type=dify_config.STORAGE_TYPE)
),
).rowcount
updated = (
db.session.query(UploadFile)
.where(
UploadFile.storage_type == source_storage_type,
UploadFile.key.in_(copied_upload_file_keys),
)
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
)
db.session.commit()
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
except Exception as e:

View File

@@ -2,7 +2,6 @@ import logging
import click
import sqlalchemy as sa
from sqlalchemy import delete, select, update
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@@ -42,7 +41,7 @@ def reset_encrypt_key_pair():
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.scalars(select(Tenant)).all()
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
@@ -50,8 +49,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
click.echo(
click.style(
@@ -94,7 +93,7 @@ def convert_to_agent_apps():
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.scalar(select(App).where(App.id == app_id))
app = db.session.query(App).where(App.id == app_id).first()
if app is not None:
apps.append(app)
@@ -109,8 +108,8 @@ def convert_to_agent_apps():
db.session.commit()
# update conversation mode to agent
db.session.execute(
update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT)
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT}
)
db.session.commit()
@@ -178,7 +177,7 @@ where sites.id is null limit 1000"""
continue
try:
app = db.session.scalar(select(App).where(App.id == app_id))
app = db.session.query(App).where(App.id == app_id).first()
if not app:
logger.info("App %s not found", app_id)
continue

View File

@@ -10,12 +10,10 @@ from configs import dify_config
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -42,13 +40,14 @@ def migrate_annotation_vector_database():
# get apps info
per_page = 50
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = session.scalars(
select(App)
apps = (
session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
).all()
.all()
)
if not apps:
break
except SQLAlchemyError:
@@ -63,8 +62,8 @@ def migrate_annotation_vector_database():
try:
click.echo(f"Creating app annotation index: {app.id}")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1)
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
@@ -72,10 +71,10 @@ def migrate_annotation_vector_database():
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = session.scalar(
select(DatasetCollectionBinding).where(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
)
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
@@ -86,7 +85,7 @@ def migrate_annotation_vector_database():
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
@@ -178,9 +177,7 @@ def migrate_knowledge_vector_database():
while True:
try:
stmt = (
select(Dataset)
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
.order_by(Dataset.created_at.desc())
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@@ -207,11 +204,11 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.execute(
select(DatasetCollectionBinding).where(
DatasetCollectionBinding.id == dataset.collection_binding_id
)
).scalar_one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
@@ -245,7 +242,7 @@ def migrate_knowledge_vector_database():
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
@@ -257,7 +254,7 @@ def migrate_knowledge_vector_database():
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == SegmentStatus.COMPLETED,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
).all()
@@ -272,7 +269,7 @@ def migrate_knowledge_vector_database():
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if dataset_document.doc_form == "hierarchical_model":
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
@@ -336,7 +333,7 @@ def add_qdrant_index(field: str):
create_count = 0
try:
bindings = db.session.scalars(select(DatasetCollectionBinding)).all()
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
@@ -423,22 +420,22 @@ def old_metadata_migration():
if field.value == key:
break
else:
dataset_metadata = db.session.scalar(
select(DatasetMetadata)
dataset_metadata = (
db.session.query(DatasetMetadata)
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.limit(1)
.first()
)
if not dataset_metadata:
dataset_metadata = DatasetMetadata(
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
name=key,
type=DatasetMetadataType.STRING,
type="string",
created_by=document.created_by,
)
db.session.add(dataset_metadata)
db.session.flush()
dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding(
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
metadata_id=dataset_metadata.id,
@@ -447,14 +444,14 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata_binding)
else:
dataset_metadata_binding = db.session.scalar(
select(DatasetMetadataBinding)
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
.where(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.limit(1)
.first()
)
if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding(

View File

@@ -1,4 +1,4 @@
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic_settings import BaseSettings
@@ -116,13 +116,3 @@ class RedisConfig(BaseSettings):
description="Maximum connections in the Redis connection pool (unset for library default)",
default=None,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):
"""Allow empty string in env/.env to mean 'unset' (None)."""
if v is None:
return None
if isinstance(v, str) and v.strip() == "":
return None
return v

View File

@@ -1,4 +1,4 @@
from typing import Literal, Protocol, cast
from typing import Literal, Protocol
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
@@ -12,13 +12,16 @@ class RedisConfigDefaults(Protocol):
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
def _redis_defaults(config: object) -> RedisConfigDefaults:
return cast(RedisConfigDefaults, config)
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
class RedisPubSubConfig(BaseSettings):
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
"""
Configuration settings for event transport between API and workers.
@@ -71,7 +74,7 @@ class RedisPubSubConfig(BaseSettings):
)
def _build_default_pubsub_url(self) -> str:
defaults = _redis_defaults(self)
defaults = self._redis_defaults()
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
@@ -88,9 +91,11 @@ class RedisPubSubConfig(BaseSettings):
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
netloc = f"{userinfo}{host}:{port}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property

View File

@@ -51,18 +51,3 @@ class BaiduVectorDBConfig(BaseSettings):
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
default="COARSE_MODE",
)
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
description="Auto build row count increment threshold (default is 500)",
default=500,
)
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
description="Auto build row count increment ratio threshold (default is 0.05)",
default=0.05,
)
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
default=300,
)

View File

@@ -1,36 +1,74 @@
"""
Application-layer context adapters.
Core Context - Framework-agnostic context management.
Concrete execution-context implementations live here so `graphon` only
depends on injected context managers rather than framework state capture.
This module provides context management that is independent of any specific
web framework. Framework-specific implementations register their context
capture functions at application initialization time.
This ensures the workflow layer remains completely decoupled from Flask
or any other web framework.
"""
from context.execution_context import (
AppContext,
ContextProviderNotFoundError,
import contextvars
from collections.abc import Callable
from dify_graph.context.execution_context import (
ExecutionContext,
ExecutionContextBuilder,
IExecutionContext,
NullAppContext,
capture_current_context,
read_context,
register_context,
register_context_capturer,
reset_context_provider,
)
from context.models import SandboxContext
# Global capturer function - set by framework-specific modules
_capturer: Callable[[], IExecutionContext] | None = None
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""
Register a context capture function.
This should be called by framework-specific modules (e.g., Flask)
during application initialization.
Args:
capturer: Function that captures current context and returns IExecutionContext
"""
global _capturer
_capturer = capturer
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context.
This function uses the registered context capturer. If no capturer
is registered, it returns a minimal context with only contextvars
(suitable for non-framework environments like tests or standalone scripts).
Returns:
IExecutionContext with captured context
"""
if _capturer is None:
# No framework registered - return minimal context
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
def reset_context_provider() -> None:
"""
Reset the context capturer.
This is primarily useful for testing to ensure a clean state.
"""
global _capturer
_capturer = None
__all__ = [
"AppContext",
"ContextProviderNotFoundError",
"ExecutionContext",
"ExecutionContextBuilder",
"IExecutionContext",
"NullAppContext",
"SandboxContext",
"capture_current_context",
"read_context",
"register_context",
"register_context_capturer",
"reset_context_provider",
]

View File

@@ -10,7 +10,11 @@ from typing import Any, final
from flask import Flask, current_app, g
from context.execution_context import AppContext, IExecutionContext, register_context_capturer
from dify_graph.context import register_context_capturer
from dify_graph.context.execution_context import (
AppContext,
IExecutionContext,
)
@final

View File

@@ -6,6 +6,7 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
@@ -19,6 +20,14 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)

View File

@@ -4,7 +4,7 @@ from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, computed_field
from graphon.file import helpers as file_helpers
from dify_graph.file import helpers as file_helpers
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]

View File

@@ -1,7 +1,7 @@
import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import delete, func, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -9,7 +9,6 @@ from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
@@ -34,10 +33,16 @@ api_key_list_model = console_ns.model(
def _get_resource(resource_id, tenant_id, resource_model):
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
@@ -48,7 +53,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: ApiTokenType | None = None
resource_type: str | None = None
resource_model: type | None = None
resource_id_field: str | None = None
token_prefix: str | None = None
@@ -75,13 +80,10 @@ class BaseApiKeyListResource(Resource):
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count: int = (
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
)
or 0
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)
if current_key_count >= self.max_keys:
@@ -92,7 +94,6 @@ class BaseApiKeyListResource(Resource):
)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
assert self.resource_type is not None, "resource_type must be set"
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_tenant_id
@@ -106,7 +107,7 @@ class BaseApiKeyListResource(Resource):
class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: ApiTokenType | None = None
resource_type: str | None = None
resource_model: type | None = None
resource_id_field: str | None = None
@@ -118,14 +119,14 @@ class BaseApiKeyResource(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
key = db.session.scalar(
select(ApiToken)
key = (
db.session.query(ApiToken)
.where(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.limit(1)
.first()
)
if key is None:
@@ -136,7 +137,7 @@ class BaseApiKeyResource(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204
@@ -161,7 +162,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
resource_type = ApiTokenType.APP
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
@@ -177,7 +178,7 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
resource_type = ApiTokenType.APP
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -201,7 +202,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
resource_type = ApiTokenType.DATASET
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
token_prefix = "ds-"
@@ -217,6 +218,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
resource_type = ApiTokenType.DATASET
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@@ -26,9 +26,9 @@ from controllers.console.wraps import (
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
@@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: IconType | None = Field(default=None, description="Icon type")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
@@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: IconType | None = Field(default=None, description="Icon type")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@@ -594,7 +594,7 @@ class AppApi(Resource):
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type,
"icon_type": args.icon_type or "",
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,

View File

@@ -22,7 +22,7 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from models import App, AppMode
from services.audio_service import AudioService

View File

@@ -26,7 +26,7 @@ from core.errors.error import (
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from graphon.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user, login_required

View File

@@ -5,7 +5,7 @@ from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@@ -376,12 +376,8 @@ class CompletionConversationApi(Resource):
# FIXME, the type ignore in this file
if args.annotation_status == "annotated":
query = (
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
query = (
@@ -458,7 +454,9 @@ class ChatConversationApi(Resource):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
)
@@ -513,12 +511,8 @@ class ChatConversationApi(Resource):
match args.annotation_status:
case "annotated":
query = (
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
case "not_annotated":
query = (
@@ -593,8 +587,10 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:

View File

@@ -18,8 +18,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from dify_graph.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
@@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.get(App, args.flow_id)
app = db.session.query(App).where(App.id == args.flow_id).first()
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@@ -2,7 +2,6 @@ import json
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@@ -48,7 +47,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server
@console_ns.doc("create_app_mcp_server")
@@ -99,18 +98,18 @@ class AppMCPServerController(Resource):
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.get(AppMCPServer, payload.id)
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
if not server:
raise NotFound()
description = payload.description
if description is None or not description:
if description is None:
pass
elif not description:
server.description = app_model.description or ""
else:
server.description = description
server.name = app_model.name
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
try:
@@ -136,10 +135,11 @@ class AppMCPServerRefreshController(Resource):
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
.limit(1)
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
)
if not server:
raise NotFound()

View File

@@ -4,7 +4,7 @@ from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -24,13 +24,12 @@ from controllers.console.wraps import (
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.raws import FilesContainedField
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -244,25 +243,27 @@ class ChatMessageListApi(Resource):
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = db.session.scalar(
select(Conversation)
conversation = (
db.session.query(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.limit(1)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
first_message = db.session.scalar(
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
)
if not first_message:
raise NotFound("First message not found")
history_messages = db.session.scalars(
select(Message)
history_messages = (
db.session.query(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
@@ -270,14 +271,16 @@ class ChatMessageListApi(Resource):
)
.order_by(Message.created_at.desc())
.limit(args.limit)
).all()
.all()
)
else:
history_messages = db.session.scalars(
select(Message)
history_messages = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
).all()
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
@@ -322,9 +325,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args.message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
@@ -334,7 +335,7 @@ class MessageFeedbackApi(Resource):
if not args.rating and feedback:
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = FeedbackRating(args.rating)
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
@@ -346,9 +347,9 @@ class MessageFeedbackApi(Resource):
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=FeedbackRating(rating_value),
rating=rating_value,
content=args.content,
from_source=FeedbackFromSource.ADMIN,
from_source="admin",
from_account_id=current_user.id,
)
db.session.add(feedback)
@@ -373,9 +374,7 @@ class MessageAnnotationCountApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
return {"count": count}
@@ -479,9 +478,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

View File

@@ -69,7 +69,9 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict
@@ -88,7 +90,6 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@@ -128,7 +129,6 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
except Exception:
continue

View File

@@ -2,7 +2,6 @@ from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@@ -76,7 +75,7 @@ class AppSite(Resource):
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
@@ -125,7 +124,7 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound

View File

@@ -7,7 +7,7 @@ from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
@@ -20,7 +20,6 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.file_access import DatabaseFileAccessController
from core.helper.trace_id_helper import get_external_trace_id
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
@@ -30,15 +29,15 @@ from core.trigger.debug.event_selectors import (
create_event_poller,
select_trigger_debug_events,
)
from dify_graph.enums import NodeType
from dify_graph.file.models import File
from dify_graph.graph_engine.manager import GraphEngineManager
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from graphon.enums import NodeType
from graphon.file.models import File
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
@@ -47,15 +46,13 @@ from models import App
from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -206,7 +203,6 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
access_controller=_file_access_controller,
)
return file_objs
@@ -288,9 +284,7 @@ class DraftWorkflowApi(Resource):
workflow_service = WorkflowService()
try:
environment_variables_list = Workflow.normalize_environment_variable_mappings(
args.get("environment_variables") or [],
)
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
@@ -1000,43 +994,6 @@ class PublishedAllWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
class DraftWorkflowRestoreApi(Resource):
@console_ns.doc("restore_workflow_to_draft")
@console_ns.doc(description="Restore a published workflow version into the draft workflow")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"})
@console_ns.response(200, "Workflow restored successfully")
@console_ns.response(400, "Source workflow must be published")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, workflow_id: str):
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
try:
workflow = workflow_service.restore_published_workflow_to_draft(
app_model=app_model,
workflow_id=workflow_id,
account=current_user,
)
except IsDraftWorkflowError as exc:
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")

View File

@@ -9,12 +9,12 @@ from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from dify_graph.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import (
build_workflow_app_log_pagination_model,
build_workflow_archived_log_pagination_model,
)
from graphon.enums import WorkflowExecutionStatus
from libs.login import login_required
from models import App
from models.model import AppMode

View File

@@ -15,15 +15,14 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.file_access import DatabaseFileAccessController
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.file import helpers as file_helpers
from dify_graph.variables.segment_group import SegmentGroup
from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment
from dify_graph.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from graphon.file import helpers as file_helpers
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
from graphon.variables.types import SegmentType
from libs.login import current_user, login_required
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
@@ -31,7 +30,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -391,21 +389,13 @@ class VariableApi(Resource):
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@@ -12,7 +12,8 @@ from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@@ -26,8 +27,6 @@ from fields.workflow_run_fields import (
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import WorkflowExecutionStatus
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
@@ -497,9 +496,6 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
form_tokens_by_form_id = _load_form_tokens_by_form_id(
[reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)]
)
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
@@ -518,9 +514,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
"pause_type": {
"type": "human_input",
"form_id": reason.form_id,
"backstage_input_url": _build_backstage_input_url(
form_tokens_by_form_id.get(reason.form_id)
),
"backstage_input_url": _build_backstage_input_url(reason.form_token),
},
}
)

View File

@@ -2,8 +2,6 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@@ -17,14 +15,16 @@ R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
return app_model

View File

@@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import languages
@@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "")
normalized_email = email.lower()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:

View File

@@ -4,7 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
@@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
@@ -215,6 +215,7 @@ class ForgotPasswordResetApi(Resource):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
# Create workspace if needed
if (

View File

@@ -1,10 +1,9 @@
import logging
import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@@ -113,9 +112,6 @@ class OAuthCallback(Resource):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400
except ValueError as e:
logger.warning("OAuth error with %s", provider, exc_info=True)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
@@ -180,7 +176,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import OAuthProviderApp

View File

@@ -3,7 +3,7 @@ from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -25,12 +25,12 @@ from controllers.console.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
@@ -51,11 +51,9 @@ from fields.dataset_fields import (
weighted_score_fields,
)
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@@ -332,7 +330,7 @@ class DatasetListApi(Resource):
)
# check embedding setting
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -356,7 +354,7 @@ class DatasetListApi(Resource):
for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
@@ -437,7 +435,7 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
data["embedding_model_provider"] = str(provider_id)
@@ -446,7 +444,7 @@ class DatasetApi(Resource):
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -455,7 +453,7 @@ class DatasetApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
@@ -486,7 +484,7 @@ class DatasetApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
@@ -739,23 +737,18 @@ class DatasetIndexingStatusApi(Resource):
documents_status = []
for document in documents:
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
or 0
.count()
)
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@@ -781,7 +774,7 @@ class DatasetIndexingStatusApi(Resource):
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = "dataset-"
resource_type = ApiTokenType.DATASET
resource_type = "dataset"
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@@ -806,12 +799,9 @@ class DatasetApiKeyApi(Resource):
_, current_tenant_id = current_account_with_tenant()
current_key_count = (
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
)
)
or 0
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
)
if current_key_count >= self.max_keys:
@@ -833,7 +823,7 @@ class DatasetApiKeyApi(Resource):
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
class DatasetApiDeleteApi(Resource):
resource_type = ApiTokenType.DATASET
resource_type = "dataset"
@console_ns.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete dataset API key")
@@ -846,14 +836,14 @@ class DatasetApiDeleteApi(Resource):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
key = db.session.scalar(
select(ApiToken)
key = (
db.session.query(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.limit(1)
.first()
)
if key is None:
@@ -864,7 +854,7 @@ class DatasetApiDeleteApi(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.delete(key)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204

View File

@@ -10,7 +10,7 @@ import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, func, select
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -27,7 +27,8 @@ from core.model_manager import ModelManager
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from extensions.ext_database import db
from fields.dataset_fields import dataset_fields
from fields.document_fields import (
@@ -37,13 +38,10 @@ from fields.document_fields import (
document_status_fields,
document_with_segments_fields,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
@@ -212,11 +210,12 @@ class GetProcessRuleApi(Resource):
raise Forbidden(str(e))
# get the latest process rule
dataset_process_rule = db.session.scalar(
select(DatasetProcessRule)
dataset_process_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
if dataset_process_rule:
mode = dataset_process_rule.mode
@@ -298,7 +297,6 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count":
sub_query = (
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)
@@ -330,23 +328,18 @@ class DatasetDocumentListApi(Resource):
if fetch:
for document in documents:
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
or 0
.count()
)
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
@@ -450,11 +443,11 @@ class DatasetInitApi(Resource):
raise Forbidden()
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=knowledge_config.embedding_model_provider,
@@ -464,7 +457,7 @@ class DatasetInitApi(Resource):
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@@ -510,7 +503,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@@ -523,10 +516,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = db.session.scalar(
select(UploadFile)
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.limit(1)
.first()
)
# raise error if file not found
@@ -580,7 +573,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = []
for document in documents:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
@@ -588,10 +581,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = db.session.scalar(
select(UploadFile)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.limit(1)
.first()
)
if file_detail is None:
@@ -674,28 +667,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents_status = []
for document in documents:
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
or 0
.count()
)
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@@ -728,29 +716,24 @@ class DocumentIndexingStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
)
or 0
.count()
)
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@@ -972,7 +955,7 @@ class DocumentProcessingApi(DocumentResource):
match action:
case "pause":
if document.indexing_status != IndexingStatus.INDEXING:
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
@@ -981,7 +964,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
case "resume":
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
@@ -1186,7 +1169,7 @@ class DocumentRetryApi(DocumentResource):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == IndexingStatus.COMPLETED:
if document.indexing_status == "completed":
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:
@@ -1268,11 +1251,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = db.session.scalar(
select(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document_id)
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.limit(1)
.first()
)
if not log:
return {
@@ -1338,7 +1321,7 @@ class DocumentGenerateSummaryApi(Resource):
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"

View File

@@ -26,11 +26,10 @@ from controllers.console.wraps import (
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
@@ -46,7 +45,7 @@ def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict = dict(marshal(segment, segment_fields))
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
@@ -207,7 +206,7 @@ class DatasetDocumentSegmentListApi(Resource):
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict = dict(marshal(segment, segment_fields))
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
@@ -280,10 +279,10 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -334,9 +333,9 @@ class DatasetDocumentSegmentAddApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -384,10 +383,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -402,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@@ -448,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@@ -495,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
@@ -560,19 +559,19 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_dataset_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id)
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
@@ -617,10 +616,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@@ -667,10 +666,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@@ -715,24 +714,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.limit(1)
.first()
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@@ -772,24 +771,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.limit(1)
.first()
)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
from services.knowledge_service import ExternalDatasetTestService
def _build_dataset_detail_model():
@@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel):
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: "BedrockRetrievalSetting"
retrieval_setting: dict[str, object]
query: str
knowledge_id: str

View File

@@ -19,12 +19,11 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
@@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: RetrievalModel | None = None
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None

View File

@@ -10,8 +10,8 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.plugin.impl.oauth import OAuthHandler
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService

View File

@@ -46,8 +46,6 @@ class PipelineTemplateDetailApi(Resource):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
if pipeline_template is None:
return {"error": "Pipeline template not found from upstream service."}, 404
return pipeline_template, 200

View File

@@ -21,12 +21,11 @@ from controllers.console.app.workflow_draft_variable import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.file_access import DatabaseFileAccessController
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from graphon.variables.types import SegmentType
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
@@ -34,7 +33,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
def _create_pagination_parser():
@@ -225,21 +223,13 @@ class RagPipelineVariableApi(Resource):
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@@ -6,7 +6,7 @@ from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
@@ -16,11 +16,7 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.workflow import (
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
workflow_model,
workflow_pagination_model,
)
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
@@ -37,17 +33,16 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs import helper
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
from models.workflow import Workflow
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@@ -208,12 +203,9 @@ class DraftRagPipelineApi(Resource):
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
rag_pipeline_service = RagPipelineService()
try:
environment_variables_list = Workflow.normalize_environment_variable_mappings(
payload.environment_variables or [],
)
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
@@ -221,6 +213,7 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=payload.graph,
@@ -712,36 +705,6 @@ class PublishedAllRagPipelineApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>/restore")
class RagPipelineDraftWorkflowRestoreApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, workflow_id: str):
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
try:
workflow = rag_pipeline_service.restore_published_workflow_to_draft(
pipeline=pipeline,
workflow_id=workflow_id,
account=current_user,
)
except IsDraftWorkflowError as exc:
# Use a stable, predefined message to keep the 400 response consistent
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@setup_required

View File

@@ -2,8 +2,6 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@@ -26,8 +24,10 @@ def get_rag_pipeline(view_func: Callable[P, R]):
del kwargs["pipeline_id"]
pipeline = db.session.scalar(
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
)
if not pipeline:

View File

@@ -19,7 +19,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,

View File

@@ -1,11 +1,9 @@
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.enums import BannerStatus
from models.model import ExporleBanner
@@ -18,18 +16,14 @@ class BannerApi(Resource):
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
# Try to get banners in the requested language
banners = db.session.scalars(
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
).all()
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = db.session.scalars(
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
).all()
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
# Convert banners to serializable format
result = []
for banner in banners:

View File

@@ -24,8 +24,8 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user

View File

@@ -133,15 +133,13 @@ class InstalledAppsListApi(Resource):
def post(self):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.scalar(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
)
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.get(App, payload.app_id)
app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None:
raise NotFound("App entity not found")
@@ -149,10 +147,10 @@ class InstalledAppsListApi(Resource):
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = db.session.scalar(
select(InstalledApp)
installed_app = (
db.session.query(InstalledApp)
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.limit(1)
.first()
)
if installed_app is None:

View File

@@ -21,13 +21,12 @@ from controllers.console.explore.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -117,7 +116,7 @@ class MessageFeedbackApi(InstalledAppResource):
app_model=app_model,
message_id=message_id,
user=current_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:

View File

@@ -4,7 +4,6 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@@ -42,6 +41,8 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.graph_engine.manager import GraphEngineManager
from dify_graph.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.app_fields import (
@@ -59,8 +60,6 @@ from fields.workflow_fields import (
workflow_fields,
workflow_partial_fields,
)
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
@@ -477,7 +476,7 @@ class TrialSitApi(Resource):
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
@@ -542,7 +541,13 @@ class AppWorkflowApi(Resource):
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = db.session.get(Workflow, app_model.workflow_id)
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow

View File

@@ -21,9 +21,9 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.graph_engine.manager import GraphEngineManager
from dify_graph.model_runtime.errors.invoke import InvokeError
from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp

View File

@@ -4,7 +4,6 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
@@ -25,10 +24,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
installed_app = db.session.scalar(
select(InstalledApp)
installed_app = (
db.session.query(InstalledApp)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.limit(1)
.first()
)
if installed_app is None:
@@ -79,7 +78,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
@@ -88,10 +87,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = db.session.scalar(
select(AccountTrialAppRecord)
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.limit(1)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:

View File

@@ -13,9 +13,9 @@ from controllers.common.errors import (
)
from controllers.console import console_ns
from core.helper import ssrf_proxy
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService

View File

@@ -2,7 +2,6 @@ from typing import Literal
from flask import request
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from configs import dify_config
from controllers.fastopenapi import console_router
@@ -101,6 +100,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
def get_setup_status() -> DifySetup | bool | None:
if dify_config.EDITION == "SELF_HOSTED":
return db.session.scalar(select(DifySetup).limit(1))
return db.session.query(DifySetup).first()
return True

View File

@@ -212,13 +212,13 @@ class AccountInitApi(Resource):
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = db.session.scalar(
select(InvitationCode)
invitation_code = (
db.session.query(InvitationCode)
.where(
InvitationCode.code == args.invitation_code,
InvitationCode.status == InvitationCodeStatus.UNUSED,
)
.limit(1)
.first()
)
if not invitation_code:

View File

@@ -2,7 +2,7 @@ from flask_restx import Resource, fields
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from services.agent_service import AgentService

View File

@@ -8,7 +8,7 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginPermissionDeniedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService

View File

@@ -5,8 +5,8 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService

View File

@@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None:
abort(404)
else:

View File

@@ -7,9 +7,9 @@ from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService

View File

@@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
@@ -282,18 +282,14 @@ class ModelProviderModelCredentialApi(Resource):
)
if args.config_from == "predefined-model":
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=normalized_model_type,
model=args.model,
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(

View File

@@ -14,7 +14,7 @@ from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginDaemonClientSideError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService

View File

@@ -26,8 +26,8 @@ from core.mcp.mcp_client import MCPClient
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID

View File

@@ -14,8 +14,8 @@ from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_user, login_required
from models.account import Account
from models.provider_ids import TriggerProviderID

View File

@@ -7,7 +7,6 @@ from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
@@ -30,7 +29,6 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
@@ -110,29 +108,9 @@ class TenantListApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
tenant_plans: dict[str, SubscriptionPlan] = {}
if is_saas:
tenant_ids = [tenant.id for tenant in tenants]
if tenant_ids:
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
if not tenant_plans:
logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path")
for tenant in tenants:
plan: str = CloudPlan.SANDBOX
if is_saas:
tenant_plan = tenant_plans.get(tenant.id)
if tenant_plan:
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
else:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
elif not is_enterprise_only:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
features = FeatureService.get_features(tenant.id)
# Create a dictionary with tenant attributes
tenant_dict = {
@@ -140,7 +118,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": plan,
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
@@ -220,7 +198,7 @@ class SwitchWorkspaceApi(Resource):
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")

View File

@@ -7,7 +7,6 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from sqlalchemy import select
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
@@ -219,9 +218,13 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
if os.environ.get("INIT_PASSWORD"):
raise NotInitValidateError()
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
raise NotSetupError()
return view(*args, **kwargs)

View File

@@ -70,25 +70,22 @@ class ToolFileApi(Resource):
except Exception:
raise UnsupportedFileTypeError()
mime_type = tool_file.mime_type
filename = tool_file.filename
response = Response(
stream,
mimetype=mime_type,
mimetype=tool_file.mimetype,
direct_passthrough=True,
headers={},
)
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args.as_attachment and filename:
encoded_filename = quote(filename)
if args.as_attachment:
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
enforce_download_for_html(
response,
mime_type=mime_type,
filename=filename,
mime_type=tool_file.mimetype,
filename=tool_file.name,
extension=extension,
)

View File

@@ -7,8 +7,8 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
import services
from core.tools.signature import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file.helpers import verify_plugin_file_signature
from fields.file_fields import FileResponse
from ..common.errors import (

View File

@@ -16,14 +16,12 @@ api = ExternalApi(
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
from . import mail as _mail
from .app import dsl as _app_dsl
from .plugin import plugin as _plugin
from .workspace import workspace as _workspace
api.add_namespace(inner_api_ns)
__all__ = [
"_app_dsl",
"_mail",
"_plugin",
"_workspace",

View File

@@ -1 +0,0 @@

View File

@@ -1,110 +0,0 @@
"""Inner API endpoints for app DSL import/export.
Called by the enterprise admin-api service. Import requires ``creator_email``
to attribute the created app; workspace/membership validation is done by the
Go admin-api caller.
"""
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only
from extensions.ext_database import db
from models import Account, App
from models.account import AccountStatus
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
class InnerAppDSLImportPayload(BaseModel):
yaml_content: str = Field(description="YAML DSL content")
creator_email: str = Field(description="Email of the workspace member who will own the imported app")
name: str | None = Field(default=None, description="Override app name from DSL")
description: str | None = Field(default=None, description="Override app description from DSL")
register_schema_model(inner_api_ns, InnerAppDSLImportPayload)
@inner_api_ns.route("/enterprise/workspaces/<string:workspace_id>/dsl/import")
class EnterpriseAppDSLImport(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_ns.doc("enterprise_app_dsl_import")
@inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__])
@inner_api_ns.doc(
responses={
200: "Import completed",
202: "Import pending (DSL version mismatch requires confirmation)",
400: "Import failed (business error)",
404: "Creator account not found or inactive",
}
)
def post(self, workspace_id: str):
"""Import a DSL into a workspace on behalf of a specified creator."""
args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {})
account = _get_active_account(args.creator_email)
if account is None:
return {"message": f"account '{args.creator_email}' not found or inactive"}, 404
account.set_tenant_id(workspace_id)
with Session(db.engine) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=args.yaml_content,
name=args.name,
description=args.description,
)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
if result.status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@inner_api_ns.route("/enterprise/apps/<string:app_id>/dsl")
class EnterpriseAppDSLExport(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_ns.doc(
"enterprise_app_dsl_export",
responses={
200: "Export successful",
404: "App not found",
},
)
def get(self, app_id: str):
"""Export an app's DSL as YAML."""
include_secret = request.args.get("include_secret", "false").lower() == "true"
app_model = db.session.query(App).filter_by(id=app_id).first()
if not app_model:
return {"message": "app not found"}, 404
data = AppDslService.export_dsl(
app_model=app_model,
include_secret=include_secret,
)
return {"data": data}, 200
def _get_active_account(email: str) -> Account | None:
"""Look up an active account by email.
Workspace membership is already validated by the Go admin-api caller.
"""
account = db.session.query(Account).filter_by(email=email).first()
if account is None or account.status != AccountStatus.ACTIVE:
return None
return account

View File

@@ -28,8 +28,8 @@ from core.plugin.entities.request import (
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.signature import get_signed_file_url_for_plugin
from graphon.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.file.helpers import get_signed_file_url_for_plugin
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import length_prefixed_response
from models import Account, Tenant
from models.model import EndUser

View File

@@ -5,7 +5,6 @@ from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
@@ -37,16 +36,23 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_model = None
if is_anonymous:
user_model = session.scalar(
select(EndUser)
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
.first()
)
else:
user_model = session.get(EndUser, user_id)
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = EndUser(
@@ -79,7 +85,16 @@ def get_user_tenant(view_func: Callable[P, R]):
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
tenant_model = db.session.get(Tenant, tenant_id)
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")

View File

@@ -2,7 +2,6 @@ import json
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required
@@ -43,7 +42,7 @@ class EnterpriseWorkspace(Resource):
def post(self):
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1))
account = db.session.query(Account).filter_by(email=args.owner_email).first()
if account is None:
return {"message": "owner account not found."}, 404

View File

@@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
if signature_base64 != token:
return view(*args, **kwargs)
kwargs["user"] = db.session.get(EndUser, user_id)
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
return view(*args, **kwargs)

View File

@@ -9,8 +9,8 @@ from controllers.common.schema import register_schema_model
from controllers.mcp import mcp_ns
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from dify_graph.variables.input_entities import VariableEntity
from extensions.ext_database import db
from graphon.variables.input_entities import VariableEntity
from libs import helper
from models.enums import AppMCPServerStatus
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@@ -21,7 +21,7 @@ from controllers.service_api.app.error import (
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.errors.invoke import InvokeError
from models.model import App, EndUser
from services.audio_service import AudioService
from services.errors.audio import (

View File

@@ -28,7 +28,7 @@ from core.errors.error import (
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from graphon.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser

View File

@@ -4,7 +4,6 @@ from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model
@@ -103,27 +102,27 @@ class FilePreviewApi(Resource):
raise FileAccessDeniedError("Invalid file or app identifier")
# First, find the MessageFile that references this upload file
message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1))
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
if not message_file:
raise FileNotFoundError("File not found in message context")
# Get the message and verify it belongs to the requesting app
message = db.session.scalar(
select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1)
message = (
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
)
if not message:
raise FileAccessDeniedError("File access denied: not owned by requesting app")
# Get the actual upload file record
upload_file = db.session.get(UploadFile, file_id)
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise FileNotFoundError("Upload file record not found")
# Additional security: verify tenant isolation
app = db.session.get(App, app_id)
app = db.session.query(App).where(App.id == app_id).first()
if app and upload_file.tenant_id != app.tenant_id:
raise FileAccessDeniedError("File access denied: tenant mismatch")

View File

@@ -15,7 +15,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -117,7 +116,7 @@ class MessageFeedbackApi(Resource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:

View File

@@ -1,5 +1,4 @@
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from controllers.common.fields import Site as SiteResponse
@@ -29,7 +28,7 @@ class AppSiteApi(Resource):
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()

View File

@@ -27,12 +27,12 @@ from core.errors.error import (
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.graph_engine.manager import GraphEngineManager
from dify_graph.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser

View File

@@ -14,11 +14,10 @@ from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
)
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelType
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@@ -140,10 +139,10 @@ class DatasetListApi(DatasetApiResource):
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
)
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
provider_manager = create_plugin_provider_manager(tenant_id=cid)
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -154,20 +153,15 @@ class DatasetListApi(DatasetApiResource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
if (
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
):
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
)
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True # type: ignore
item["embedding_available"] = True
else:
item["embedding_available"] = False # type: ignore
item["embedding_available"] = False
else:
item["embedding_available"] = True # type: ignore
item["embedding_available"] = True
response = {
"data": data,
"has_more": len(datasets) == query.limit,
@@ -259,10 +253,10 @@ class DatasetApi(DatasetApiResource):
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting
provider_manager = ProviderManager()
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
provider_manager = create_plugin_provider_manager(tenant_id=cid)
configurations = provider_manager.get_configurations(tenant_id=cid)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -271,7 +265,7 @@ class DatasetApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY:
if data.get("indexing_technique") == "high_quality":
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names:
data["embedding_available"] = True
@@ -321,7 +315,7 @@ class DatasetApi(DatasetApiResource):
# check embedding model setting
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider:
if payload.indexing_technique == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model

View File

@@ -36,7 +36,6 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
@@ -623,15 +622,13 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields

Some files were not shown because too many files have changed in this diff Show More