mirror of
https://github.com/langgenius/dify.git
synced 2026-02-19 15:34:00 +00:00
Compare commits
53 Commits
chore/refa
...
refactor/w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0ad6e04d9 | ||
|
|
8d9fbbe595 | ||
|
|
5362f69083 | ||
|
|
822374eca5 | ||
|
|
815ae6c754 | ||
|
|
9a22baf57d | ||
|
|
c1bb310183 | ||
|
|
8f2aabf7bd | ||
|
|
3da975c745 | ||
|
|
bde4b82695 | ||
|
|
6d7026255e | ||
|
|
9b6b2f3195 | ||
|
|
ae43ad5cb6 | ||
|
|
5b02e5dcb6 | ||
|
|
e3ef33366d | ||
|
|
ee1d0df927 | ||
|
|
184077c37c | ||
|
|
3015e9be73 | ||
|
|
2bb1e24fb4 | ||
|
|
cad7101534 | ||
|
|
e856287b65 | ||
|
|
27be89c984 | ||
|
|
fa69cce1e7 | ||
|
|
f28a08a696 | ||
|
|
8129b04143 | ||
|
|
1b8e80a722 | ||
|
|
0421387672 | ||
|
|
2aaaa4bd34 | ||
|
|
64dc98e607 | ||
|
|
9007109a6b | ||
|
|
925168383b | ||
|
|
e6f3528bb0 | ||
|
|
fb5edd0bf6 | ||
|
|
de53c78125 | ||
|
|
3a59ae9617 | ||
|
|
69589807fd | ||
|
|
6ca44eea28 | ||
|
|
bf76f10653 | ||
|
|
c1af6a7127 | ||
|
|
1873b5a766 | ||
|
|
9fbc7fa379 | ||
|
|
2399d00d86 | ||
|
|
3505516e8e | ||
|
|
faef04cdf7 | ||
|
|
0ba9b9e6b5 | ||
|
|
30dd50ff83 | ||
|
|
5338cf85b1 | ||
|
|
673209d086 | ||
|
|
43758ec85d | ||
|
|
20944e7e1a | ||
|
|
7a5d2728a1 | ||
|
|
14bff10201 | ||
|
|
9a6b4147bc |
@@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
|
||||
|
||||
// i18n (automatically mocked)
|
||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||
// No explicit mock needed - it returns translation keys as-is
|
||||
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
|
||||
// No explicit mock needed for most tests
|
||||
//
|
||||
// Override only if custom translations are required:
|
||||
// vi.mock('react-i18next', () => ({
|
||||
// useTranslation: () => ({
|
||||
// t: (key: string) => {
|
||||
// const customTranslations: Record<string, string> = {
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// }
|
||||
// return customTranslations[key] || key
|
||||
// },
|
||||
// }),
|
||||
// import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
// vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// 'button.save': 'Save',
|
||||
// }))
|
||||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
|
||||
@@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
|
||||
### 1. i18n (Auto-loaded via Global Mock)
|
||||
|
||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||
|
||||
For tests requiring custom translations, override the mock:
|
||||
The global mock provides:
|
||||
|
||||
- `useTranslation` - returns translation keys with namespace prefix
|
||||
- `Trans` component - renders i18nKey and components
|
||||
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
|
||||
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
|
||||
|
||||
**Default behavior**: Most tests should use the global mock (no local override needed).
|
||||
|
||||
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
|
||||
|
||||
```typescript
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
'my.custom.key': 'Custom translation',
|
||||
}
|
||||
return translations[key] || key
|
||||
},
|
||||
}),
|
||||
import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
|
||||
vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
'my.custom.key': 'Custom translation',
|
||||
'button.save': 'Save',
|
||||
}))
|
||||
```
|
||||
|
||||
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
|
||||
|
||||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
|
||||
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@@ -110,6 +110,16 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check:tsgo
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
||||
@@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
ARCHIVE_STORAGE_ENDPOINT=
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET=
|
||||
ARCHIVE_STORAGE_ACCESS_KEY=
|
||||
ARCHIVE_STORAGE_SECRET_KEY=
|
||||
ARCHIVE_STORAGE_REGION=auto
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
@@ -493,6 +502,8 @@ LOG_FILE_BACKUP_COUNT=5
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log output format: text or json
|
||||
LOG_OUTPUT_FORMAT=text
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
exclude = ["migrations/*"]
|
||||
exclude = [
|
||||
"migrations/*",
|
||||
".git",
|
||||
".git/**",
|
||||
]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
|
||||
@@ -2,9 +2,11 @@ import logging
|
||||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
# Initialize logging context for this request
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
def add_trace_id_header(response):
|
||||
def add_trace_headers(response):
|
||||
try:
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
if ctx and ctx.is_valid:
|
||||
trace_id_hex = format(ctx.trace_id, "032x")
|
||||
# Avoid duplicates if some middleware added it
|
||||
if "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = trace_id_hex
|
||||
|
||||
if not ctx or not ctx.is_valid:
|
||||
return response
|
||||
|
||||
# Inject trace headers from OTEL context
|
||||
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
|
||||
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
|
||||
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
|
||||
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_id_header
|
||||
_ = add_trace_headers
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
):
|
||||
|
||||
43
api/configs/extra/archive_config.py
Normal file
43
api/configs/extra/archive_config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ArchiveStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for workflow run logs archiving storage.
|
||||
"""
|
||||
|
||||
ARCHIVE_STORAGE_ENABLED: bool = Field(
|
||||
description="Enable workflow run logs archiving to S3-compatible storage",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
|
||||
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store archived workflow logs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store exported workflow runs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
|
||||
description="Access key ID for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
|
||||
description="Secret access key for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_REGION: str = Field(
|
||||
description="Region for storage (use 'auto' if the provider supports it)",
|
||||
default="auto",
|
||||
)
|
||||
@@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
|
||||
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
LOG_FILE: str | None = Field(
|
||||
description="File path for log output.",
|
||||
default=None,
|
||||
|
||||
@@ -1,62 +1,59 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AppIconUrlField
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the system parameters model for the API or Namespace."""
|
||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
||||
class SystemParameters(BaseModel):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: JSONObject
|
||||
speech_to_text: JSONObject
|
||||
text_to_speech: JSONObject
|
||||
retriever_resource: JSONObject
|
||||
annotation_reply: JSONObject
|
||||
more_like_this: JSONObject
|
||||
user_input_form: list[JSONObject]
|
||||
sensitive_word_avoidance: JSONObject
|
||||
file_upload: JSONObject
|
||||
system_parameters: SystemParameters
|
||||
|
||||
|
||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the parameters model for the API or Namespace."""
|
||||
copied_fields = parameters_fields.copy()
|
||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
||||
return api_or_ns.model("Parameters", copied_fields)
|
||||
class Site(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
default_language: str
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
def build_site_model(api_or_ns: Api | Namespace):
|
||||
"""Build the site model for the API or Namespace."""
|
||||
return api_or_ns.model("Site", site_fields)
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
if self.icon and self.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
return None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
@@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
# XSS prevention: patterns that could lead to XSS attacks
|
||||
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||
_XSS_PATTERNS = [
|
||||
r"<script[^>]*>.*?</script>", # Script tags
|
||||
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||
r"javascript:", # JavaScript protocol
|
||||
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||
r"<link[^>]*>", # Link tags with javascript
|
||||
]
|
||||
|
||||
|
||||
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||
"""
|
||||
Validate that a string value doesn't contain potential XSS payloads.
|
||||
|
||||
Args:
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The original value if safe
|
||||
|
||||
Raises:
|
||||
ValueError: If the value contains XSS patterns
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
value_lower = value.lower()
|
||||
for pattern in _XSS_PATTERNS:
|
||||
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||
raise ValueError(
|
||||
f"{field_name} contains invalid characters or patterns. "
|
||||
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
@@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
@@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
|
||||
@@ -124,7 +124,7 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account = _generate_account(provider, user_info)
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
@@ -159,7 +159,10 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
base_url = dify_config.CONSOLE_WEB_URL
|
||||
query_char = "&" if "?" in base_url else "?"
|
||||
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||
response = redirect(target_url)
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
@@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
raise AccountRegisterError(
|
||||
@@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
return account
|
||||
return account, oauth_new_user
|
||||
|
||||
@@ -3,10 +3,12 @@ import uuid
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
if args.enabled.lower() != "all":
|
||||
if args.enabled.lower() == "true":
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@@ -13,7 +11,6 @@ from services.app_service import AppService
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
@@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
||||
@@ -20,7 +20,6 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
@@ -987,9 +986,6 @@ class ToolProviderMCPApi(Resource):
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
# Final cache invalidation to ensure list views are up to date
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@console_ns.expect(parser_mcp_put)
|
||||
@@ -1036,9 +1032,6 @@ class ToolProviderMCPApi(Resource):
|
||||
validation_result=validation_result,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@@ -1053,9 +1046,6 @@ class ToolProviderMCPApi(Resource):
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -1106,8 +1096,6 @@ class ToolMCPAuthApi(Resource):
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
# Invalidate cache after updating credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
@@ -1121,22 +1109,16 @@ class ToolMCPAuthApi(Resource):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
# Invalidate cache after auth actions may have updated provider state
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -92,7 +92,7 @@ annotation_list_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import build_parameters_model
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters.
|
||||
|
||||
@@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
@@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info.
|
||||
|
||||
@@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
@@ -78,7 +78,7 @@ workflow_run_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
@@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
|
||||
@@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
|
||||
@@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and tool_calls:
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
|
||||
@@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
|
||||
def __init__(self, provider_entity: ProviderEntity):
|
||||
@@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label=provider_entity.label,
|
||||
icon_small=provider_entity.icon_small,
|
||||
icon_small_dark=provider_entity.icon_small_dark,
|
||||
icon_large=provider_entity.icon_large,
|
||||
supported_model_types=provider_entity.supported_model_types,
|
||||
)
|
||||
|
||||
@@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType] = []
|
||||
|
||||
|
||||
|
||||
@@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
|
||||
When OTEL is disabled, we manually inject the traceparent header.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return headers
|
||||
|
||||
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
|
||||
if dify_config.ENABLE_OTEL:
|
||||
return headers
|
||||
|
||||
# Generate and inject traceparent for non-OTEL scenarios
|
||||
try:
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
# Silently ignore errors to avoid breaking requests
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
if "follow_redirects" not in kwargs:
|
||||
@@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Preserve user-provided Host header
|
||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||
headers = kwargs.get("headers", {})
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# Build the request manually to preserve the Host header
|
||||
# httpx may override the Host header when using a proxy, so we use
|
||||
# the request API to explicitly set headers before sending
|
||||
# Preserve the user-provided Host header
|
||||
# httpx may override the Host header when using a proxy
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
if user_provided_host is not None:
|
||||
headers["host"] = user_provided_host
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolProviderListCache:
|
||||
"""Cache for tool provider lists"""
|
||||
|
||||
CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
|
||||
"""Generate cache key for tool providers list"""
|
||||
type_filter = typ or "all"
|
||||
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
|
||||
"""Get cached tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
return json.loads(cached_data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Failed to decode cached tool providers data")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
|
||||
"""Cache tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""Invalidate cache for tool providers"""
|
||||
if typ:
|
||||
# Invalidate specific type cache
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
# Invalidate all caches for this tenant
|
||||
keys = ["builtin", "model", "api", "workflow", "mcp"]
|
||||
pipeline = redis_client.pipeline()
|
||||
for key in keys:
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
|
||||
pipeline.delete(cache_key)
|
||||
pipeline.execute()
|
||||
@@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
|
||||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def get_span_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current span ID from the active OpenTelemetry trace context.
|
||||
|
||||
Returns:
|
||||
A 16-character hex string representing the span ID, or None if not available.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context = span.get_span_context()
|
||||
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
||||
return None
|
||||
|
||||
return f"{span_context.span_id:016x}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_traceparent_header() -> str | None:
|
||||
"""
|
||||
Generate a W3C traceparent header from the current context.
|
||||
|
||||
Uses OpenTelemetry context if available, otherwise uses the
|
||||
ContextVar-based trace_id from the logging context.
|
||||
|
||||
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||
|
||||
Returns:
|
||||
A valid traceparent header string, or None if generation fails.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Try OTEL context first
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
span_id = get_span_id_from_otel_context()
|
||||
|
||||
if trace_id and span_id:
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
# Fallback: use ContextVar-based trace_id or generate new one
|
||||
from core.logging.context import get_trace_id as get_logging_trace_id
|
||||
|
||||
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
||||
|
||||
# Generate a new span_id (16 hex chars)
|
||||
span_id = uuid.uuid4().hex[:16]
|
||||
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
20
api/core/logging/__init__.py
Normal file
20
api/core/logging/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Structured logging components for Dify."""
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
__all__ = [
|
||||
"IdentityContextFilter",
|
||||
"StructuredJSONFormatter",
|
||||
"TraceContextFilter",
|
||||
"clear_request_context",
|
||||
"get_request_id",
|
||||
"get_trace_id",
|
||||
"init_request_context",
|
||||
]
|
||||
35
api/core/logging/context.py
Normal file
35
api/core/logging/context.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Request context for logging - framework agnostic.
|
||||
|
||||
This module provides request-scoped context variables for logging,
|
||||
using Python's contextvars for thread-safe and async-safe storage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
|
||||
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get current request ID (10 hex chars)."""
|
||||
return _request_id.get()
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
|
||||
return _trace_id.get()
|
||||
|
||||
|
||||
def init_request_context() -> None:
|
||||
"""Initialize request context. Call at start of each request."""
|
||||
req_id = uuid.uuid4().hex[:10]
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
|
||||
_request_id.set(req_id)
|
||||
_trace_id.set(trace_id)
|
||||
|
||||
|
||||
def clear_request_context() -> None:
|
||||
"""Clear request context. Call at end of request (optional)."""
|
||||
_request_id.set("")
|
||||
_trace_id.set("")
|
||||
94
api/core/logging/filters.py
Normal file
94
api/core/logging/filters.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Logging filters for structured logging."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import flask
|
||||
|
||||
from core.logging.context import get_request_id, get_trace_id
|
||||
|
||||
|
||||
class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds trace_id and span_id to log records.
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
|
||||
return True
|
||||
|
||||
def _get_otel_context(self) -> tuple[str, str]:
|
||||
"""Extract trace_id and span_id from OpenTelemetry context."""
|
||||
with contextlib.suppress(Exception):
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if span and span.get_span_context():
|
||||
ctx = span.get_span_context()
|
||||
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
|
||||
trace_id = f"{ctx.trace_id:032x}"
|
||||
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
|
||||
return trace_id, span_id
|
||||
return "", ""
|
||||
|
||||
|
||||
class IdentityContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds user identity context to log records.
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
"""Extract identity from current_user if in request context."""
|
||||
try:
|
||||
if not flask.has_request_context():
|
||||
return {}
|
||||
from flask_login import current_user
|
||||
|
||||
# Check if user is authenticated using the proxy
|
||||
if not current_user.is_authenticated:
|
||||
return {}
|
||||
|
||||
# Access the underlying user object
|
||||
user = current_user
|
||||
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
|
||||
if isinstance(user, Account):
|
||||
if user.current_tenant_id:
|
||||
identity["tenant_id"] = user.current_tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = "account"
|
||||
elif isinstance(user, EndUser):
|
||||
identity["tenant_id"] = user.tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = user.type or "end_user"
|
||||
|
||||
return identity
|
||||
except Exception:
|
||||
return {}
|
||||
107
api/core/logging/structured_formatter.py
Normal file
107
api/core/logging/structured_formatter.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Structured JSON log formatter for Dify."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
{
|
||||
"ts": "ISO 8601 UTC",
|
||||
"severity": "INFO|ERROR|WARN|DEBUG",
|
||||
"service": "service name",
|
||||
"caller": "file:line",
|
||||
"trace_id": "hex 32",
|
||||
"span_id": "hex 16",
|
||||
"identity": { "tenant_id", "user_id", "user_type" },
|
||||
"message": "log message",
|
||||
"attributes": { ... },
|
||||
"stack_trace": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
SEVERITY_MAP: dict[int, str] = {
|
||||
logging.DEBUG: "DEBUG",
|
||||
logging.INFO: "INFO",
|
||||
logging.WARNING: "WARN",
|
||||
logging.ERROR: "ERROR",
|
||||
logging.CRITICAL: "ERROR",
|
||||
}
|
||||
|
||||
def __init__(self, service_name: str | None = None):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
return orjson.dumps(log_dict).decode("utf-8")
|
||||
except TypeError:
|
||||
# Fallback: convert non-serializable objects to string
|
||||
import json
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
"caller": f"{record.filename}:{record.lineno}",
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Trace context (from TraceContextFilter)
|
||||
trace_id = getattr(record, "trace_id", "")
|
||||
span_id = getattr(record, "span_id", "")
|
||||
|
||||
if trace_id:
|
||||
log_dict["trace_id"] = trace_id
|
||||
if span_id:
|
||||
log_dict["span_id"] = span_id
|
||||
|
||||
# Identity context (from IdentityContextFilter)
|
||||
identity = self._extract_identity(record)
|
||||
if identity:
|
||||
log_dict["identity"] = identity
|
||||
|
||||
# Dynamic attributes
|
||||
attributes = getattr(record, "attributes", None)
|
||||
if attributes:
|
||||
log_dict["attributes"] = attributes
|
||||
|
||||
# Stack trace for errors with exceptions
|
||||
if record.exc_info and record.levelno >= logging.ERROR:
|
||||
log_dict["stack_trace"] = self._format_exception(record.exc_info)
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
identity["user_id"] = user_id
|
||||
if user_type:
|
||||
identity["user_type"] = user_type
|
||||
return identity
|
||||
|
||||
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
|
||||
if exc_info and exc_info[0] is not None:
|
||||
return "".join(traceback.format_exception(*exc_info))
|
||||
return ""
|
||||
@@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
@@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
@@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
|
||||
provider=self.provider,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
icon_large=self.icon_large,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
@@ -285,7 +285,7 @@ class ModelProviderFactory:
|
||||
"""
|
||||
Get provider icon
|
||||
:param provider: provider name
|
||||
:param icon_type: icon type (icon_small or icon_large)
|
||||
:param icon_type: icon type (icon_small or icon_small_dark)
|
||||
:param lang: language (zh_Hans or en_US)
|
||||
:return: provider icon
|
||||
"""
|
||||
@@ -309,13 +309,7 @@ class ModelProviderFactory:
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
if not file_name:
|
||||
raise ValueError(f"Provider {provider} does not have icon.")
|
||||
|
||||
@@ -103,6 +103,9 @@ class BasePluginClient:
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
# Inject traceparent header for distributed tracing
|
||||
self._inject_trace_headers(prepared_headers)
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
@@ -114,6 +117,31 @@ class BasePluginClient:
|
||||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
This ensures trace context is propagated to plugin daemon even if
|
||||
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
|
||||
"""
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
import contextlib
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return
|
||||
|
||||
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
|
||||
with contextlib.suppress(Exception):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
@@ -331,7 +331,6 @@ class ProviderManager:
|
||||
provider=provider_schema.provider,
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -27,26 +27,44 @@ class CleanProcessor:
|
||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL but keep Markdown image URLs
|
||||
# First, temporarily replace Markdown image URLs with a placeholder
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
|
||||
placeholders: list[str] = []
|
||||
# Remove URL but keep Markdown image URLs and link URLs
|
||||
# Replace the ENTIRE markdown link/image with a single placeholder to protect
|
||||
# the link text (which might also be a URL) from being removed
|
||||
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
|
||||
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
|
||||
|
||||
def replace_with_placeholder(match, placeholders=placeholders):
|
||||
def replace_markdown_with_placeholder(match, placeholders=placeholders):
|
||||
link_type = "link"
|
||||
link_text = match.group(1)
|
||||
url = match.group(2)
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||
placeholders.append((link_type, link_text, url))
|
||||
return placeholder
|
||||
|
||||
def replace_image_with_placeholder(match, placeholders=placeholders):
|
||||
link_type = "image"
|
||||
url = match.group(1)
|
||||
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
|
||||
placeholders.append(url)
|
||||
return f""
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||
placeholders.append((link_type, "image", url))
|
||||
return placeholder
|
||||
|
||||
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
|
||||
# Protect markdown links first
|
||||
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
|
||||
# Then protect markdown images
|
||||
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
|
||||
|
||||
# Now remove all remaining URLs
|
||||
url_pattern = r"https?://[^\s)]+"
|
||||
url_pattern = r"https?://\S+"
|
||||
text = re.sub(url_pattern, "", text)
|
||||
|
||||
# Finally, restore the Markdown image URLs
|
||||
for i, url in enumerate(placeholders):
|
||||
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
|
||||
# Restore the Markdown links and images
|
||||
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
|
||||
if link_type == "link":
|
||||
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
|
||||
else: # image
|
||||
text = text.replace(placeholder, f"")
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
@@ -36,6 +37,8 @@ default_retrieval_model = {
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
# Cache precompiled regular expressions to avoid repeated compilation
|
||||
@@ -106,7 +109,12 @@ class RetrievalService:
|
||||
)
|
||||
)
|
||||
|
||||
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
@@ -210,6 +218,7 @@ class RetrievalService:
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
@@ -303,6 +312,7 @@ class RetrievalService:
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
@@ -351,6 +361,7 @@ class RetrievalService:
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
exceptions.append(str(e))
|
||||
|
||||
@staticmethod
|
||||
@@ -663,7 +674,14 @@ class RetrievalService:
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
# Use as_completed for early error propagation - cancel remaining futures on first error
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=300):
|
||||
if future.exception():
|
||||
# Cancel remaining futures to avoid unnecessary waiting
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
@@ -112,7 +112,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@@ -148,7 +148,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
|
||||
@@ -1,25 +1,57 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pypdfium2
|
||||
import pypdfium2.raw as pdfium_c
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
"""
|
||||
PdfExtractor is used to extract text and images from PDF files.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
file_path: Path to the PDF file.
|
||||
tenant_id: Workspace ID.
|
||||
user_id: ID of the user performing the extraction.
|
||||
file_cache_key: Optional cache key for the extracted text.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, file_cache_key: str | None = None):
|
||||
"""Initialize with file path."""
|
||||
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
|
||||
IMAGE_FORMATS = [
|
||||
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
|
||||
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
|
||||
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
|
||||
(b"GIF8", "gif", "image/gif"),
|
||||
(b"BM", "bmp", "image/bmp"),
|
||||
(b"II*\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00*", "tiff", "image/tiff"),
|
||||
(b"II+\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00+", "tiff", "image/tiff"),
|
||||
]
|
||||
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||
"""Initialize PdfExtractor."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
@@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2 # type: ignore
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
@@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
|
||||
image_content = self._extract_images(page)
|
||||
if image_content:
|
||||
content += "\n" + image_content
|
||||
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
||||
|
||||
def _extract_images(self, page) -> str:
|
||||
"""
|
||||
Extract images from a PDF page, save them to storage and database,
|
||||
and return markdown image links.
|
||||
|
||||
Args:
|
||||
page: pypdfium2 page object.
|
||||
|
||||
Returns:
|
||||
Markdown string containing links to the extracted images.
|
||||
"""
|
||||
image_content = []
|
||||
upload_files = []
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
try:
|
||||
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||
for obj in image_objects:
|
||||
try:
|
||||
# Extract image bytes
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
|
||||
# Fallback to png for other formats
|
||||
obj.extract(img_byte_arr, fb_format="png")
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
|
||||
if not img_bytes:
|
||||
continue
|
||||
|
||||
header = img_bytes[: self.MAX_MAGIC_LEN]
|
||||
image_ext = None
|
||||
mime_type = None
|
||||
for magic, ext, mime in self.IMAGE_FORMATS:
|
||||
if header.startswith(magic):
|
||||
image_ext = ext
|
||||
mime_type = mime
|
||||
break
|
||||
|
||||
if not image_ext or not mime_type:
|
||||
continue
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
|
||||
|
||||
storage.save(file_key, img_bytes)
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
extension=image_ext,
|
||||
mime_type=mime_type,
|
||||
created_by=self._user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self._user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
upload_files.append(upload_file)
|
||||
image_content.append(f"")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract image from PDF: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get objects from PDF page: %s", e)
|
||||
if upload_files:
|
||||
db.session.add_all(upload_files)
|
||||
db.session.commit()
|
||||
return "\n".join(image_content)
|
||||
|
||||
@@ -516,6 +516,9 @@ class DatasetRetrieval:
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
with measure_time() as timer:
|
||||
cancel_event = threading.Event()
|
||||
thread_exceptions: list[Exception] = []
|
||||
|
||||
if query:
|
||||
query_thread = threading.Thread(
|
||||
target=self._multiple_retrieve_thread,
|
||||
@@ -534,6 +537,8 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
)
|
||||
all_threads.append(query_thread)
|
||||
@@ -557,12 +562,25 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
)
|
||||
all_threads.append(attachment_thread)
|
||||
attachment_thread.start()
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
|
||||
# Poll threads with short timeout to detect errors quickly (fail-fast)
|
||||
while any(t.is_alive() for t in all_threads):
|
||||
for thread in all_threads:
|
||||
thread.join(timeout=0.1)
|
||||
if thread_exceptions:
|
||||
cancel_event.set()
|
||||
break
|
||||
if thread_exceptions:
|
||||
break
|
||||
|
||||
if thread_exceptions:
|
||||
raise thread_exceptions[0]
|
||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
@@ -1404,40 +1422,53 @@ class DatasetRetrieval:
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
cancel_event: threading.Event | None = None,
|
||||
thread_exceptions: list[Exception] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
all_documents_item: list[Document] = []
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
try:
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
all_documents_item: list[Document] = []
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
# Check for cancellation signal
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": flask_app,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents_item,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": flask_app,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents_item,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
|
||||
# Poll threads with short timeout to respond quickly to cancellation
|
||||
while any(t.is_alive() for t in threads):
|
||||
for thread in threads:
|
||||
thread.join(timeout=0.1)
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
@@ -1470,3 +1501,8 @@ class DatasetRetrieval:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
if thread_exceptions is not None:
|
||||
thread_exceptions.append(e)
|
||||
|
||||
@@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
Preserves markdown links like [text](url) at the start.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
@@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Check if text starts with a markdown link - preserve it
|
||||
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
|
||||
if re.match(markdown_link_pattern, text):
|
||||
return text
|
||||
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
|
||||
@@ -54,7 +54,6 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
@@ -67,7 +66,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id="",
|
||||
provider_id=db_provider.id,
|
||||
)
|
||||
|
||||
controller.tools = [
|
||||
|
||||
@@ -60,6 +60,7 @@ class SkipPropagator:
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.start_execution(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
|
||||
@@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
|
||||
@@ -12,9 +12,8 @@ from dify_app import DifyApp
|
||||
|
||||
def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Use REDIS_USE_SSL for consistency with the main Redis client
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
if not dify_config.BROKER_USE_SSL:
|
||||
return None
|
||||
|
||||
# Check if Celery is actually using Redis
|
||||
@@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
from core.logging.context import init_request_context
|
||||
|
||||
with app.app_context():
|
||||
# Initialize logging context for this task (similar to before_request in Flask)
|
||||
init_request_context()
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
broker_transport_options = {}
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
"""Logging extension for Dify Flask application."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize logging with support for text or JSON format."""
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# File handler
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
@@ -25,27 +26,53 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
)
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
# Console handler
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
# Apply filters to all handlers
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(TraceContextFilter())
|
||||
handler.addFilter(IdentityContextFilter())
|
||||
|
||||
# Configure formatter based on format type
|
||||
formatter = _create_formatter()
|
||||
for handler in log_handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
|
||||
# Apply timezone if specified (only for text format)
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "text":
|
||||
_apply_timezone(log_handlers)
|
||||
|
||||
|
||||
def _create_formatter() -> logging.Formatter:
|
||||
"""Create appropriate formatter based on configuration."""
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "json":
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
return StructuredJSONFormatter()
|
||||
else:
|
||||
# Text format - use existing pattern with backward compatible formatter
|
||||
return _TextFormatter(
|
||||
fmt=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
)
|
||||
|
||||
|
||||
def _apply_timezone(handlers: list[logging.Handler]):
|
||||
"""Apply timezone conversion to text formatters."""
|
||||
log_tz = dify_config.LOG_TZ
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
@@ -57,34 +84,51 @@ def init_app(app: DifyApp):
|
||||
def time_converter(seconds):
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
for handler in handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
handler.formatter.converter = time_converter # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_request_id():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
class _TextFormatter(logging.Formatter):
|
||||
"""Text formatter that ensures trace_id and req_id are always present."""
|
||||
|
||||
new_uuid = uuid.uuid4().hex[:10]
|
||||
flask.g.request_id = new_uuid
|
||||
|
||||
return new_uuid
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
record.trace_id = ""
|
||||
if not hasattr(record, "span_id"):
|
||||
record.span_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get request ID for current request context.
|
||||
|
||||
Deprecated: Use core.logging.context.get_request_id() directly.
|
||||
"""
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
|
||||
return _get_request_id()
|
||||
|
||||
|
||||
# Backward compatibility aliases
|
||||
class RequestIdFilter(logging.Filter):
|
||||
# This is a logging filter that makes the request ID available for use in
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
from core.logging.context import get_trace_id as _get_trace_id
|
||||
|
||||
record.req_id = _get_request_id()
|
||||
record.trace_id = _get_trace_id()
|
||||
return True
|
||||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""Deprecated: Use _TextFormatter instead."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
@@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
|
||||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
"""Deprecated: Formatter is now applied in init_app."""
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
|
||||
@@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
|
||||
Unlike creating a new span, this records exceptions on the existing span
|
||||
to maintain trace context consistency throughout the request lifecycle.
|
||||
"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
) as span:
|
||||
span.set_status(StatusCode.ERROR)
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
if not record.exc_info:
|
||||
return
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
span = get_current_span()
|
||||
if not span or not span.is_recording():
|
||||
return
|
||||
|
||||
# Record exception on the current span instead of creating a new one
|
||||
span.set_status(StatusCode.ERROR, record.getMessage())
|
||||
|
||||
# Add log context as span events/attributes
|
||||
span.add_event(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
)
|
||||
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
|
||||
def instrument_exception_logging() -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@@ -12,7 +12,7 @@ annotation_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
@@ -46,7 +46,7 @@ message_file_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_message_file_model(api_or_ns: Api | Namespace):
|
||||
def build_message_file_model(api_or_ns: Namespace):
|
||||
"""Build the message file fields for the API or Namespace."""
|
||||
return api_or_ns.model("MessageFile", message_file_fields)
|
||||
|
||||
@@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||
|
||||
@@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
|
||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||
|
||||
|
||||
def build_conversation_delete_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||
"""Build the conversation delete model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||
|
||||
|
||||
def build_simple_conversation_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||
"""Build the simple conversation model for the API or Namespace."""
|
||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_variable_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested variable model first
|
||||
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@@ -8,5 +8,5 @@ simple_end_user_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@@ -14,7 +14,7 @@ upload_config_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_upload_config_model(api_or_ns: Api | Namespace):
|
||||
def build_upload_config_model(api_or_ns: Namespace):
|
||||
"""Build the upload config model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@@ -39,7 +39,7 @@ file_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_file_model(api_or_ns: Api | Namespace):
|
||||
def build_file_model(api_or_ns: Namespace):
|
||||
"""Build the file model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@@ -57,7 +57,7 @@ remote_file_info_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_remote_file_info_model(api_or_ns: Api | Namespace):
|
||||
def build_remote_file_info_model(api_or_ns: Namespace):
|
||||
"""Build the remote file info model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@@ -81,7 +81,7 @@ file_fields_with_signed_url = {
|
||||
}
|
||||
|
||||
|
||||
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
|
||||
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
||||
"""Build the file with signed URL model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
@@ -9,7 +9,7 @@ simple_account_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
@@ -10,7 +10,7 @@ feedback_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
||||
def build_feedback_model(api_or_ns: Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
|
||||
@@ -30,7 +30,7 @@ agent_thought_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
||||
def build_agent_thought_model(api_or_ns: Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
@@ -8,5 +8,5 @@ dataset_tag_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
@@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
@@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log pagination model for the API or Namespace."""
|
||||
# Build the nested partial model first
|
||||
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
@@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||
|
||||
|
||||
|
||||
347
api/libs/archive_storage.py
Normal file
347
api/libs/archive_storage.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Archive Storage Client for S3-compatible storage.
|
||||
|
||||
This module provides a dedicated storage client for archiving or exporting logs
|
||||
to S3-compatible object storage.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
import boto3
|
||||
import orjson
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArchiveStorageError(Exception):
|
||||
"""Base exception for archive storage operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
|
||||
"""Raised when archive storage is not properly configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArchiveStorage:
|
||||
"""
|
||||
S3-compatible storage client for archiving or exporting.
|
||||
|
||||
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket: str):
|
||||
if not dify_config.ARCHIVE_STORAGE_ENABLED:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
|
||||
|
||||
if not bucket:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
|
||||
if not all(
|
||||
[
|
||||
dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||
bucket,
|
||||
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
]
|
||||
):
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive storage configuration is incomplete. "
|
||||
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
|
||||
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
|
||||
)
|
||||
|
||||
self.bucket = bucket
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
region_name=dify_config.ARCHIVE_STORAGE_REGION,
|
||||
config=Config(s3={"addressing_style": "path"}),
|
||||
)
|
||||
|
||||
# Verify bucket accessibility
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.bucket)
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "404":
|
||||
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
|
||||
elif error_code == "403":
|
||||
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
|
||||
else:
|
||||
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
|
||||
|
||||
def put_object(self, key: str, data: bytes) -> str:
|
||||
"""
|
||||
Upload an object to the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
data: Binary data to upload
|
||||
|
||||
Returns:
|
||||
MD5 checksum of the uploaded data
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If upload fails
|
||||
"""
|
||||
checksum = hashlib.md5(data).hexdigest()
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
Body=data,
|
||||
ContentMD5=self._content_md5(data),
|
||||
)
|
||||
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
|
||||
return checksum
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
|
||||
|
||||
def get_object(self, key: str) -> bytes:
|
||||
"""
|
||||
Download an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Returns:
|
||||
Binary data of the object
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If download fails
|
||||
FileNotFoundError: If object does not exist
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
|
||||
|
||||
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Stream an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Yields:
|
||||
Chunks of binary data
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If download fails
|
||||
FileNotFoundError: If object does not exist
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
|
||||
|
||||
def object_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if an object exists in the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Returns:
|
||||
True if object exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket, Key=key)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
def delete_object(self, key: str) -> None:
|
||||
"""
|
||||
Delete an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If deletion fails
|
||||
"""
|
||||
try:
|
||||
self.client.delete_object(Bucket=self.bucket, Key=key)
|
||||
logger.debug("Deleted object: %s", key)
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
|
||||
|
||||
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Generate a pre-signed URL for downloading an object.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If generation fails
|
||||
"""
|
||||
try:
|
||||
return self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket, "Key": key},
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
|
||||
|
||||
def list_objects(self, prefix: str) -> list[str]:
|
||||
"""
|
||||
List objects under a given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Object key prefix to filter by
|
||||
|
||||
Returns:
|
||||
List of object keys matching the prefix
|
||||
"""
|
||||
keys = []
|
||||
paginator = self.client.get_paginator("list_objects_v2")
|
||||
|
||||
try:
|
||||
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
|
||||
for obj in page.get("Contents", []):
|
||||
keys.append(obj["Key"])
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
|
||||
|
||||
return keys
|
||||
|
||||
@staticmethod
|
||||
def _content_md5(data: bytes) -> str:
|
||||
"""Calculate base64-encoded MD5 for Content-MD5 header."""
|
||||
return base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
|
||||
@staticmethod
|
||||
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
|
||||
"""
|
||||
Serialize records to gzipped JSONL format.
|
||||
|
||||
Args:
|
||||
records: List of dictionaries to serialize
|
||||
|
||||
Returns:
|
||||
Gzipped JSONL bytes
|
||||
"""
|
||||
lines = []
|
||||
for record in records:
|
||||
# Convert datetime objects to ISO format strings
|
||||
serialized = ArchiveStorage._serialize_record(record)
|
||||
lines.append(orjson.dumps(serialized))
|
||||
|
||||
jsonl_content = b"\n".join(lines)
|
||||
if jsonl_content:
|
||||
jsonl_content += b"\n"
|
||||
|
||||
return gzip.compress(jsonl_content)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Deserialize gzipped JSONL data to records.
|
||||
|
||||
Args:
|
||||
data: Gzipped JSONL bytes
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
"""
|
||||
jsonl_content = gzip.decompress(data)
|
||||
records = []
|
||||
|
||||
for line in jsonl_content.splitlines():
|
||||
if line:
|
||||
records.append(orjson.loads(line))
|
||||
|
||||
return records
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize a single record, converting special types."""
|
||||
|
||||
def _serialize(item: Any) -> Any:
|
||||
if isinstance(item, datetime.datetime):
|
||||
return item.isoformat()
|
||||
if isinstance(item, dict):
|
||||
return {key: _serialize(value) for key, value in item.items()}
|
||||
if isinstance(item, list):
|
||||
return [_serialize(value) for value in item]
|
||||
return item
|
||||
|
||||
return cast(dict[str, Any], _serialize(record))
|
||||
|
||||
@staticmethod
|
||||
def compute_checksum(data: bytes) -> str:
|
||||
"""Compute MD5 checksum of data."""
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
# Singleton instance (lazy initialization)
|
||||
_archive_storage: ArchiveStorage | None = None
|
||||
_export_storage: ArchiveStorage | None = None
|
||||
|
||||
|
||||
def get_archive_storage() -> ArchiveStorage:
|
||||
"""
|
||||
Get the archive storage singleton instance.
|
||||
|
||||
Returns:
|
||||
ArchiveStorage instance
|
||||
|
||||
Raises:
|
||||
ArchiveStorageNotConfiguredError: If archive storage is not configured
|
||||
"""
|
||||
global _archive_storage
|
||||
if _archive_storage is None:
|
||||
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
|
||||
if not archive_bucket:
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
|
||||
)
|
||||
_archive_storage = ArchiveStorage(bucket=archive_bucket)
|
||||
return _archive_storage
|
||||
|
||||
|
||||
def get_export_storage() -> ArchiveStorage:
|
||||
"""
|
||||
Get the export storage singleton instance.
|
||||
|
||||
Returns:
|
||||
ArchiveStorage instance
|
||||
"""
|
||||
global _export_storage
|
||||
if _export_storage is None:
|
||||
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
|
||||
if not export_bucket:
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
|
||||
)
|
||||
_export_storage = ArchiveStorage(bucket=export_bucket)
|
||||
return _export_storage
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
@@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
|
||||
data.setdefault("code", "unknown")
|
||||
data.setdefault("status", status_code)
|
||||
|
||||
# Log stack
|
||||
exc_info: Any = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = (None, None, None)
|
||||
current_app.log_exception(exc_info)
|
||||
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
|
||||
# Explicit log_exception call removed to avoid duplicate log entries
|
||||
|
||||
return data, status_code
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '00bacef91f18'
|
||||
down_revision = '8ec536f3c800'
|
||||
@@ -23,31 +20,17 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '114eed84c228'
|
||||
down_revision = 'c71211c8f604'
|
||||
@@ -32,13 +28,7 @@ def upgrade():
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '161cadc1af8d'
|
||||
down_revision = '7e6a8693e07a'
|
||||
@@ -23,16 +20,9 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -9,11 +9,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '6af6a521a53e'
|
||||
down_revision = 'd57ba9ebb251'
|
||||
@@ -23,58 +18,30 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd3f6769a94a3'
|
||||
|
||||
@@ -28,85 +28,45 @@ def upgrade():
|
||||
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -49,57 +49,33 @@ def upgrade():
|
||||
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
|
||||
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
|
||||
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
|
||||
@@ -86,57 +86,30 @@ def upgrade():
|
||||
|
||||
def migrate_existing_provider_models_data():
|
||||
"""migrate provider_models table data to provider_model_credentials"""
|
||||
conn = op.get_bind()
|
||||
# Define table structure for data manipulation
|
||||
if _is_pg(conn):
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
else:
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
# Define table structure for data manipulatio
|
||||
provider_models_table = table('provider_models',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime()),
|
||||
column('credential_id', models.types.StringUUID()),
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', sa.Text()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
else:
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
provider_model_credentials_table = table('provider_model_credentials',
|
||||
column('id', models.types.StringUUID()),
|
||||
column('tenant_id', models.types.StringUUID()),
|
||||
column('provider_name', sa.String()),
|
||||
column('model_name', sa.String()),
|
||||
column('model_type', sa.String()),
|
||||
column('credential_name', sa.String()),
|
||||
column('encrypted_config', models.types.LongText()),
|
||||
column('created_at', sa.DateTime()),
|
||||
column('updated_at', sa.DateTime())
|
||||
)
|
||||
|
||||
|
||||
# Get database connection
|
||||
@@ -183,14 +156,8 @@ def migrate_existing_provider_models_data():
|
||||
|
||||
def downgrade():
|
||||
# Re-add encrypted_config column to provider_models table
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
|
||||
|
||||
if not context.is_offline_mode():
|
||||
# Migrate data back from provider_model_credentials to provider_models
|
||||
|
||||
@@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
|
||||
@@ -9,8 +9,6 @@ from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -23,12 +21,7 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# Add encrypted_headers column to tool_mcp_providers table
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
|
||||
else:
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
|
||||
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
||||
@@ -44,6 +44,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('datasource_oauth_tenant_params',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@@ -70,6 +71,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('datasource_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@@ -104,6 +106,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
|
||||
batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False)
|
||||
|
||||
@@ -133,6 +136,7 @@ def upgrade():
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
|
||||
|
||||
@@ -174,6 +178,7 @@ def upgrade():
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('pipeline_customized_templates',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@@ -193,7 +198,6 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||
)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
op.create_table('pipeline_customized_templates',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
@@ -211,6 +215,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
@@ -236,6 +241,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('pipelines',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@@ -266,6 +272,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('workflow_draft_variable_files',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@@ -292,6 +299,7 @@ def upgrade():
|
||||
sa.Column('value_type', sa.String(20), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey'))
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('workflow_node_execution_offload',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
@@ -316,6 +324,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')),
|
||||
sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key'))
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
|
||||
@@ -342,6 +351,7 @@ def upgrade():
|
||||
comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',)
|
||||
)
|
||||
batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
|
||||
|
||||
@@ -9,8 +9,6 @@ from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -33,15 +31,9 @@ def upgrade():
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
@@ -105,6 +105,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('trigger_subscriptions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
@@ -143,6 +144,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
|
||||
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
|
||||
@@ -176,6 +178,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
|
||||
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False)
|
||||
|
||||
@@ -207,6 +210,7 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
|
||||
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
|
||||
|
||||
@@ -264,6 +268,7 @@ def upgrade():
|
||||
sa.Column('finished_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
|
||||
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
|
||||
@@ -299,6 +304,7 @@ def upgrade():
|
||||
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
|
||||
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '23db93619b9d'
|
||||
down_revision = '8ae9bc661daa'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -62,14 +62,8 @@ def upgrade():
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_settings_app_idx')
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2a3aebbbf4bb'
|
||||
down_revision = 'c031d46af369'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2e9819ca5b28'
|
||||
down_revision = 'ab23c11305d4'
|
||||
@@ -24,35 +20,19 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
else:
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
else:
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '42e85ed5564d'
|
||||
down_revision = 'f9107f83abab'
|
||||
@@ -24,59 +20,31 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('model_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('app_model_config_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4829e54d2fee'
|
||||
down_revision = '114eed84c228'
|
||||
@@ -23,39 +19,21 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_chain_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '563cf8bf777b'
|
||||
down_revision = 'b5429b71023c'
|
||||
@@ -23,35 +19,19 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -48,12 +48,9 @@ def upgrade():
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '714aafe25d39'
|
||||
down_revision = 'f2a6fc85e260'
|
||||
@@ -23,16 +20,9 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
|
||||
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
|
||||
batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
|
||||
batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '77e83833755c'
|
||||
down_revision = '6dcb43972bdc'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
@@ -40,7 +39,6 @@ def upgrade():
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
@@ -52,12 +50,9 @@ def upgrade():
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '88072f0caa04'
|
||||
down_revision = '246ba09cbbdb'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '89c7899ca936'
|
||||
down_revision = '187385f442fc'
|
||||
@@ -23,39 +20,21 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('description',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('description',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
existing_nullable=True)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('description',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
existing_nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('description',
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('description',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=True)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('description',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8ec536f3c800'
|
||||
down_revision = 'ad472b61a054'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -57,12 +57,9 @@ def upgrade():
|
||||
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
|
||||
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
|
||||
@@ -24,7 +24,6 @@ def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False))
|
||||
batch_op.drop_index('pinned_conversation_conversation_idx')
|
||||
@@ -35,7 +34,6 @@ def upgrade():
|
||||
batch_op.drop_index('saved_message_message_idx')
|
||||
batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False))
|
||||
batch_op.drop_index('pinned_conversation_conversation_idx')
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a5b56fb053ef'
|
||||
down_revision = 'd3d503a3471c'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a9836e3baeee'
|
||||
down_revision = '968fff4c0ab9'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b24be59fbb04'
|
||||
down_revision = 'de95f5c77138'
|
||||
@@ -23,14 +20,8 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ from alembic import op
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b3a09c049e8e'
|
||||
down_revision = '2e9819ca5b28'
|
||||
@@ -23,20 +20,11 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
|
||||
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
|
||||
batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
|
||||
batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
|
||||
batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
|
||||
batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
|
||||
batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
|
||||
batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
@@ -54,12 +54,9 @@ def upgrade():
|
||||
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
|
||||
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
|
||||
else:
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
@@ -68,54 +65,31 @@ def upgrade():
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False))
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
|
||||
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
|
||||
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
batch_op.drop_column('hit_count')
|
||||
batch_op.drop_column('question')
|
||||
else:
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.drop_column('hit_count')
|
||||
batch_op.drop_column('question')
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.drop_column('hit_count')
|
||||
batch_op.drop_column('question')
|
||||
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_column('type')
|
||||
|
||||
@@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f2a6fc85e260'
|
||||
down_revision = '46976cc39132'
|
||||
@@ -24,16 +21,9 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
|
||||
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
|
||||
else:
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
|
||||
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
|
||||
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .base import TypeBase
|
||||
@@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
|
||||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||
|
||||
@validates("status")
|
||||
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
|
||||
if isinstance(value, AccountStatus):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
||||
@@ -16,6 +16,11 @@ celery_redis = Redis(
|
||||
port=redis_config.get("port") or 6379,
|
||||
password=redis_config.get("password") or None,
|
||||
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
|
||||
ssl=bool(dify_config.BROKER_USE_SSL),
|
||||
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
@@ -31,6 +32,11 @@ class BillingService:
|
||||
|
||||
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
|
||||
|
||||
# Redis key prefix for tenant plan cache
|
||||
_PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
|
||||
# Cache TTL: 10 minutes
|
||||
_PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
@@ -272,14 +278,110 @@ class BillingService:
|
||||
data = resp.get("data", {})
|
||||
|
||||
for tenant_id, plan in data.items():
|
||||
subscription_plan = subscription_adapter.validate_python(plan)
|
||||
results[tenant_id] = subscription_plan
|
||||
try:
|
||||
subscription_plan = subscription_adapter.validate_python(plan)
|
||||
results[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
||||
logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _make_plan_cache_key(cls, tenant_id: str) -> str:
|
||||
return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
|
||||
|
||||
@classmethod
|
||||
def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
|
||||
"""
|
||||
Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
|
||||
|
||||
NOTE: if you want to high data consistency, use get_plan_bulk instead.
|
||||
|
||||
Returns:
|
||||
Mapping of tenant_id -> {plan: str, expiration_date: int}
|
||||
"""
|
||||
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
if not tenant_ids:
|
||||
return tenant_plans
|
||||
|
||||
subscription_adapter = TypeAdapter(SubscriptionPlan)
|
||||
|
||||
# Step 1: Batch fetch from Redis cache using mget
|
||||
redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
|
||||
try:
|
||||
cached_values = redis_client.mget(redis_keys)
|
||||
|
||||
if len(cached_values) != len(tenant_ids):
|
||||
raise Exception(
|
||||
"get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
|
||||
)
|
||||
|
||||
# Map cached values back to tenant_ids
|
||||
cache_misses: list[str] = []
|
||||
|
||||
for tenant_id, cached_value in zip(tenant_ids, cached_values):
|
||||
if cached_value:
|
||||
try:
|
||||
# Redis returns bytes, decode to string and parse JSON
|
||||
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||
plan_dict = json.loads(json_str)
|
||||
subscription_plan = subscription_adapter.validate_python(plan_dict)
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
|
||||
)
|
||||
cache_misses.append(tenant_id)
|
||||
else:
|
||||
cache_misses.append(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
|
||||
len(tenant_plans),
|
||||
len(cache_misses),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
|
||||
cache_misses = list(tenant_ids)
|
||||
|
||||
# Step 2: Fetch missing plans from billing API
|
||||
if cache_misses:
|
||||
bulk_plans = BillingService.get_plan_bulk(cache_misses)
|
||||
|
||||
if bulk_plans:
|
||||
plans_to_cache: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
for tenant_id, subscription_plan in bulk_plans.items():
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
plans_to_cache[tenant_id] = subscription_plan
|
||||
|
||||
# Step 3: Batch update Redis cache using pipeline
|
||||
if plans_to_cache:
|
||||
try:
|
||||
pipe = redis_client.pipeline()
|
||||
for tenant_id, subscription_plan in plans_to_cache.items():
|
||||
redis_key = cls._make_plan_cache_key(tenant_id)
|
||||
# Serialize dict to JSON string
|
||||
json_str = json.dumps(subscription_plan)
|
||||
pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
|
||||
pipe.execute()
|
||||
|
||||
logger.info(
|
||||
"get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
|
||||
len(plans_to_cache),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
|
||||
|
||||
return tenant_plans
|
||||
|
||||
@classmethod
|
||||
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
|
||||
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user