Compare commits

..

7 Commits

Author SHA1 Message Date
Yi
4fcb048607 add "add block" shortcut ui 2024-08-23 16:18:19 +08:00
Yi
18b61591b8 "add block" shortcut update 2024-08-23 16:17:14 +08:00
Yi
5408e923e3 Merge branch 'main' into feat/workflow-add-block-shortcut 2024-08-23 15:56:37 +08:00
Yi
1bcfa747db add block shortcut 2024-08-23 15:54:29 +08:00
Yi
071b7d607b Merge branch 'main' into feat/workflow-add-block-shortcut 2024-08-23 13:58:35 +08:00
Yi
1a5e5cd8f8 Merge branch 'main' into feat/workflow-add-block-shortcut 2024-08-23 11:46:15 +08:00
Yi
3bd5f9542e add "add block" shortcut ui 2024-08-22 17:51:12 +08:00
882 changed files with 20426 additions and 32283 deletions

View File

@@ -125,6 +125,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@@ -1,54 +0,0 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

View File

@@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
## Before you jump in
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests:

View File

@@ -8,7 +8,7 @@
## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
### 功能请求:

View File

@@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
## 飛び込む前に
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
### 機能リクエスト

View File

@@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
## Trước khi bắt đầu
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
### Yêu cầu tính năng:

View File

@@ -60,8 +60,7 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string

1
api/.idea/vcs.xml generated
View File

@@ -12,6 +12,5 @@
</component>
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

View File

@@ -5,10 +5,6 @@ WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=1.8.3
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
@@ -20,9 +16,6 @@ ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages
# if you located in China, you can use aliyun mirror to speed up
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
@@ -50,12 +43,10 @@ WORKDIR /app/api
RUN apt-get update \
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# if you located in China, you can use aliyun mirror to speed up
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
@@ -65,7 +56,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
RUN python -c "import nltk; nltk.download('punkt')"
# Copy source code
COPY . /app/api/

View File

@@ -559,9 +559,8 @@ def add_qdrant_doc_id_index(field: str):
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
@@ -581,15 +580,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
if language not in languages:
language = "en-US"
name = name.strip()
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account, name)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(
click.style(

View File

@@ -1,3 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig()
dify_config = DifyConfig()

View File

@@ -1,3 +1,4 @@
from pydantic import Field, computed_field
from pydantic_settings import SettingsConfigDict
from configs.deploy import DeploymentConfig
@@ -23,16 +24,44 @@ class DifyConfig(
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
):
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
model_config = SettingsConfigDict(
# read from dotenv format config file
env_file=".env",
env_file_encoding="utf-8",
env_file='.env',
env_file_encoding='utf-8',
frozen=True,
# ignore extra attributes
extra="ignore",
extra='ignore',
)
# Before adding any config,
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_PRECISION: int = 20
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')

View File

@@ -6,28 +6,22 @@ class DeploymentConfig(BaseSettings):
"""
Deployment configs
"""
APPLICATION_NAME: str = Field(
description="application name",
default="langgenius/dify",
)
DEBUG: bool = Field(
description="whether to enable debug mode.",
default=False,
description='application name',
default='langgenius/dify',
)
TESTING: bool = Field(
description="",
description='',
default=False,
)
EDITION: str = Field(
description="deployment edition",
default="SELF_HOSTED",
description='deployment edition',
default='SELF_HOSTED',
)
DEPLOY_ENV: str = Field(
description="deployment environment, default to PRODUCTION.",
default="PRODUCTION",
description='deployment environment, default to PRODUCTION.',
default='PRODUCTION',
)

View File

@@ -7,14 +7,13 @@ class EnterpriseFeatureConfig(BaseSettings):
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
"""
ENTERPRISE_ENABLED: bool = Field(
description="whether to enable enterprise features."
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
description='whether to enable enterprise features.'
'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
default=False,
)
CAN_REPLACE_LOGO: bool = Field(
description="whether to allow replacing enterprise logo.",
description='whether to allow replacing enterprise logo.',
default=False,
)

View File

@@ -8,28 +8,27 @@ class NotionConfig(BaseSettings):
"""
Notion integration configs
"""
NOTION_CLIENT_ID: Optional[str] = Field(
description="Notion client ID",
description='Notion client ID',
default=None,
)
NOTION_CLIENT_SECRET: Optional[str] = Field(
description="Notion client secret key",
description='Notion client secret key',
default=None,
)
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description="Notion integration type, default to None, available values: internal.",
description='Notion integration type, default to None, available values: internal.',
default=None,
)
NOTION_INTERNAL_SECRET: Optional[str] = Field(
description="Notion internal secret key",
description='Notion internal secret key',
default=None,
)
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description="Notion integration token",
description='Notion integration token',
default=None,
)

View File

@@ -8,18 +8,17 @@ class SentryConfig(BaseSettings):
"""
Sentry configs
"""
SENTRY_DSN: Optional[str] = Field(
description="Sentry DSN",
description='Sentry DSN',
default=None,
)
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sentry trace sample rate",
description='Sentry trace sample rate',
default=1.0,
)
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sentry profiles sample rate",
description='Sentry profiles sample rate',
default=1.0,
)

View File

@@ -1,6 +1,6 @@
from typing import Annotated, Optional
from typing import Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@@ -10,17 +10,16 @@ class SecurityConfig(BaseSettings):
"""
Secret Key configs
"""
SECRET_KEY: Optional[str] = Field(
description="Your App secret key will be used for securely signing the session cookie"
"Make sure you are changing this key for your deployment with a strong key."
"You can generate a strong key using `openssl rand -base64 42`."
"Alternatively you can set it with `SECRET_KEY` environment variable.",
description='Your App secret key will be used for securely signing the session cookie'
'Make sure you are changing this key for your deployment with a strong key.'
'You can generate a strong key using `openssl rand -base64 42`.'
'Alternatively you can set it with `SECRET_KEY` environment variable.',
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description="Expiry time in hours for reset token",
description='Expiry time in hours for reset token',
default=24,
)
@@ -29,13 +28,12 @@ class AppExecutionConfig(BaseSettings):
"""
App Execution configs
"""
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description="execution timeout in seconds for app execution",
description='execution timeout in seconds for app execution',
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="max active request per app, 0 means unlimited",
description='max active request per app, 0 means unlimited',
default=0,
)
@@ -44,70 +42,14 @@ class CodeExecutionSandboxConfig(BaseSettings):
"""
Code Execution Sandbox configs
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="endpoint URL of code execution servcie",
default="http://sandbox:8194",
CODE_EXECUTION_ENDPOINT: str = Field(
description='endpoint URL of code execution servcie',
default='http://sandbox:8194',
)
CODE_EXECUTION_API_KEY: str = Field(
description="API key for code execution service",
default="dify-sandbox",
)
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
description="connect timeout in seconds for code execution request",
default=10.0,
)
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
description="read timeout in seconds for code execution request",
default=60.0,
)
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
description="write timeout in seconds for code execution request",
default=10.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="max depth for code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="max depth for code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="max precision digits for float type in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="max string length for code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=1000,
description='API key for code execution service',
default='dify-sandbox',
)
@@ -115,27 +57,28 @@ class EndpointConfig(BaseSettings):
"""
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description="The backend URL prefix of the console API."
"used to concatenate the login authorization callback or notion integration callback.",
default="",
description='The backend URL prefix of the console API.'
'used to concatenate the login authorization callback or notion integration callback.',
default='',
)
CONSOLE_WEB_URL: str = Field(
description="The front-end URL prefix of the console web."
"used to concatenate some front-end addresses and for CORS configuration use.",
default="",
description='The front-end URL prefix of the console web.'
'used to concatenate some front-end addresses and for CORS configuration use.',
default='',
)
SERVICE_API_URL: str = Field(
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
default="",
description='Service API Url prefix.'
'used to display Service API Base Url to the front-end.',
default='',
)
APP_WEB_URL: str = Field(
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
default="",
description='WebApp Url prefix.'
'used to display WebAPP API Base Url to the front-end.',
default='',
)
@@ -143,18 +86,17 @@ class FileAccessConfig(BaseSettings):
"""
File Access configs
"""
FILES_URL: str = Field(
description="File preview or download Url prefix."
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
"Url is signed and has expiration time.",
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
description='File preview or download Url prefix.'
' used to display File preview or download Url to the front-end or as Multi-model inputs;'
'Url is signed and has expiration time.',
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
alias_priority=1,
default="",
default='',
)
FILES_ACCESS_TIMEOUT: int = Field(
description="timeout in seconds for file accessing",
description='timeout in seconds for file accessing',
default=300,
)
@@ -163,24 +105,23 @@ class FileUploadConfig(BaseSettings):
"""
File Uploading configs
"""
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="size limit in Megabytes for uploading files",
description='size limit in Megabytes for uploading files',
default=15,
)
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description="batch size limit for uploading files",
description='batch size limit for uploading files',
default=5,
)
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="image file size limit in Megabytes for uploading files",
description='image file size limit in Megabytes for uploading files',
default=10,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="", # todo: to be clarified
description='', # todo: to be clarified
default=20,
)
@@ -189,79 +130,45 @@ class HttpConfig(BaseSettings):
"""
HTTP configs
"""
API_COMPRESSION_ENABLED: bool = Field(
description="whether to enable HTTP response compression of gzip",
description='whether to enable HTTP response compression of gzip',
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description="",
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
default="",
description='',
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
default='',
)
@computed_field
@property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
description="",
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
default="*",
description='',
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
default='*',
)
@computed_field
@property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="",
default=1 * 1024 * 1024,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="HTTP URL for SSRF proxy",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="HTTPS URL for SSRF proxy",
default=None,
)
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
class InnerAPIConfig(BaseSettings):
"""
Inner API configs
"""
INNER_API: bool = Field(
description="whether to enable the inner API",
description='whether to enable the inner API',
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description="The inner API key is used to authenticate the inner API",
description='The inner API key is used to authenticate the inner API',
default=None,
)
@@ -272,27 +179,28 @@ class LoggingConfig(BaseSettings):
"""
LOG_LEVEL: str = Field(
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
default="INFO",
description='Log output level, default to INFO.'
'It is recommended to set it to ERROR for production.',
default='INFO',
)
LOG_FILE: Optional[str] = Field(
description="logging output file path",
description='logging output file path',
default=None,
)
LOG_FORMAT: str = Field(
description="log format",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
description='log format',
default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
)
LOG_DATEFORMAT: Optional[str] = Field(
description="log date format",
description='log date format',
default=None,
)
LOG_TZ: Optional[str] = Field(
description="specify log timezone, eg: America/New_York",
description='specify log timezone, eg: America/New_York',
default=None,
)
@@ -301,9 +209,8 @@ class ModelLoadBalanceConfig(BaseSettings):
"""
Model load balance configs
"""
MODEL_LB_ENABLED: bool = Field(
description="whether to enable model load balancing",
description='whether to enable model load balancing',
default=False,
)
@@ -312,9 +219,8 @@ class BillingConfig(BaseSettings):
"""
Platform Billing Configurations
"""
BILLING_ENABLED: bool = Field(
description="whether to enable billing",
description='whether to enable billing',
default=False,
)
@@ -323,10 +229,9 @@ class UpdateConfig(BaseSettings):
"""
Update configs
"""
CHECK_UPDATE_URL: str = Field(
description="url for checking updates",
default="https://updates.dify.ai",
description='url for checking updates',
default='https://updates.dify.ai',
)
@@ -336,53 +241,47 @@ class WorkflowConfig(BaseSettings):
"""
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description="max execution steps in single workflow execution",
description='max execution steps in single workflow execution',
default=500,
)
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description="max execution time in seconds in single workflow execution",
description='max execution time in seconds in single workflow execution',
default=1200,
)
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description="max depth of calling in single workflow execution",
description='max depth of calling in single workflow execution',
default=5,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="The maximum size in bytes of a variable. default to 5KB.",
default=5 * 1024,
)
class OAuthConfig(BaseSettings):
"""
oauth configs
"""
OAUTH_REDIRECT_PATH: str = Field(
description="redirect path for OAuth",
default="/console/api/oauth/authorize",
description='redirect path for OAuth',
default='/console/api/oauth/authorize',
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub client id for OAuth",
description='GitHub client id for OAuth',
default=None,
)
GITHUB_CLIENT_SECRET: Optional[str] = Field(
description="GitHub client secret key for OAuth",
description='GitHub client secret key for OAuth',
default=None,
)
GOOGLE_CLIENT_ID: Optional[str] = Field(
description="Google client id for OAuth",
description='Google client id for OAuth',
default=None,
)
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description="Google client secret key for OAuth",
description='Google client secret key for OAuth',
default=None,
)
@@ -392,8 +291,9 @@ class ModerationConfig(BaseSettings):
Moderation in app configs.
"""
MODERATION_BUFFER_SIZE: PositiveInt = Field(
description="buffer size for moderation",
# todo: to be clarified in usage and unit
OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
description='buffer size for moderation',
default=300,
)
@@ -404,7 +304,7 @@ class ToolConfig(BaseSettings):
"""
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description="max age in seconds for tool icon caching",
description='max age in seconds for tool icon caching',
default=3600,
)
@@ -415,52 +315,52 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
default=None,
)
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description="default email address for sending from ",
description='default email address for sending from ',
default=None,
)
RESEND_API_KEY: Optional[str] = Field(
description="API key for Resend",
description='API key for Resend',
default=None,
)
RESEND_API_URL: Optional[str] = Field(
description="API URL for Resend",
description='API URL for Resend',
default=None,
)
SMTP_SERVER: Optional[str] = Field(
description="smtp server host",
description='smtp server host',
default=None,
)
SMTP_PORT: Optional[int] = Field(
description="smtp server port",
description='smtp server port',
default=465,
)
SMTP_USERNAME: Optional[str] = Field(
description="smtp server username",
description='smtp server username',
default=None,
)
SMTP_PASSWORD: Optional[str] = Field(
description="smtp server password",
description='smtp server password',
default=None,
)
SMTP_USE_TLS: bool = Field(
description="whether to use TLS connection to smtp server",
description='whether to use TLS connection to smtp server',
default=False,
)
SMTP_OPPORTUNISTIC_TLS: bool = Field(
description="whether to use opportunistic TLS connection to smtp server",
description='whether to use opportunistic TLS connection to smtp server',
default=False,
)
@@ -471,22 +371,22 @@ class RagEtlConfig(BaseSettings):
"""
ETL_TYPE: str = Field(
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
default="dify",
description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
default='dify',
)
KEYWORD_DATA_SOURCE_TYPE: str = Field(
description="source type for keyword data, default to `database`, available values are `database` .",
default="database",
description='source type for keyword data, default to `database`, available values are `database` .',
default='database',
)
UNSTRUCTURED_API_URL: Optional[str] = Field(
description="API URL for Unstructured",
description='API URL for Unstructured',
default=None,
)
UNSTRUCTURED_API_KEY: Optional[str] = Field(
description="API key for Unstructured",
description='API key for Unstructured',
default=None,
)
@@ -497,12 +397,12 @@ class DataSetConfig(BaseSettings):
"""
CLEAN_DAY_SETTING: PositiveInt = Field(
description="interval in days for cleaning up dataset",
description='interval in days for cleaning up dataset',
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description="whether to enable dataset operator",
description='whether to enable dataset operator',
default=False,
)
@@ -513,7 +413,7 @@ class WorkspaceConfig(BaseSettings):
"""
INVITE_EXPIRY_HOURS: PositiveInt = Field(
description="workspaces invitation expiration in hours",
description='workspaces invitation expiration in hours',
default=72,
)
@@ -524,79 +424,80 @@ class IndexingConfig(BaseSettings):
"""
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description="max segmentation token length for indexing",
description='max segmentation token length for indexing',
default=1000,
)
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description="multi model send image format, support base64, url, default is base64",
default="base64",
description='multi model send image format, support base64, url, default is base64',
default='base64',
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="the time of the celery scheduler, default to 1 day",
description='the time of the celery scheduler, default to 1 day',
default=1,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="The heads of model providers",
default="",
description='The heads of model providers',
default='',
)
POSITION_PROVIDER_INCLUDES: str = Field(
description="The included model providers",
default="",
description='The included model providers',
default='',
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description="The excluded model providers",
default="",
description='The excluded model providers',
default='',
)
POSITION_TOOL_PINS: str = Field(
description="The heads of tools",
default="",
description='The heads of tools',
default='',
)
POSITION_TOOL_INCLUDES: str = Field(
description="The included tools",
default="",
description='The included tools',
default='',
)
POSITION_TOOL_EXCLUDES: str = Field(
description="The excluded tools",
default="",
description='The excluded tools',
default='',
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
class FeatureConfig(
@@ -624,6 +525,7 @@ class FeatureConfig(
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,

View File

@@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
"""
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description="",
description='',
default=False,
)
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"text-davinci-003",
description='',
default='gpt-3.5-turbo,'
'gpt-3.5-turbo-1106,'
'gpt-3.5-turbo-instruct,'
'gpt-3.5-turbo-16k,'
'gpt-3.5-turbo-16k-0613,'
'gpt-3.5-turbo-0613,'
'gpt-3.5-turbo-0125,'
'text-davinci-003',
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="",
description='',
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="",
description='',
default=False,
)
HOSTED_OPENAI_PAID_MODELS: str = Field(
description="",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003",
description='',
default='gpt-4,'
'gpt-4-turbo-preview,'
'gpt-4-turbo-2024-04-09,'
'gpt-4-1106-preview,'
'gpt-4-0125-preview,'
'gpt-3.5-turbo,'
'gpt-3.5-turbo-16k,'
'gpt-3.5-turbo-16k-0613,'
'gpt-3.5-turbo-1106,'
'gpt-3.5-turbo-0613,'
'gpt-3.5-turbo-0125,'
'gpt-3.5-turbo-instruct,'
'text-davinci-003',
)
@@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
"""
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description="",
description='',
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="",
description='',
default=200,
)
@@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
"""
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description="",
description='',
default=None,
)
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description="",
description='',
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="",
description='',
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="",
description='',
default=False,
)
@@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
"""
HOSTED_MINIMAX_ENABLED: bool = Field(
description="",
description='',
default=False,
)
@@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
"""
HOSTED_SPARK_ENABLED: bool = Field(
description="",
description='',
default=False,
)
@@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
"""
HOSTED_ZHIPUAI_ENABLED: bool = Field(
description="",
description='',
default=False,
)
@@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
"""
HOSTED_MODERATION_ENABLED: bool = Field(
description="",
description='',
default=False,
)
HOSTED_MODERATION_PROVIDERS: str = Field(
description="",
default="",
description='',
default='',
)
@@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="the mode for fetching app templates,"
" default to remote,"
" available values: remote, db, builtin",
default="remote",
description='the mode for fetching app templates,'
' default to remote,'
' available values: remote, db, builtin',
default='remote',
)
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="the domain for fetching remote app templates",
default="https://tmpl.dify.ai",
description='the domain for fetching remote app templates',
default='https://tmpl.dify.ai',
)
@@ -202,6 +202,7 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
# moderation
HostedModerationConfig,
):

View File

@@ -13,7 +13,6 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
@@ -29,108 +28,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description="storage type,"
" default to `local`,"
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
default="local",
description='storage type,'
' default to `local`,'
' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
default='local',
)
STORAGE_LOCAL_PATH: str = Field(
description="local storage path",
default="storage",
description='local storage path',
default='storage',
)
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description="vector store type",
description='vector store type',
default=None,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description="keyword store type",
default="jieba",
description='keyword store type',
default='jieba',
)
class DatabaseConfig:
DB_HOST: str = Field(
description="db host",
default="localhost",
description='db host',
default='localhost',
)
DB_PORT: PositiveInt = Field(
description="db port",
description='db port',
default=5432,
)
DB_USERNAME: str = Field(
description="db username",
default="postgres",
description='db username',
default='postgres',
)
DB_PASSWORD: str = Field(
description="db password",
default="",
description='db password',
default='',
)
DB_DATABASE: str = Field(
description="db database",
default="dify",
description='db database',
default='dify',
)
DB_CHARSET: str = Field(
description="db charset",
default="",
description='db charset',
default='',
)
DB_EXTRAS: str = Field(
description="db extras options. Example: keepalives_idle=60&keepalives=1",
default="",
description='db extras options. Example: keepalives_idle=60&keepalives=1',
default='',
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="db uri scheme",
default="postgresql",
description='db uri scheme',
default='postgresql',
)
@computed_field
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
if self.DB_CHARSET
else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}"
)
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}")
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description="pool size of SqlAlchemy",
description='pool size of SqlAlchemy',
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description="max overflows for SqlAlchemy",
description='max overflows for SqlAlchemy',
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description="SqlAlchemy pool recycle",
description='SqlAlchemy pool recycle',
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="whether to enable pool pre-ping in SqlAlchemy",
description='whether to enable pool pre-ping in SqlAlchemy',
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description="whether to enable SqlAlchemy echo",
description='whether to enable SqlAlchemy echo',
default=False,
)
@@ -138,38 +137,35 @@ class DatabaseConfig:
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"},
'pool_size': self.SQLALCHEMY_POOL_SIZE,
'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
'connect_args': {'options': '-c timezone=UTC'},
}
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Celery backend, available values are `database`, `redis`",
default="database",
description='Celery backend, available values are `database`, `redis`',
default='database',
)
CELERY_BROKER_URL: Optional[str] = Field(
description="CELERY_BROKER_URL",
description='CELERY_BROKER_URL',
default=None,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
if self.CELERY_BACKEND == "database"
else self.CELERY_BROKER_URL
)
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
@computed_field
@property
def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
class MiddlewareConfig(
@@ -178,6 +174,7 @@ class MiddlewareConfig(
DatabaseConfig,
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,
@@ -186,6 +183,7 @@ class MiddlewareConfig(
TencentCloudCOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,
@@ -201,6 +199,5 @@ class MiddlewareConfig(
TencentVectorDBConfig,
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
):
pass

View File

@@ -8,33 +8,32 @@ class RedisConfig(BaseSettings):
"""
Redis configs
"""
REDIS_HOST: str = Field(
description="Redis host",
default="localhost",
description='Redis host',
default='localhost',
)
REDIS_PORT: PositiveInt = Field(
description="Redis port",
description='Redis port',
default=6379,
)
REDIS_USERNAME: Optional[str] = Field(
description="Redis username",
description='Redis username',
default=None,
)
REDIS_PASSWORD: Optional[str] = Field(
description="Redis password",
description='Redis password',
default=None,
)
REDIS_DB: NonNegativeInt = Field(
description="Redis database id, default to 0",
description='Redis database id, default to 0',
default=0,
)
REDIS_USE_SSL: bool = Field(
description="whether to use SSL for Redis connection",
description='whether to use SSL for Redis connection',
default=False,
)

View File

@@ -10,36 +10,31 @@ class AliyunOSSStorageConfig(BaseSettings):
"""
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description="Aliyun OSS bucket name",
description='Aliyun OSS bucket name',
default=None,
)
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description="Aliyun OSS access key",
description='Aliyun OSS access key',
default=None,
)
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description="Aliyun OSS secret key",
description='Aliyun OSS secret key',
default=None,
)
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description="Aliyun OSS endpoint URL",
description='Aliyun OSS endpoint URL',
default=None,
)
ALIYUN_OSS_REGION: Optional[str] = Field(
description="Aliyun OSS region",
description='Aliyun OSS region',
default=None,
)
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description="Aliyun OSS authentication version",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Aliyun OSS path",
description='Aliyun OSS authentication version',
default=None,
)

View File

@@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
"""
S3_ENDPOINT: Optional[str] = Field(
description="S3 storage endpoint",
description='S3 storage endpoint',
default=None,
)
S3_REGION: Optional[str] = Field(
description="S3 storage region",
description='S3 storage region',
default=None,
)
S3_BUCKET_NAME: Optional[str] = Field(
description="S3 storage bucket name",
description='S3 storage bucket name',
default=None,
)
S3_ACCESS_KEY: Optional[str] = Field(
description="S3 storage access key",
description='S3 storage access key',
default=None,
)
S3_SECRET_KEY: Optional[str] = Field(
description="S3 storage secret key",
description='S3 storage secret key',
default=None,
)
S3_ADDRESS_STYLE: str = Field(
description="S3 storage address style",
default="auto",
description='S3 storage address style',
default='auto',
)
S3_USE_AWS_MANAGED_IAM: bool = Field(
description="whether to use aws managed IAM for S3",
description='whether to use aws managed IAM for S3',
default=False,
)

View File

@@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
"""
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description="Azure Blob account name",
description='Azure Blob account name',
default=None,
)
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description="Azure Blob account key",
description='Azure Blob account key',
default=None,
)
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description="Azure Blob container name",
description='Azure Blob container name',
default=None,
)
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description="Azure Blob account URL",
description='Azure Blob account URL',
default=None,
)

View File

@@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
"""
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description="Google Cloud storage bucket name",
description='Google Cloud storage bucket name',
default=None,
)
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description="Google Cloud storage service account json base64",
description='Google Cloud storage service account json base64',
default=None,
)

View File

@@ -10,26 +10,27 @@ class OCIStorageConfig(BaseSettings):
"""
OCI_ENDPOINT: Optional[str] = Field(
description="OCI storage endpoint",
description='OCI storage endpoint',
default=None,
)
OCI_REGION: Optional[str] = Field(
description="OCI storage region",
description='OCI storage region',
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description="OCI storage bucket name",
description='OCI storage bucket name',
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description="OCI storage access key",
description='OCI storage access key',
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description="OCI storage secret key",
description='OCI storage secret key',
default=None,
)

View File

@@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
"""
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description="Tencent Cloud COS bucket name",
description='Tencent Cloud COS bucket name',
default=None,
)
TENCENT_COS_REGION: Optional[str] = Field(
description="Tencent Cloud COS region",
description='Tencent Cloud COS region',
default=None,
)
TENCENT_COS_SECRET_ID: Optional[str] = Field(
description="Tencent Cloud COS secret id",
description='Tencent Cloud COS secret id',
default=None,
)
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description="Tencent Cloud COS secret key",
description='Tencent Cloud COS secret key',
default=None,
)
TENCENT_COS_SCHEME: Optional[str] = Field(
description="Tencent Cloud COS scheme",
description='Tencent Cloud COS scheme',
default=None,
)

View File

@@ -10,28 +10,35 @@ class AnalyticdbConfig(BaseModel):
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID: Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_ACCOUNT: Optional[str] = Field(
default=None, description="The account name used to log in to the AnalyticDB instance."
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_PASSWORD: Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for authentication."
ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_NAMESPACE: Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation."
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
)
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
"""
CHROMA_HOST: Optional[str] = Field(
description="Chroma host",
description='Chroma host',
default=None,
)
CHROMA_PORT: PositiveInt = Field(
description="Chroma port",
description='Chroma port',
default=8000,
)
CHROMA_TENANT: Optional[str] = Field(
description="Chroma database",
description='Chroma database',
default=None,
)
CHROMA_DATABASE: Optional[str] = Field(
description="Chroma database",
description='Chroma database',
default=None,
)
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description="Chroma authentication provider",
description='Chroma authentication provider',
default=None,
)
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description="Chroma authentication credentials",
description='Chroma authentication credentials',
default=None,
)

View File

@@ -1,30 +0,0 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)

View File

@@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings):
"""
MILVUS_HOST: Optional[str] = Field(
description="Milvus host",
description='Milvus host',
default=None,
)
MILVUS_PORT: PositiveInt = Field(
description="Milvus RestFul API port",
description='Milvus RestFul API port',
default=9091,
)
MILVUS_USER: Optional[str] = Field(
description="Milvus user",
description='Milvus user',
default=None,
)
MILVUS_PASSWORD: Optional[str] = Field(
description="Milvus password",
description='Milvus password',
default=None,
)
MILVUS_SECURE: bool = Field(
description="whether to use SSL connection for Milvus",
description='whether to use SSL connection for Milvus',
default=False,
)
MILVUS_DATABASE: str = Field(
description="Milvus database, default to `default`",
default="default",
description='Milvus database, default to `default`',
default='default',
)

View File

@@ -1,3 +1,4 @@
from pydantic import BaseModel, Field, PositiveInt
@@ -7,31 +8,31 @@ class MyScaleConfig(BaseModel):
"""
MYSCALE_HOST: str = Field(
description="MyScale host",
default="localhost",
description='MyScale host',
default='localhost',
)
MYSCALE_PORT: PositiveInt = Field(
description="MyScale port",
description='MyScale port',
default=8123,
)
MYSCALE_USER: str = Field(
description="MyScale user",
default="default",
description='MyScale user',
default='default',
)
MYSCALE_PASSWORD: str = Field(
description="MyScale password",
default="",
description='MyScale password',
default='',
)
MYSCALE_DATABASE: str = Field(
description="MyScale database name",
default="default",
description='MyScale database name',
default='default',
)
MYSCALE_FTS_PARAMS: str = Field(
description="MyScale fts index parameters",
default="",
description='MyScale fts index parameters',
default='',
)

View File

@@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
"""
OPENSEARCH_HOST: Optional[str] = Field(
description="OpenSearch host",
description='OpenSearch host',
default=None,
)
OPENSEARCH_PORT: PositiveInt = Field(
description="OpenSearch port",
description='OpenSearch port',
default=9200,
)
OPENSEARCH_USER: Optional[str] = Field(
description="OpenSearch user",
description='OpenSearch user',
default=None,
)
OPENSEARCH_PASSWORD: Optional[str] = Field(
description="OpenSearch password",
description='OpenSearch password',
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="whether to use SSL connection for OpenSearch",
description='whether to use SSL connection for OpenSearch',
default=False,
)

View File

@@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
"""
ORACLE_HOST: Optional[str] = Field(
description="ORACLE host",
description='ORACLE host',
default=None,
)
ORACLE_PORT: Optional[PositiveInt] = Field(
description="ORACLE port",
description='ORACLE port',
default=1521,
)
ORACLE_USER: Optional[str] = Field(
description="ORACLE user",
description='ORACLE user',
default=None,
)
ORACLE_PASSWORD: Optional[str] = Field(
description="ORACLE password",
description='ORACLE password',
default=None,
)
ORACLE_DATABASE: Optional[str] = Field(
description="ORACLE database",
description='ORACLE database',
default=None,
)

View File

@@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
"""
PGVECTOR_HOST: Optional[str] = Field(
description="PGVector host",
description='PGVector host',
default=None,
)
PGVECTOR_PORT: Optional[PositiveInt] = Field(
description="PGVector port",
description='PGVector port',
default=5433,
)
PGVECTOR_USER: Optional[str] = Field(
description="PGVector user",
description='PGVector user',
default=None,
)
PGVECTOR_PASSWORD: Optional[str] = Field(
description="PGVector password",
description='PGVector password',
default=None,
)
PGVECTOR_DATABASE: Optional[str] = Field(
description="PGVector database",
description='PGVector database',
default=None,
)

View File

@@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
"""
PGVECTO_RS_HOST: Optional[str] = Field(
description="PGVectoRS host",
description='PGVectoRS host',
default=None,
)
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
description="PGVectoRS port",
description='PGVectoRS port',
default=5431,
)
PGVECTO_RS_USER: Optional[str] = Field(
description="PGVectoRS user",
description='PGVectoRS user',
default=None,
)
PGVECTO_RS_PASSWORD: Optional[str] = Field(
description="PGVectoRS password",
description='PGVectoRS password',
default=None,
)
PGVECTO_RS_DATABASE: Optional[str] = Field(
description="PGVectoRS database",
description='PGVectoRS database',
default=None,
)

View File

@@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
"""
QDRANT_URL: Optional[str] = Field(
description="Qdrant url",
description='Qdrant url',
default=None,
)
QDRANT_API_KEY: Optional[str] = Field(
description="Qdrant api key",
description='Qdrant api key',
default=None,
)
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Qdrant client timeout in seconds",
description='Qdrant client timeout in seconds',
default=20,
)
QDRANT_GRPC_ENABLED: bool = Field(
description="whether enable grpc support for Qdrant connection",
description='whether enable grpc support for Qdrant connection',
default=False,
)
QDRANT_GRPC_PORT: PositiveInt = Field(
description="Qdrant grpc port",
description='Qdrant grpc port',
default=6334,
)

View File

@@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
"""
RELYT_HOST: Optional[str] = Field(
description="Relyt host",
description='Relyt host',
default=None,
)
RELYT_PORT: PositiveInt = Field(
description="Relyt port",
description='Relyt port',
default=9200,
)
RELYT_USER: Optional[str] = Field(
description="Relyt user",
description='Relyt user',
default=None,
)
RELYT_PASSWORD: Optional[str] = Field(
description="Relyt password",
description='Relyt password',
default=None,
)
RELYT_DATABASE: Optional[str] = Field(
description="Relyt database",
default="default",
description='Relyt database',
default='default',
)

View File

@@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
"""
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description="Tencent Vector URL",
description='Tencent Vector URL',
default=None,
)
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description="Tencent Vector API key",
description='Tencent Vector API key',
default=None,
)
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description="Tencent Vector timeout in seconds",
description='Tencent Vector timeout in seconds',
default=30,
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description="Tencent Vector username",
description='Tencent Vector username',
default=None,
)
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description="Tencent Vector password",
description='Tencent Vector password',
default=None,
)
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description="Tencent Vector sharding number",
description='Tencent Vector sharding number',
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Tencent Vector replicas",
description='Tencent Vector replicas',
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Tencent Vector Database",
description='Tencent Vector Database',
default=None,
)

View File

@@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
"""
TIDB_VECTOR_HOST: Optional[str] = Field(
description="TiDB Vector host",
description='TiDB Vector host',
default=None,
)
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description="TiDB Vector port",
description='TiDB Vector port',
default=4000,
)
TIDB_VECTOR_USER: Optional[str] = Field(
description="TiDB Vector user",
description='TiDB Vector user',
default=None,
)
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description="TiDB Vector password",
description='TiDB Vector password',
default=None,
)
TIDB_VECTOR_DATABASE: Optional[str] = Field(
description="TiDB Vector database",
description='TiDB Vector database',
default=None,
)

View File

@@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
"""
WEAVIATE_ENDPOINT: Optional[str] = Field(
description="Weaviate endpoint URL",
description='Weaviate endpoint URL',
default=None,
)
WEAVIATE_API_KEY: Optional[str] = Field(
description="Weaviate API key",
description='Weaviate API key',
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description="whether to enable gRPC for Weaviate connection",
description='whether to enable gRPC for Weaviate connection',
default=True,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Weaviate batch size",
description='Weaviate batch size',
default=100,
)

View File

@@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
"""
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.8.0-beta1",
description='Dify version',
default='0.7.1',
)
COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app",
default="",
default='',
)

View File

@@ -1 +1,3 @@

View File

@@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint("console", __name__, url_prefix="/console/api")
bp = Blueprint('console', __name__, url_prefix='/console/api')
api = ExternalApi(bp)
# Import other controllers

View File

@@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
def admin_required(view):
@wraps(view)
def decorated(*args, **kwargs):
if not os.getenv("ADMIN_API_KEY"):
raise Unauthorized("API key is invalid.")
if not os.getenv('ADMIN_API_KEY'):
raise Unauthorized('API key is invalid.')
auth_header = request.headers.get("Authorization")
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized("Authorization header is missing.")
raise Unauthorized('Authorization header is missing.')
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if os.getenv("ADMIN_API_KEY") != auth_token:
raise Unauthorized("API key is invalid.")
if os.getenv('ADMIN_API_KEY') != auth_token:
raise Unauthorized('API key is invalid.')
return view(*args, **kwargs)
@@ -44,41 +44,37 @@ class InsertExploreAppListApi(Resource):
@admin_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("desc", type=str, location="json")
parser.add_argument("copyright", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
parser.add_argument('desc', type=str, location='json')
parser.add_argument('copyright', type=str, location='json')
parser.add_argument('privacy_policy', type=str, location='json')
parser.add_argument('custom_disclaimer', type=str, location='json')
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
args = parser.parse_args()
app = App.query.filter(App.id == args["app_id"]).first()
app = App.query.filter(App.id == args['app_id']).first()
if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found')
site = app.site
if not site:
desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
desc = args['desc'] if args['desc'] else ''
copy_right = args['copyright'] if args['copyright'] else ''
privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
else:
desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = (
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
desc = site.description if site.description else \
args['desc'] if args['desc'] else ''
copy_right = site.copyright if site.copyright else \
args['copyright'] if args['copyright'] else ''
privacy_policy = site.privacy_policy if site.privacy_policy else \
args['privacy_policy'] if args['privacy_policy'] else ''
custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
args['custom_disclaimer'] if args['custom_disclaimer'] else ''
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
if not recommended_app:
recommended_app = RecommendedApp(
@@ -87,9 +83,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=args["language"],
category=args["category"],
position=args["position"],
language=args['language'],
category=args['category'],
position=args['position']
)
db.session.add(recommended_app)
@@ -97,21 +93,21 @@ class InsertExploreAppListApi(Resource):
app.is_public = True
db.session.commit()
return {"result": "success"}, 201
return {'result': 'success'}, 201
else:
recommended_app.description = desc
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
recommended_app.language = args['language']
recommended_app.category = args['category']
recommended_app.position = args['position']
app.is_public = True
db.session.commit()
return {"result": "success"}, 200
return {'result': 'success'}, 200
class InsertExploreAppApi(Resource):
@@ -120,14 +116,15 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
if not recommended_app:
return {"result": "success"}, 204
return {'result': 'success'}, 204
app = App.query.filter(App.id == recommended_app.app_id).first()
if app:
app.is_public = False
installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
).all()
for installed_app in installed_apps:
@@ -136,8 +133,8 @@ class InsertExploreAppApi(Resource):
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
return {'result': 'success'}, 204
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")
api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')

View File

@@ -14,21 +14,26 @@ from .setup import setup_required
from .wraps import account_initialization_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
'id': fields.String,
'type': fields.String,
'token': fields.String,
'last_used_at': TimestampField,
'created_at': TimestampField
}
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
api_key_list = {
'data': fields.List(fields.Nested(api_key_fields), attribute="items")
}
def _get_resource(resource_id, tenant_id, resource_model):
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
resource = resource_model.query.filter_by(
id=resource_id, tenant_id=tenant_id
).first()
if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
flask_restful.abort(
404, message=f"{resource_model.__name__} not found.")
return resource
@@ -45,32 +50,30 @@ class BaseApiKeyListResource(Resource):
@marshal_with(api_key_list)
def get(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
_get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
all()
return {"items": keys}
@marshal_with(api_key_fields)
def post(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
count()
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
code='max_keys_exceeded'
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -94,78 +97,79 @@ class BaseApiKeyResource(Resource):
def delete(self, resource_id, api_key_id):
resource_id = str(resource_id)
api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
key = (
db.session.query(ApiToken)
.filter(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
key = db.session.query(ApiToken). \
filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
first()
if key is None:
flask_restful.abort(404, message="API key not found")
flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204
return {'result': 'success'}, 204
class AppApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp
resource_type = "app"
resource_type = 'app'
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
resource_id_field = 'app_id'
token_prefix = 'app-'
class AppApiKeyResource(BaseApiKeyResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp
resource_type = "app"
resource_type = 'app'
resource_model = App
resource_id_field = "app_id"
resource_id_field = 'app_id'
class DatasetApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp
resource_type = "dataset"
resource_type = 'dataset'
resource_model = Dataset
resource_id_field = "dataset_id"
token_prefix = "ds-"
resource_id_field = 'dataset_id'
token_prefix = 'ds-'
class DatasetApiKeyResource(BaseApiKeyResource):
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp
resource_type = "dataset"
resource_type = 'dataset'
resource_model = Dataset
resource_id_field = "dataset_id"
resource_id_field = 'dataset_id'
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
api.add_resource(AppApiKeyResource,
'/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiKeyListResource,
'/datasets/<uuid:resource_id>/api-keys')
api.add_resource(DatasetApiKeyResource,
'/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')

View File

@@ -8,18 +8,19 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
parser.add_argument("model_name", type=str, required=True, location="args")
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')

View File

@@ -18,12 +18,15 @@ class AgentLogApi(Resource):
def get(self, app_model):
"""Get agent logs"""
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
parser.add_argument('message_id', type=uuid_value, required=True, location='args')
parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")
return AgentService.get_agent_logs(
app_model,
args['conversation_id'],
args['message_id']
)
api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')

View File

@@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
args = parser.parse_args()
if action == "enable":
if action == 'enable':
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable":
elif action == 'disable':
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
raise ValueError('Unsupported annotation reply action')
return result, 200
@@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument('score_threshold', required=True, type=float, location='json')
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@@ -77,24 +77,28 @@ class AnnotationReplyActionStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = ''
if job_status == 'error':
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationListApi(Resource):
@@ -105,18 +109,18 @@ class AnnotationListApi(Resource):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default=None, type=str)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get('keyword', default=None, type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
'data': marshal(annotation_list, annotation_fields),
'has_more': len(annotation_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
@@ -131,7 +135,9 @@ class AnnotationExportApi(Resource):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
response = {
'data': marshal(annotation_list, annotation_fields)
}
return response, 200
@@ -139,7 +145,7 @@ class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id):
if not current_user.is_editor:
@@ -147,8 +153,8 @@ class AnnotationCreateApi(Resource):
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
@@ -158,7 +164,7 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
@@ -167,8 +173,8 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@@ -183,29 +189,29 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
class AnnotationBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files["file"]
file = request.files['file']
# check file
if "file" not in request.files:
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith(".csv"):
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
@@ -214,23 +220,27 @@ class AnnotationBatchImportStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
error_msg = ''
if job_status == 'error':
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationHitHistoryListApi(Resource):
@@ -241,32 +251,30 @@ class AnnotationHitHistoryListApi(Resource):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
page, limit)
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
'has_more': len(annotation_hit_history_list) == limit,
'limit': limit,
'total': total,
'page': page
}
return response
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
api.add_resource(
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
)
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource(AnnotationReplyActionStatusApi,
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')

View File

@@ -18,35 +18,27 @@ from libs.login import login_required
from services.app_dsl_service import AppDslService
from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
class AppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Get app list"""
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(",")]
return [str(uuid.UUID(v)) for v in value.split(',')]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"mode",
type=str,
choices=["chat", "workflow", "agent-chat", "channel", "all"],
default="all",
location="args",
required=False,
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
parser.add_argument('name', type=str, location='args', required=False)
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
args = parser.parse_args()
@@ -54,7 +46,7 @@ class AppListApi(Resource):
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
return marshal(app_pagination, app_pagination_fields)
@@ -62,23 +54,23 @@ class AppListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps")
@cloud_edition_billing_resource_check('apps')
def post(self):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None:
if 'mode' not in args or args['mode'] is None:
raise BadRequest("mode is required")
app_service = AppService()
@@ -92,7 +84,7 @@ class AppImportApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
@cloud_edition_billing_resource_check('apps')
def post(self):
"""Import app"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -100,16 +92,19 @@ class AppImportApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
tenant_id=current_user.current_tenant_id,
data=args['data'],
args=args,
account=current_user
)
return app, 201
@@ -120,7 +115,7 @@ class AppImportFromUrlApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
@cloud_edition_billing_resource_check('apps')
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -128,21 +123,25 @@ class AppImportFromUrlApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
tenant_id=current_user.current_tenant_id,
url=args['url'],
args=args,
account=current_user
)
return app, 201
class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -166,15 +165,14 @@ class AppApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
args = parser.parse_args()
app_service = AppService()
@@ -195,7 +193,7 @@ class AppApi(Resource):
app_service = AppService()
app_service.delete_app(app_model)
return {"result": "success"}, 204
return {'result': 'success'}, 204
class AppCopyApi(Resource):
@@ -211,16 +209,19 @@ class AppCopyApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
tenant_id=current_user.current_tenant_id,
data=data,
args=args,
account=current_user
)
return app, 201
@@ -239,10 +240,12 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
return {
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
}
class AppNameApi(Resource):
@@ -255,13 +258,13 @@ class AppNameApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument('name', type=str, required=True, location='json')
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.get("name"))
app_model = app_service.update_app_name(app_model, args.get('name'))
return app_model
@@ -276,14 +279,14 @@ class AppIconApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
return app_model
@@ -298,13 +301,13 @@ class AppSiteStatus(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
parser.add_argument('enable_site', type=bool, required=True, location='json')
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
return app_model
@@ -319,13 +322,13 @@ class AppApiStatus(Resource):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_api", type=bool, required=True, location="json")
parser.add_argument('enable_api', type=bool, required=True, location='json')
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
return app_model
@@ -336,7 +339,9 @@ class AppTraceApi(Resource):
@account_initialization_required
def get(self, app_id):
"""Get app trace"""
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
app_trace_config = OpsTraceManager.get_app_tracing_config(
app_id=app_id
)
return app_trace_config
@@ -348,27 +353,27 @@ class AppTraceApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument('enabled', type=bool, required=True, location='json')
parser.add_argument('tracing_provider', type=str, required=True, location='json')
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
enabled=args['enabled'],
tracing_provider=args['tracing_provider'],
)
return {"result": "success"}
api.add_resource(AppListApi, "/apps")
api.add_resource(AppImportApi, "/apps/import")
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")
api.add_resource(AppListApi, '/apps')
api.add_resource(AppImportApi, '/apps/import')
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
api.add_resource(AppApi, '/apps/<uuid:app_id>')
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')

View File

@@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model):
file = request.files["file"]
file = request.files['file']
try:
response = AudioService.transcript_asr(
@@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource):
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
parser.add_argument('message_id', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
response = AudioService.transcript_tts(
app_model=app_model,
text=text,
message_id=message_id,
voice=voice
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
@@ -145,12 +145,12 @@ class TextModesApi(Resource):
def get(self, app_model):
try:
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, required=True, location="args")
parser.add_argument('language', type=str, required=True, location='args')
args = parser.parse_args()
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args["language"],
language=args['language'],
)
return response
@@ -179,6 +179,6 @@ class TextModesApi(Resource):
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')

View File

@@ -17,7 +17,6 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@@ -32,33 +31,37 @@ from libs.helper import uuid_value
from libs.login import login_required
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
# define completion message api for user
class CompletionMessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False
account = flask_login.current_user
try:
response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
app_model=app_model,
user=account,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
)
return helper.compact_generate_response(response)
@@ -94,7 +97,7 @@ class CompletionMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
class ChatMessageApi(Resource):
@@ -104,23 +107,27 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False
account = flask_login.current_user
try:
response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
app_model=app_model,
user=account,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
)
return helper.compact_generate_response(response)
@@ -137,8 +144,6 @@ class ChatMessageApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
@@ -158,10 +163,10 @@ class ChatMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')

View File

@@ -26,6 +26,7 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat
class CompletionConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -35,23 +36,24 @@ class CompletionConversationApi(Resource):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('annotation_status', type=str,
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
args = parser.parse_args()
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
if args['keyword']:
query = query.join(
Message, Message.conversation_id == Conversation.id
).filter(
or_(
Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])),
Message.query.ilike('%{}%'.format(args['keyword'])),
Message.answer.ilike('%{}%'.format(args['keyword']))
)
)
@@ -59,8 +61,8 @@ class CompletionConversationApi(Resource):
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
@@ -68,8 +70,8 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
@@ -77,25 +79,29 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
if args["annotation_status"] == "annotated":
if args['annotation_status'] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
elif args['annotation_status'] == "not_annotated":
query = query.outerjoin(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
conversations = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
return conversations
class CompletionConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -117,11 +123,8 @@ class CompletionConversationDetailApi(Resource):
raise Forbidden()
conversation_id = str(conversation_id)
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
@@ -129,10 +132,11 @@ class CompletionConversationDetailApi(Resource):
conversation.is_deleted = True
db.session.commit()
return {"result": "success"}, 204
return {'result': 'success'}, 204
class ChatConversationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -142,28 +146,22 @@ class ChatConversationApi(Resource):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('annotation_status', type=str,
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
subquery = (
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
Conversation.id.label('conversation_id'),
EndUser.session_id.label('from_end_user_session_id')
)
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
@@ -171,31 +169,28 @@ class ChatConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
if args["keyword"]:
keyword_filter = "%{}%".format(args["keyword"])
query = (
query.join(
Message,
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
.filter(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
)
if args['keyword']:
keyword_filter = '%{}%'.format(args['keyword'])
query = query.join(
Message, Message.conversation_id == Conversation.id,
).join(
subquery, subquery.c.conversation_id == Conversation.id
).filter(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter)
),
)
account = current_user
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
@@ -203,8 +198,8 @@ class ChatConversationApi(Resource):
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
@@ -212,46 +207,50 @@ class ChatConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
if args["annotation_status"] == "annotated":
if args['annotation_status'] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
elif args['annotation_status'] == "not_annotated":
query = query.outerjoin(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
if args["message_count_gte"] and args["message_count_gte"] >= 1:
if args['message_count_gte'] and args['message_count_gte'] >= 1:
query = (
query.options(joinedload(Conversation.messages))
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
.having(func.count(Message.id) >= args['message_count_gte'])
)
if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
match args["sort_by"]:
case "created_at":
match args['sort_by']:
case 'created_at':
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
case '-created_at':
query = query.order_by(Conversation.created_at.desc())
case "updated_at":
case 'updated_at':
query = query.order_by(Conversation.updated_at.asc())
case "-updated_at":
case '-updated_at':
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
conversations = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
return conversations
class ChatConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -273,11 +272,8 @@ class ChatConversationDetailApi(Resource):
raise Forbidden()
conversation_id = str(conversation_id)
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
@@ -285,21 +281,18 @@ class ChatConversationDetailApi(Resource):
conversation.is_deleted = True
db.session.commit()
return {"result": "success"}, 204
return {'result': 'success'}, 204
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
def _get_conversation(app_model, conversation_id):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")

View File

@@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
@marshal_with(paginated_conversation_variable_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", type=str, location="args")
parser.add_argument('conversation_id', type=str, location='args')
args = parser.parse_args()
stmt = (
@@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
if args['conversation_id']:
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
else:
raise ValueError("conversation_id is required")
raise ValueError('conversation_id is required')
# NOTE: This is a temporary solution to avoid performance issues.
page = 1
@@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
rows = session.scalars(stmt).all()
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
'page': page,
'limit': page_size,
'total': len(rows),
'has_more': False,
'data': [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
'created_at': row.created_at,
'updated_at': row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
@@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
}
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')

View File

@@ -2,128 +2,116 @@ from libs.exception import BaseHTTPException
class AppNotFoundError(BaseHTTPException):
error_code = "app_not_found"
error_code = 'app_not_found'
description = "App not found."
code = 404
class ProviderNotInitializeError(BaseHTTPException):
error_code = "provider_not_initialize"
description = (
"No valid model provider credentials found. "
"Please go to Settings -> Model Provider to complete your provider credentials."
)
error_code = 'provider_not_initialize'
description = "No valid model provider credentials found. " \
"Please go to Settings -> Model Provider to complete your provider credentials."
code = 400
class ProviderQuotaExceededError(BaseHTTPException):
error_code = "provider_quota_exceeded"
description = (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
error_code = 'provider_quota_exceeded'
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials."
code = 400
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
error_code = "model_currently_not_support"
error_code = 'model_currently_not_support'
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
code = 400
class ConversationCompletedError(BaseHTTPException):
error_code = "conversation_completed"
error_code = 'conversation_completed'
description = "The conversation has ended. Please start a new conversation."
code = 400
class AppUnavailableError(BaseHTTPException):
error_code = "app_unavailable"
error_code = 'app_unavailable'
description = "App unavailable, please check your app configurations."
code = 400
class CompletionRequestError(BaseHTTPException):
error_code = "completion_request_error"
error_code = 'completion_request_error'
description = "Completion request failed."
code = 400
class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = "app_more_like_this_disabled"
error_code = 'app_more_like_this_disabled'
description = "The 'More like this' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = "no_audio_uploaded"
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = "audio_too_large"
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = "unsupported_audio_type"
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = "provider_not_support_speech_to_text"
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
error_code = 'draft_workflow_not_exist'
description = "Draft workflow need to be initialized."
code = 400
class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync"
error_code = 'draft_workflow_not_sync'
description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400
class TracingConfigNotExist(BaseHTTPException):
error_code = "trace_config_not_exist"
error_code = 'trace_config_not_exist'
description = "Trace config not exist."
code = 400
class TracingConfigIsExist(BaseHTTPException):
error_code = "trace_config_is_exist"
error_code = 'trace_config_is_exist'
description = "Trace config is exist."
code = 400
class TracingConfigCheckError(BaseHTTPException):
error_code = "trace_config_check_error"
error_code = 'trace_config_check_error'
description = "Invalid Credentials."
code = 400
class InvokeRateLimitError(BaseHTTPException):
"""Raised when the Invoke returns rate limit error."""
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429

View File

@@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
args = parser.parse_args()
account = current_user
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
instruction=args['instruction'],
model_config=args['model_config'],
no_variable=args['no_variable'],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
return rules
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleGenerateApi, '/rule-generate')

View File

@@ -33,9 +33,9 @@ from services.message_service import MessageService
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_fields)),
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_detail_fields))
}
@setup_required
@@ -45,69 +45,55 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
conversation = db.session.query(Conversation).filter(
Conversation.id == args['conversation_id'],
Conversation.app_id == app_model.id
).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
if args["first_id"]:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
if args['first_id']:
first_message = db.session.query(Message) \
.filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
if not first_message:
raise NotFound("First message not found")
history_messages = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
history_messages = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id
) \
.order_by(Message.created_at.desc()).limit(args['limit']).all()
else:
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
.order_by(Message.created_at.desc()).limit(args['limit']).all()
has_more = False
if len(history_messages) == args["limit"]:
if len(history_messages) == args['limit']:
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)
rest_count = db.session.query(Message).filter(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id
).count()
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
return InfiniteScrollPagination(
data=history_messages,
limit=args['limit'],
has_more=has_more
)
class MessageFeedbackApi(Resource):
@@ -117,46 +103,49 @@ class MessageFeedbackApi(Resource):
@get_app_model
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args()
message_id = str(args["message_id"])
message_id = str(args['message_id'])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id
).first()
if not message:
raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args["rating"] and feedback:
if not args['rating'] and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
elif args['rating'] and feedback:
feedback.rating = args['rating']
elif not args['rating'] and not feedback:
raise ValueError('rating cannot be None when feedback not exists')
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
rating=args['rating'],
from_source='admin',
from_account_id=current_user.id
)
db.session.add(feedback)
db.session.commit()
return {"result": "success"}
return {'result': 'success'}
class MessageAnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@cloud_edition_billing_resource_check('annotation')
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
@@ -164,10 +153,10 @@ class MessageAnnotationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
@@ -180,9 +169,11 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required
@get_app_model
def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
count = db.session.query(MessageAnnotation).filter(
MessageAnnotation.app_id == app_model.id
).count()
return {"count": count}
return {'count': count}
class MessageSuggestedQuestionApi(Resource):
@@ -195,7 +186,10 @@ class MessageSuggestedQuestionApi(Resource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
app_model=app_model,
message_id=message_id,
user=current_user,
invoke_from=InvokeFrom.DEBUGGER
)
except MessageNotExistsError:
raise NotFound("Message not found")
@@ -215,7 +209,7 @@ class MessageSuggestedQuestionApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
return {'data': questions}
class MessageApi(Resource):
@@ -227,7 +221,10 @@ class MessageApi(Resource):
def get(self, app_model, message_id):
message_id = str(message_id)
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id
).first()
if not message:
raise NotFound("Message Not Exists.")
@@ -235,9 +232,9 @@ class MessageApi(Resource):
return message
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')

View File

@@ -19,35 +19,37 @@ from services.app_model_config_service import AppModelConfigService
class ModelConfigResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
tenant_id=current_user.current_tenant_id,
config=request.json,
app_mode=AppMode.value_of(app_model.mode)
)
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config: AppModelConfig = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app_model.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get("tools") or []:
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
@@ -64,7 +66,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}",
identity_id=f'AGENT.{app_model.id}'
)
except Exception as e:
continue
@@ -77,18 +79,18 @@ class ModelConfigResource(Resource):
parameters = {}
masked_parameter = {}
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
@@ -106,7 +108,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}",
identity_id=f'AGENT.{app_model.id}'
)
manager.delete_tool_parameters_cache()
@@ -114,17 +116,15 @@ class ModelConfigResource(Resource):
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
for masked_key, masked_value in masked_parameter_map[key].items():
if (
masked_key in agent_tool_entity.tool_parameters
and agent_tool_entity.tool_parameters[masked_key] == masked_value
):
if masked_key in agent_tool_entity.tool_parameters and \
agent_tool_entity.tool_parameters[masked_key] == masked_value:
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
@@ -135,9 +135,12 @@ class ModelConfigResource(Resource):
app_model.app_model_config_id = new_app_model_config.id
db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
app_model_config_was_updated.send(
app_model,
app_model_config=new_app_model_config
)
return {"result": "success"}
return {'result': 'success'}
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")
api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')

View File

@@ -18,11 +18,13 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument('tracing_provider', type=str, required=True, location='args')
args = parser.parse_args()
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
trace_config = OpsService.get_tracing_app_config(
app_id=app_id, tracing_provider=args['tracing_provider']
)
if not trace_config:
return {"has_not_configured": True}
return trace_config
@@ -35,17 +37,19 @@ class TraceAppConfigApi(Resource):
def post(self, app_id):
"""Create a new trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument('tracing_config', type=dict, required=True, location='json')
args = parser.parse_args()
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id,
tracing_provider=args['tracing_provider'],
tracing_config=args['tracing_config']
)
if not result:
raise TracingConfigIsExist()
if result.get("error"):
if result.get('error'):
raise TracingConfigCheckError()
return result
except Exception as e:
@@ -57,13 +61,15 @@ class TraceAppConfigApi(Resource):
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument('tracing_config', type=dict, required=True, location='json')
args = parser.parse_args()
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id,
tracing_provider=args['tracing_provider'],
tracing_config=args['tracing_config']
)
if not result:
raise TracingConfigNotExist()
@@ -77,11 +83,14 @@ class TraceAppConfigApi(Resource):
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument('tracing_provider', type=str, required=True, location='args')
args = parser.parse_args()
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
result = OpsService.delete_tracing_app_config(
app_id=app_id,
tracing_provider=args['tracing_provider']
)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}
@@ -89,4 +98,4 @@ class TraceAppConfigApi(Resource):
raise e
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')

View File

@@ -1,5 +1,3 @@
from datetime import datetime, timezone
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@@ -17,24 +15,23 @@ from models.model import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
parser.add_argument(
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument('description', type=str, required=False, location='json')
parser.add_argument('default_language', type=supported_language, required=False, location='json')
parser.add_argument('chat_color_theme', type=str, required=False, location='json')
parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
parser.add_argument('customize_domain', type=str, required=False, location='json')
parser.add_argument('copyright', type=str, required=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, location='json')
parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
required=False,
location='json')
parser.add_argument('prompt_public', type=bool, required=False, location='json')
parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
return parser.parse_args()
@@ -51,38 +48,38 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
site = db.session.query(Site). \
filter(Site.app_id == app_model.id). \
one_or_404()
for attr_name in [
"title",
"icon_type",
"icon",
"icon_background",
"description",
"default_language",
"chat_color_theme",
"chat_color_theme_inverted",
"customize_domain",
"copyright",
"privacy_policy",
"custom_disclaimer",
"customize_token_strategy",
"prompt_public",
"show_workflow_steps",
"use_icon_as_answer_icon",
'title',
'icon_type',
'icon',
'icon_background',
'description',
'default_language',
'chat_color_theme',
'chat_color_theme_inverted',
'customize_domain',
'copyright',
'privacy_policy',
'custom_disclaimer',
'customize_token_strategy',
'prompt_public',
'show_workflow_steps'
]:
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
return site
class AppSiteAccessTokenReset(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -99,12 +96,10 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound
site.code = Site.generate_code(16)
site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
return site
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")
api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')

View File

@@ -16,61 +16,8 @@ from libs.login import login_required
from models.model import AppMode
class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date order by date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "message_count": i.message_count})
return jsonify({"data": response_data})
class DailyConversationStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -79,52 +26,58 @@ class DailyConversationStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
response_data.append({
'date': str(i.date),
'conversation_count': i.conversation_count
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class DailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -133,49 +86,54 @@ class DailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class DailyTokenCostStatistic(Resource):
@@ -187,53 +145,58 @@ class DailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
sum(total_price) as total_price
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
)
response_data.append({
'date': str(i.date),
'token_count': i.token_count,
'total_price': i.total_price,
'currency': 'USD'
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class AverageSessionInteractionStatistic(Resource):
@@ -245,8 +208,8 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
@@ -255,30 +218,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM conversations c
JOIN messages m ON c.id = m.conversation_id
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and c.created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and c.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and c.created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += """
GROUP BY m.conversation_id) subquery
@@ -287,15 +250,18 @@ GROUP BY date
ORDER BY date"""
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
response_data.append({
'date': str(i.date),
'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class UserSatisfactionRateStatistic(Resource):
@@ -307,57 +273,57 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
FROM messages m
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
WHERE m.app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and m.created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and m.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and m.created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
"date": str(i.date),
"rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
}
)
response_data.append({
'date': str(i.date),
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class AverageResponseTimeStatistic(Resource):
@@ -369,51 +335,56 @@ class AverageResponseTimeStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) as latency
FROM messages
WHERE app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
response_data.append({
'date': str(i.date),
'latency': round(i.latency * 1000, 4)
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class TokensPerSecondStatistic(Resource):
@@ -425,59 +396,63 @@ class TokensPerSecondStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second
FROM messages
WHERE app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
WHERE app_id = :app_id'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
response_data.append({
'date': str(i.date),
'tps': round(i.tokens_per_second, 4)
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")
api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')

View File

@@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
content_type = request.headers.get('Content-Type', '')
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
if 'application/json' in content_type:
parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument('environment_variables', type=list, required=False, location='json')
parser.add_argument('conversation_variables', type=list, required=False, location='json')
args = parser.parse_args()
elif "text/plain" in content_type:
elif 'text/plain' in content_type:
try:
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
data = json.loads(request.data.decode('utf-8'))
if 'graph' not in data or 'features' not in data:
raise ValueError('graph or features not found in data')
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
raise ValueError('graph or features is not a dict')
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
'graph': data.get('graph'),
'features': data.get('features'),
'hash': data.get('hash'),
'environment_variables': data.get('environment_variables'),
'conversation_variables': data.get('conversation_variables'),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
return {'message': 'Invalid JSON data'}, 400
else:
abort(415)
workflow_service = WorkflowService()
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables_list = args.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables_list = args.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args["graph"],
features=args["features"],
unique_hash=args.get("hash"),
graph=args['graph'],
features=args['features'],
unique_hash=args.get('hash'),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
@@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
}
@@ -138,11 +138,13 @@ class DraftWorkflowImportApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model, data=args["data"], account=current_user
app_model=app_model,
data=args['data'],
account=current_user
)
return workflow
@@ -160,17 +162,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument('inputs', type=dict, location='json')
parser.add_argument('query', type=str, required=True, location='json', default='')
parser.add_argument('files', type=list, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True
)
return helper.compact_generate_response(response)
@@ -184,7 +190,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class AdvancedChatDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@@ -197,14 +202,18 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
)
return helper.compact_generate_response(response)
@@ -218,7 +227,6 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@@ -231,14 +239,18 @@ class WorkflowDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
)
return helper.compact_generate_response(response)
@@ -252,7 +264,6 @@ class WorkflowDraftRunIterationNodeApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class DraftWorkflowRunApi(Resource):
@setup_required
@login_required
@@ -265,15 +276,19 @@ class DraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument('files', type=list, required=False, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True
)
return helper.compact_generate_response(response)
@@ -296,10 +311,12 @@ class WorkflowTaskStopApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}
return {
"result": "success"
}
class DraftWorkflowNodeRunApi(Resource):
@@ -315,20 +332,24 @@ class DraftWorkflowNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
app_model=app_model,
node_id=node_id,
user_inputs=args.get('inputs'),
account=current_user
)
return workflow_node_execution
class PublishedWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -341,7 +362,7 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# fetch published workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model)
@@ -360,11 +381,14 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
return {
"result": "success",
"created_at": TimestampField().format(workflow.created_at)
}
class DefaultBlockConfigsApi(Resource):
@@ -379,7 +403,7 @@ class DefaultBlockConfigsApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_configs()
@@ -397,21 +421,24 @@ class DefaultBlockConfigApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
parser.add_argument('q', type=str, location='args')
args = parser.parse_args()
filters = None
if args.get("q"):
if args.get('q'):
try:
filters = json.loads(args.get("q"))
filters = json.loads(args.get('q'))
except json.JSONDecodeError:
raise ValueError("Invalid filters")
raise ValueError('Invalid filters')
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
return workflow_service.get_default_block_config(
node_type=block_type,
filters=filters
)
class ConvertToWorkflowApi(Resource):
@@ -428,43 +455,41 @@ class ConvertToWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
else:
args = {}
# convert to workflow mode
workflow_service = WorkflowService()
new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
new_app_model = workflow_service.convert_to_workflow(
app_model=app_model,
account=current_user,
args=args
)
# return app id
return {
"new_app_id": new_app_model.id,
'new_app_id': new_app_model.id,
}
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
)
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
'/<string:block_type>')
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')

View File

@@ -22,19 +22,20 @@ class WorkflowAppLogApi(Resource):
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
parser.add_argument('keyword', type=str, location='args')
parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
app_model=app_model,
args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")
api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')

View File

@@ -28,12 +28,15 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model,
args=args
)
return result
@@ -49,12 +52,15 @@ class WorkflowRunListApi(Resource):
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model,
args=args
)
return result
@@ -92,10 +98,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
workflow_run_service = WorkflowRunService()
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
return {"data": node_executions}
return {
'data': node_executions
}
api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs")
api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs")
api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')

View File

@@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
return jsonify({"data": response_data})
response_data.append({
'date': str(i.date),
'runs': i.runs
})
return jsonify({
'data': response_data
})
class WorkflowDailyTerminalsStatistic(Resource):
@setup_required
@@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
return jsonify({"data": response_data})
response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
return jsonify({
'data': response_data
})
class WorkflowDailyTokenCostStatistic(Resource):
@setup_required
@@ -146,63 +146,58 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
sql_query = '''
SELECT
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) as token_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
'''
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
sql_query += ' and created_at >= :start'
arg_dict['start'] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += ' and created_at < :end'
arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date order by date"
sql_query += ' GROUP BY date order by date'
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
"date": str(i.date),
"token_count": i.token_count,
}
)
return jsonify({"data": response_data})
response_data.append({
'date': str(i.date),
'token_count': i.token_count,
})
return jsonify({
'data': response_data
})
class WorkflowAverageAppInteractionStatistic(Resource):
@setup_required
@@ -213,8 +208,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args()
sql_query = """
@@ -234,54 +229,50 @@ class WorkflowAverageAppInteractionStatistic(Resource):
GROUP BY date, c.created_by) sub
GROUP BY sub.date
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
if args['start']:
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
arg_dict["start"] = start_datetime_utc
sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
arg_dict['start'] = start_datetime_utc
else:
sql_query = sql_query.replace("{{start}}", "")
sql_query = sql_query.replace('{{start}}', '')
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
arg_dict["end"] = end_datetime_utc
sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
arg_dict['end'] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")
sql_query = sql_query.replace('{{end}}', '')
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
response_data.append({
'date': str(i.date),
'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs")
api.add_resource(
WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
)
api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')

View File

@@ -8,23 +8,24 @@ from libs.login import current_user
from models.model import App, AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
def get_app_model(view: Optional[Callable] = None, *,
mode: Union[AppMode, list[AppMode]] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
if not kwargs.get('app_id'):
raise ValueError('missing app_id in path parameters')
app_id = kwargs.get("app_id")
app_id = kwargs.get('app_id')
app_id = str(app_id)
del kwargs["app_id"]
del kwargs['app_id']
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
app_model = db.session.query(App).filter(
App.id == app_id,
App.tenant_id == current_user.current_tenant_id,
App.status == 'normal'
).first()
if not app_model:
raise AppNotFoundError()
@@ -43,10 +44,9 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model
kwargs['app_model'] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:

View File

@@ -17,61 +17,60 @@ from services.account_service import RegisterService
class ActivateCheckApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
workspaceId = args["workspace_id"]
reg_email = args["email"]
token = args["token"]
workspaceId = args['workspace_id']
reg_email = args['email']
token = args['token']
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
class ActivateApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
location='json')
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
args = parser.parse_args()
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
account = invitation["account"]
account.name = args["name"]
account = invitation['account']
account.name = args['name']
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args["password"], salt)
password_hashed = hash_password(args['password'], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.interface_language = args['interface_language']
account.timezone = args['timezone']
account.interface_theme = 'light'
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
return {"result": "success"}
return {'result': 'success'}
api.add_resource(ActivateCheckApi, "/activate/check")
api.add_resource(ActivateApi, "/activate")
api.add_resource(ActivateCheckApi, '/activate/check')
api.add_resource(ActivateApi, '/activate')

View File

@@ -19,19 +19,18 @@ class ApiKeyAuthDataSource(Resource):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
{
"id": data_source_api_key_binding.id,
"category": data_source_api_key_binding.category,
"provider": data_source_api_key_binding.provider,
"disabled": data_source_api_key_binding.disabled,
"created_at": int(data_source_api_key_binding.created_at.timestamp()),
"updated_at": int(data_source_api_key_binding.updated_at.timestamp()),
}
for data_source_api_key_binding in data_source_api_key_bindings
]
'sources': [{
'id': data_source_api_key_binding.id,
'category': data_source_api_key_binding.category,
'provider': data_source_api_key_binding.provider,
'disabled': data_source_api_key_binding.disabled,
'created_at': int(data_source_api_key_binding.created_at.timestamp()),
'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
}
for data_source_api_key_binding in
data_source_api_key_bindings]
}
return {"sources": []}
return {'sources': []}
class ApiKeyAuthDataSourceBinding(Resource):
@@ -43,16 +42,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
return {'result': 'success'}, 200
class ApiKeyAuthDataSourceBindingDelete(Resource):
@@ -66,9 +65,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')

View File

@@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
client_id=dify_config.NOTION_CLIENT_ID,
client_secret=dify_config.NOTION_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
)
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
client_secret=dify_config.NOTION_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
OAUTH_PROVIDERS = {"notion": notion_oauth}
OAUTH_PROVIDERS = {
'notion': notion_oauth
}
return OAUTH_PROVIDERS
@@ -37,16 +37,18 @@ class OAuthDataSource(Resource):
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
return {"error": "Invalid provider"}, 400
if dify_config.NOTION_INTEGRATION_TYPE == "internal":
return {'error': 'Invalid provider'}, 400
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
internal_secret = dify_config.NOTION_INTERNAL_SECRET
if not internal_secret:
return ({"error": "Internal secret is not set"},)
return {'error': 'Internal secret is not set'},
oauth_provider.save_internal_access_token(internal_secret)
return {"data": ""}
return { 'data': '' }
else:
auth_url = oauth_provider.get_authorization_url()
return {"data": auth_url}, 200
return { 'data': auth_url }, 200
class OAuthDataSourceCallback(Resource):
@@ -55,18 +57,18 @@ class OAuthDataSourceCallback(Resource):
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
if "code" in request.args:
code = request.args.get("code")
return {'error': 'Invalid provider'}, 400
if 'code' in request.args:
code = request.args.get('code')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}")
elif "error" in request.args:
error = request.args.get("error")
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}")
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
else:
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource):
def get(self, provider: str):
@@ -74,18 +76,17 @@ class OAuthDataSourceBinding(Resource):
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
if "code" in request.args:
code = request.args.get("code")
return {'error': 'Invalid provider'}, 400
if 'code' in request.args:
code = request.args.get('code')
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
logging.exception(
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
)
return {"error": "OAuth data source process failed"}, 400
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
return {"result": "success"}, 200
return {'result': 'success'}, 200
class OAuthDataSourceSync(Resource):
@@ -99,17 +100,18 @@ class OAuthDataSourceSync(Resource):
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
return {'error': 'Invalid provider'}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e:
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {"error": "OAuth data source process failed"}, 400
logging.exception(
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
return {"result": "success"}, 200
return {'result': 'success'}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')

View File

@@ -2,30 +2,31 @@ from libs.exception import BaseHTTPException
class ApiKeyAuthFailedError(BaseHTTPException):
error_code = "auth_failed"
error_code = 'auth_failed'
description = "{message}"
code = 500
class InvalidEmailError(BaseHTTPException):
error_code = "invalid_email"
error_code = 'invalid_email'
description = "The email address is not valid."
code = 400
class PasswordMismatchError(BaseHTTPException):
error_code = "password_mismatch"
error_code = 'password_mismatch'
description = "The passwords do not match."
code = 400
class InvalidTokenError(BaseHTTPException):
error_code = "invalid_or_expired_token"
error_code = 'invalid_or_expired_token'
description = "The token is invalid or has expired."
code = 400
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
error_code = 'password_reset_rate_limit_exceeded'
description = "Password reset rate limit exceeded. Try again later."
code = 429

View File

@@ -21,13 +21,14 @@ from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument('email', type=str, required=True, location='json')
args = parser.parse_args()
email = args["email"]
email = args['email']
if not email_validate(email):
raise InvalidEmailError()
@@ -48,36 +49,38 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
token = args["token"]
token = args['token']
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
return {"is_valid": False, "email": None}
return {"is_valid": True, "email": reset_data.get("email")}
return {'is_valid': False, 'email': None}
return {'is_valid': True, 'email': reset_data.get('email')}
class ForgotPasswordResetApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
args = parser.parse_args()
new_password = args["new_password"]
password_confirm = args["password_confirm"]
new_password = args['new_password']
password_confirm = args['password_confirm']
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError()
token = args["token"]
token = args['token']
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
@@ -91,14 +94,14 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first()
account = Account.query.filter_by(email=reset_data.get('email')).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
return {"result": "success"}
return {'result': 'success'}
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')

View File

@@ -20,39 +20,37 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument('email', type=email, required=True, location='json')
parser.add_argument('password', type=valid_password, required=True, location='json')
parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
# todo: Verify the recaptcha
try:
account = AccountService.authenticate(args["email"], args["password"])
account = AccountService.authenticate(args['email'], args['password'])
except services.errors.account.AccountLoginError as e:
return {"code": "unauthorized", "message": str(e)}, 401
return {'code': 'unauthorized', 'message': str(e)}, 401
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
return {
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
token = AccountService.login(account, ip_address=get_remote_ip(request))
return {"result": "success", "data": token}
return {'result': 'success', 'data': token}
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
token = request.headers.get("Authorization", "").split(" ")[1]
token = request.headers.get('Authorization', '').split(' ')[1]
AccountService.logout(account=account, token=token)
flask_login.logout_user()
return {"result": "success"}
return {'result': 'success'}
class ResetPasswordApi(Resource):
@@ -82,11 +80,11 @@ class ResetPasswordApi(Resource):
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# <p>The Dify Team</p>
# """
# }
@@ -103,8 +101,8 @@ class ResetPasswordApi(Resource):
# # handle error
# pass
return {"result": "success"}
return {'result': 'success'}
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(LoginApi, '/login')
api.add_resource(LogoutApi, '/logout')

View File

@@ -25,7 +25,7 @@ def get_oauth_providers():
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
@@ -33,10 +33,10 @@ def get_oauth_providers():
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
)
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
return OAUTH_PROVIDERS
@@ -47,7 +47,7 @@ class OAuthLogin(Resource):
oauth_provider = OAUTH_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
return {"error": "Invalid provider"}, 400
return {'error': 'Invalid provider'}, 400
auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url)
@@ -59,20 +59,20 @@ class OAuthCallback(Resource):
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
return {'error': 'Invalid provider'}, 400
code = request.args.get("code")
code = request.args.get('code')
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
return {'error': 'OAuth process failed'}, 400
account = _generate_account(provider, user_info)
# Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403
return {'error': 'Account is banned or closed.'}, 403
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
@@ -83,7 +83,7 @@ class OAuthCallback(Resource):
token = AccountService.login(account, ip_address=get_remote_ip(request))
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
# Create account
account_name = user_info.name if user_info.name else "Dify"
account_name = user_info.name if user_info.name else 'Dify'
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
@@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
api.add_resource(OAuthLogin, '/oauth/login/<provider>')
api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')

View File

@@ -9,24 +9,28 @@ from services.billing_service import BillingService
class Subscription(Resource):
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
return BillingService.get_subscription(args['plan'],
args['interval'],
current_user.email,
current_user.current_tenant_id)
class Invoices(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -36,5 +40,5 @@ class Invoices(Resource):
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, "/billing/subscription")
api.add_resource(Invoices, "/billing/invoices")
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')

View File

@@ -22,22 +22,19 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
.all()
)
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False
).all()
base_url = request.url_root.rstrip("/")
base_url = request.url_root.rstrip('/')
data_source_oauth_base_path = "/console/api/oauth/data-source"
providers = ["notion"]
@@ -47,30 +44,26 @@ class DataSourceApi(Resource):
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
if existing_integrates:
for existing_integrate in list(existing_integrates):
integrate_data.append(
{
"id": existing_integrate.id,
"provider": provider,
"created_at": existing_integrate.created_at,
"is_bound": True,
"disabled": existing_integrate.disabled,
"source_info": existing_integrate.source_info,
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
integrate_data.append({
'id': existing_integrate.id,
'provider': provider,
'created_at': existing_integrate.created_at,
'is_bound': True,
'disabled': existing_integrate.disabled,
'source_info': existing_integrate.source_info,
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
})
else:
integrate_data.append(
{
"id": None,
"provider": provider,
"created_at": None,
"source_info": None,
"is_bound": False,
"disabled": None,
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
}
)
return {"data": integrate_data}, 200
integrate_data.append({
'id': None,
'provider': provider,
'created_at': None,
'source_info': None,
'is_bound': False,
'disabled': None,
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
})
return {'data': integrate_data}, 200
@setup_required
@login_required
@@ -78,82 +71,92 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
data_source_binding = DataSourceOauthBinding.query.filter_by(
id=binding_id
).first()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
raise NotFound('Data source binding not found.')
# enable binding
if action == "enable":
if action == 'enable':
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
raise ValueError('Data source is not disabled.')
# disable binding
if action == "disable":
if action == 'disable':
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200
raise ValueError('Data source is disabled.')
return {'result': 'success'}, 200
class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
dataset_id = request.args.get('dataset_id', default=None, type=str)
exist_page_ids = []
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
raise NotFound('Dataset not found.')
if dataset.data_source_type != 'notion_import':
raise ValueError('Dataset is not notion type.')
documents = Document.query.filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
data_source_type='notion_import',
enabled=True
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_page_ids.append(data_source_info['notion_page_id'])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
tenant_id=current_user.current_tenant_id,
provider='notion',
disabled=False
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
return {
'notion_info': []
}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
pages = source_info['pages']
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
if page['page_id'] in exist_page_ids:
page['is_bound'] = True
else:
page["is_bound"] = False
page['is_bound'] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
'workspace_name': source_info['workspace_name'],
'workspace_icon': source_info['workspace_icon'],
'workspace_id': source_info['workspace_id'],
'pages': pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
return {
'notion_info': pre_import_info_list
}, 200
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -163,67 +166,64 @@ class DataSourceNotionApi(Resource):
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise NotFound("Data source binding not found.")
raise NotFound('Data source binding not found.')
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id,
tenant_id=current_user.current_tenant_id
)
text_docs = extractor.extract()
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
return {
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"]
notion_info_list = args['notion_info_list']
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args["doc_form"],
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
args["doc_language"],
)
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
return response, 200
class DataSourceNotionDatasetSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -240,6 +240,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
class DataSourceNotionDocumentSyncApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -257,14 +258,10 @@ class DataSourceNotionDocumentSyncApi(Resource):
return 200
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
api.add_resource(
DataSourceNotionApi,
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
api.add_resource(
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
)
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource(DataSourceNotionApi,
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
'/datasets/notion-indexing-estimate')
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')

View File

@@ -31,40 +31,45 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
raise ValueError('Name must be between 1 to 40 characters.')
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
raise ValueError('Description cannot exceed 400 characters.')
return description
class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
ids = request.args.getlist('ids')
provider = request.args.get('provider', default="vendor")
search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist('tag_ids')
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
)
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for embedding_model in embedding_models:
@@ -72,22 +77,28 @@ class DatasetListApi(Resource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == "high_quality":
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True
item['embedding_available'] = True
else:
item["embedding_available"] = False
item['embedding_available'] = False
else:
item["embedding_available"] = True
item['embedding_available'] = True
if item.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
item.update({"partial_member_list": part_users_list})
if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({'partial_member_list': part_users_list})
else:
item.update({"partial_member_list": []})
item.update({'partial_member_list': []})
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
@setup_required
@@ -95,21 +106,13 @@ class DatasetListApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -119,10 +122,9 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
name=args["name"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -140,36 +142,42 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
DatasetService.check_dataset_permission(
dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
data.update({'partial_member_list': part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
data['embedding_available'] = True
else:
data["embedding_available"] = False
data['embedding_available'] = False
else:
data["embedding_available"] = True
data['embedding_available'] = True
if data.get("permission") == "partial_members":
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
data.update({'partial_member_list': part_users_list})
return data, 200
@@ -183,49 +191,42 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument('name', nullable=False,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('description',
location='json', store_missing=False,
type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
)
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
if data.get('indexing_technique') == 'high_quality':
DatasetService.check_embedding_model_setting(dataset.tenant_id,
data.get('embedding_model_provider'),
data.get('embedding_model')
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
current_user, dataset, data.get('permission'), data.get('partial_member_list')
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -233,19 +234,16 @@ class DatasetApi(Resource):
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
tenant_id, dataset_id_str, data.get('partial_member_list')
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
result_data.update({'partial_member_list': partial_member_list})
return result_data, 200
@@ -262,13 +260,12 @@ class DatasetApi(Resource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
return {'result': 'success'}, 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
raise DatasetInUseError()
class DatasetUseCheckApi(Resource):
@setup_required
@login_required
@@ -277,10 +274,10 @@ class DatasetUseCheckApi(Resource):
dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
return {"is_using": dataset_is_using}, 200
return {'is_using': dataset_is_using}, 200
class DatasetQueryApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -295,53 +292,51 @@ class DatasetQueryApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
dataset_queries, total = DatasetService.get_dataset_queries(
dataset_id=dataset.id,
page=page,
per_page=limit
)
response = {
"data": marshal(dataset_queries, dataset_query_detail_fields),
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
"page": page,
'data': marshal(dataset_queries, dataset_query_detail_fields),
'has_more': len(dataset_queries) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
class DatasetIndexingEstimateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id.in_(file_ids)
).all()
if file_details is None:
raise NotFound("File not found.")
@@ -349,58 +344,55 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
datasource_type="upload_file",
upload_file=file_detail,
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
notion_info_list = args["info_list"]["notion_info_list"]
elif args['info_list']['data_source_type'] == 'notion_import':
notion_info_list = args['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
for page in notion_info["pages"]:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args["doc_form"],
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "website_crawl":
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
elif args['info_list']['data_source_type'] == 'website_crawl':
website_info_list = args['info_list']['website_info_list']
for url in website_info_list['urls']:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"provider": website_info_list['provider'],
"job_id": website_info_list['job_id'],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
"mode": 'crawl',
"only_main_content": website_info_list['only_main_content']
},
document_model=args["doc_form"],
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
args["doc_language"],
args["dataset_id"],
args["indexing_technique"],
)
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
)
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
@@ -410,6 +402,7 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -433,52 +426,52 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
return {"data": related_apps, "total": len(related_apps)}, 200
return {
'data': related_apps,
'total': len(related_apps)
}, 200
class DatasetIndexingStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id
).all()
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))
data = {"data": documents_status}
data = {
'data': documents_status
}
return data
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = "dataset-"
resource_type = "dataset"
token_prefix = 'dataset-'
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
all()
return {"items": keys}
@setup_required
@@ -490,17 +483,15 @@ class DatasetApiKeyApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count()
)
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
count()
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
code='max_keys_exceeded'
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -514,7 +505,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource):
resource_type = "dataset"
resource_type = 'dataset'
@setup_required
@login_required
@@ -526,23 +517,18 @@ class DatasetApiDeleteApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
key = (
db.session.query(ApiToken)
.filter(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
key = db.session.query(ApiToken). \
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id). \
first()
if key is None:
flask_restful.abort(404, message="API key not found")
flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204
return {'result': 'success'}, 204
class DatasetApiBaseUrlApi(Resource):
@@ -551,10 +537,8 @@ class DatasetApiBaseUrlApi(Resource):
@account_initialization_required
def get(self):
return {
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
else request.host_url.rstrip('/')) + '/v1'
}
@@ -565,26 +549,15 @@ class DatasetRetrievalSettingApi(Resource):
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
):
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
return {
"retrieval_method": [
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
@@ -600,27 +573,15 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
):
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
return {
"retrieval_method": [
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
@@ -630,6 +591,7 @@ class DatasetRetrievalSettingMockApi(Resource):
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource):
@setup_required
@login_required
@@ -641,7 +603,10 @@ class DatasetErrorDocs(Resource):
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
return {
'data': [marshal(item, document_status_fields) for item in results],
'total': len(results)
}, 200
class DatasetPermissionUserListApi(Resource):
@@ -661,21 +626,21 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
"data": partial_members_list,
'data': partial_members_list,
}, 200
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')

File diff suppressed because it is too large Load Diff

View File

@@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
raise NotFound('Dataset not found.')
try:
DatasetService.check_dataset_permission(dataset, current_user)
@@ -50,33 +50,37 @@ class DatasetDocumentSegmentListApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
raise NotFound('Document not found.')
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument('last_id', type=str, default=None, location='args')
parser.add_argument('limit', type=int, default=20, location='args')
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('hit_count_gte', type=int,
default=None, location='args')
parser.add_argument('enabled', type=str, default='all', location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
args = parser.parse_args()
last_id = args["last_id"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
last_id = args['last_id']
limit = min(args['limit'], 100)
status_list = args['status']
hit_count_gte = args['hit_count_gte']
keyword = args['keyword']
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
)
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position)
query = query.filter(
DocumentSegment.position > last_segment.position)
else:
return {"data": [], "has_more": False, "limit": limit}, 200
return {'data': [], 'has_more': False, 'limit': limit}, 200
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@@ -85,12 +89,12 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
if args['enabled'].lower() != 'all':
if args['enabled'].lower() == 'true':
query = query.filter(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false":
elif args['enabled'].lower() == 'false':
query = query.filter(DocumentSegment.enabled == False)
total = query.count()
@@ -102,11 +106,11 @@ class DatasetDocumentSegmentListApi(Resource):
segments = segments[:-1]
return {
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
"limit": limit,
"total": total,
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more,
'limit': limit,
'total': total
}, 200
@@ -114,12 +118,12 @@ class DatasetDocumentSegmentApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, segment_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
@@ -130,7 +134,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
model_manager = ModelManager()
@@ -138,32 +142,32 @@ class DatasetDocumentSegmentApi(Resource):
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
raise NotFound('Segment not found.')
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
if segment.status != 'completed':
raise NotFound('Segment is not completed, enable or disable function is not allowed')
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = "segment_{}_indexing".format(segment.id)
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later")
@@ -182,7 +186,7 @@ class DatasetDocumentSegmentApi(Resource):
enable_segment_to_index_task.delay(segment.id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
@@ -197,7 +201,7 @@ class DatasetDocumentSegmentApi(Resource):
disable_segment_from_index_task.delay(segment.id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
else:
raise InvalidActionError()
@@ -206,36 +210,35 @@ class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
raise NotFound('Document not found.')
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == 'high_quality':
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
@@ -244,34 +247,37 @@ class DatasetDocumentSegmentAddApi(Resource):
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
model_manager = ModelManager()
@@ -279,22 +285,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
@@ -304,13 +310,16 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
@setup_required
@login_required
@@ -320,21 +329,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
@@ -343,36 +353,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 200
return {'result': 'success'}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
raise NotFound('Document not found.')
# get file from request
file = request.files["file"]
file = request.files['file']
# check file
if "file" not in request.files:
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith(".csv"):
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
@@ -380,47 +390,51 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
if document.doc_form == "qa_model":
data = {"content": row[0], "answer": row[1]}
if document.doc_form == 'qa_model':
data = {'content': row[0], 'answer': row[1]}
else:
data = {"content": row[0]}
data = {'content': row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
)
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {"error": str(e)}, 500
return {"job_id": job_id, "job_status": "waiting"}, 200
return {'error': str(e)}, 500
return {
'job_id': job_id,
'job_status': 'waiting'
}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = "segment_batch_import_{}".format(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
)
api.add_resource(
DatasetDocumentSegmentBatchImportApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')

View File

@@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = "high_quality_dataset_only"
error_code = 'high_quality_dataset_only'
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
error_code = 'dataset_not_initialized'
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = "archived_document_immutable"
error_code = 'archived_document_immutable'
description = "The archived document is not editable."
code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = "dataset_name_duplicate"
error_code = 'dataset_name_duplicate'
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = "invalid_action"
error_code = 'invalid_action'
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = "document_already_finished"
error_code = 'document_already_finished'
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException):
error_code = "document_indexing"
error_code = 'document_indexing'
description = "The document is being processed and cannot be edited."
code = 400
class InvalidMetadataError(BaseHTTPException):
error_code = "invalid_metadata"
error_code = 'invalid_metadata'
description = "The metadata content is incorrect. Please check and verify."
code = 400
class WebsiteCrawlError(BaseHTTPException):
error_code = "crawl_failed"
error_code = 'crawl_failed'
description = "{message}"
code = 500
class DatasetInUseError(BaseHTTPException):
error_code = "dataset_in_use"
error_code = 'dataset_in_use'
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409
class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error"
error_code = 'indexing_estimate_error'
description = "Knowledge indexing estimate failed: {message}"
code = 500

View File

@@ -21,6 +21,7 @@ PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -30,22 +31,23 @@ class FileApi(Resource):
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
return {
"file_size_limit": file_size_limit,
"batch_count_limit": batch_count_limit,
"image_file_size_limit": image_file_size_limit,
'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit,
'image_file_size_limit': image_file_size_limit
}, 200
@setup_required
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
@cloud_edition_billing_resource_check(resource='documents')
def post(self):
# get file from request
file = request.files["file"]
file = request.files['file']
# check file
if "file" not in request.files:
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
@@ -67,7 +69,7 @@ class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
text = FileService.get_file_preview(file_id)
return {"content": text}
return {'content': text}
class FileSupportTypeApi(Resource):
@@ -76,10 +78,10 @@ class FileSupportTypeApi(Resource):
@account_initialization_required
def get(self):
etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
return {"allowed_extensions": allowed_extensions}
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
return {'allowed_extensions': allowed_extensions}
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
api.add_resource(FileSupportTypeApi, '/files/support-type')

View File

@@ -29,6 +29,7 @@ from services.hit_testing_service import HitTestingService
class HitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -45,8 +46,8 @@ class HitTestingApi(Resource):
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@@ -54,13 +55,13 @@ class HitTestingApi(Resource):
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
query=args['query'],
account=current_user,
retrieval_model=args["retrieval_model"],
limit=10,
retrieval_model=args['retrieval_model'],
limit=10
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
@@ -72,8 +73,7 @@ class HitTestingApi(Resource):
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
@@ -83,4 +83,4 @@ class HitTestingApi(Resource):
raise InternalServerError(str(e))
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')

View File

@@ -9,14 +9,16 @@ from services.website_service import WebsiteService
class WebsiteCrawlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
parser.add_argument('provider', type=str, choices=['firecrawl'],
required=True, nullable=True, location='json')
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# crawl url
@@ -33,15 +35,15 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
args = parser.parse_args()
# get crawl status
try:
result = WebsiteService.get_crawl_status(job_id, args["provider"])
result = WebsiteService.get_crawl_status(job_id, args['provider'])
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
api.add_resource(WebsiteCrawlApi, "/website/crawl")
api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>")
api.add_resource(WebsiteCrawlApi, '/website/crawl')
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')

View File

@@ -2,41 +2,35 @@ from libs.exception import BaseHTTPException
class AlreadySetupError(BaseHTTPException):
error_code = "already_setup"
error_code = 'already_setup'
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
code = 403
class NotSetupError(BaseHTTPException):
error_code = "not_setup"
description = (
"Dify has not been initialized and installed yet. "
"Please proceed with the initialization and installation process first."
)
error_code = 'not_setup'
description = "Dify has not been initialized and installed yet. " \
"Please proceed with the initialization and installation process first."
code = 401
class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated"
description = (
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
error_code = 'not_init_validated'
description = "Init validation has not been completed yet. " \
"Please proceed with the init validation process first."
code = 401
class InitValidateFailedError(BaseHTTPException):
error_code = "init_validate_failed"
error_code = 'init_validate_failed'
description = "Init validation failed. Please check the password and try again."
code = 401
class AccountNotLinkTenantError(BaseHTTPException):
error_code = "account_not_link_tenant"
error_code = 'account_not_link_tenant'
description = "Account not link tenant."
code = 403
class AlreadyActivateError(BaseHTTPException):
error_code = "already_activate"
error_code = 'already_activate'
description = "Auth Token is invalid or account already activated, please check again."
code = 403

View File

@@ -33,10 +33,14 @@ class ChatAudioApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
file = request.files["file"]
file = request.files['file']
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
response = AudioService.transcript_asr(
app_model=app_model,
file=file,
end_user=None
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -72,31 +76,30 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
response = AudioService.transcript_tts(
app_model=app_model,
message_id=message_id,
voice=voice,
text=text
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
@@ -124,7 +127,7 @@ class ChatTextApi(InstalledAppResource):
raise InternalServerError()
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id')

View File

@@ -30,28 +30,33 @@ from services.app_generate_service import AppGenerateService
# define completion api for user
class CompletionApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != 'completion':
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
return helper.compact_generate_response(response)
@@ -80,12 +85,12 @@ class CompletionApi(InstalledAppResource):
class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != 'completion':
raise NotCompletionAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
class ChatApi(InstalledAppResource):
@@ -96,21 +101,25 @@ class ChatApi(InstalledAppResource):
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
args["auto_generate_name"] = False
args['auto_generate_name'] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=True
)
return helper.compact_generate_response(response)
@@ -145,22 +154,10 @@ class ChatStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200
return {'result': 'success'}, 200
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')

View File

@@ -16,6 +16,7 @@ from services.web_conversation_service import WebConversationService
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app):
app_model = installed_app.app
@@ -24,21 +25,21 @@ class ConversationListApi(InstalledAppResource):
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
args = parser.parse_args()
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = True if args["pinned"] == "true" else False
if 'pinned' in args and args['pinned'] is not None:
pinned = True if args['pinned'] == 'true' else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=args['last_id'],
limit=args['limit'],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
@@ -64,6 +65,7 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id):
app_model = installed_app.app
@@ -74,19 +76,24 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, location="json")
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
try:
return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@@ -116,26 +123,8 @@ class ConversationUnPinApi(InstalledAppResource):
return {"result": "success"}
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')

View File

@@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
class NotCompletionAppError(BaseHTTPException):
error_code = "not_completion_app"
error_code = 'not_completion_app'
description = "Not Completion App"
code = 400
class NotChatAppError(BaseHTTPException):
error_code = "not_chat_app"
error_code = 'not_chat_app'
description = "App mode is invalid."
code = 400
class NotWorkflowAppError(BaseHTTPException):
error_code = "not_workflow_app"
error_code = 'not_workflow_app'
description = "Only support workflow app."
code = 400
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = "app_suggested_questions_after_answer_disabled"
error_code = 'app_suggested_questions_after_answer_disabled'
description = "Function Suggested questions after answer disabled."
code = 403

View File

@@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
current_tenant_id = current_user.current_tenant_id
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.query(InstalledApp).filter(
InstalledApp.tenant_id == current_tenant_id
).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [
{
"id": installed_app.id,
"app": installed_app.app,
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"],
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
'id': installed_app.id,
'app': installed_app.app,
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
'is_pinned': installed_app.is_pinned,
'last_used_at': installed_app.last_used_at,
'editable': current_user.role in ["owner", "admin"],
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
}
for installed_app in installed_apps
if installed_app.app is not None
]
installed_apps.sort(
key=lambda app: (
-app["is_pinned"],
app["last_used_at"] is None,
-app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
)
)
installed_apps.sort(key=lambda app: (-app['is_pinned'],
app['last_used_at'] is None,
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
return {"installed_apps": installed_apps}
return {'installed_apps': installed_apps}
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@cloud_edition_billing_resource_check('apps')
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
args = parser.parse_args()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
if recommended_app is None:
raise NotFound("App not found")
raise NotFound('App not found')
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).filter(App.id == args["app_id"]).first()
app = db.session.query(App).filter(
App.id == args['app_id']
).first()
if app is None:
raise NotFound("App not found")
raise NotFound('App not found')
if not app.is_public:
raise Forbidden("You can't install a non-public app")
raise Forbidden('You can\'t install a non-public app')
installed_app = InstalledApp.query.filter(
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
).first()
installed_app = InstalledApp.query.filter(and_(
InstalledApp.app_id == args['app_id'],
InstalledApp.tenant_id == current_tenant_id
)).first()
if installed_app is None:
# todo: position
recommended_app.install_count += 1
new_installed_app = InstalledApp(
app_id=args["app_id"],
app_id=args['app_id'],
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(new_installed_app)
db.session.commit()
return {"message": "App installed successfully"}
return {'message': 'App installed successfully'}
class InstalledAppApi(InstalledAppResource):
@@ -94,31 +94,30 @@ class InstalledAppApi(InstalledAppResource):
update and delete an installed app
use InstalledAppResource to apply default decorators and get installed_app
"""
def delete(self, installed_app):
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
raise BadRequest('You can\'t uninstall an app owned by the current tenant')
db.session.delete(installed_app)
db.session.commit()
return {"result": "success", "message": "App uninstalled successfully"}
return {'result': 'success', 'message': 'App uninstalled successfully'}
def patch(self, installed_app):
parser = reqparse.RequestParser()
parser.add_argument("is_pinned", type=inputs.boolean)
parser.add_argument('is_pinned', type=inputs.boolean)
args = parser.parse_args()
commit_args = False
if "is_pinned" in args:
installed_app.is_pinned = args["is_pinned"]
if 'is_pinned' in args:
installed_app.is_pinned = args['is_pinned']
commit_args = True
if commit_args:
db.session.commit()
return {"result": "success", "message": "App info updated successfully"}
return {'result': 'success', 'message': 'App info updated successfully'}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
api.add_resource(InstalledAppsListApi, '/installed-apps')
api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')

View File

@@ -44,21 +44,19 @@ class MessageListApi(InstalledAppResource):
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
return MessageService.pagination_by_first_id(app_model, current_user,
args['conversation_id'], args['first_id'], args['limit'])
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id):
app_model = installed_app.app
@@ -66,32 +64,30 @@ class MessageFeedbackApi(InstalledAppResource):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return {'result': 'success'}
class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != 'completion':
raise NotCompletionAppError()
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
streaming = args['response_mode'] == 'streaming'
try:
response = AppGenerateService.generate_more_like_this(
@@ -99,7 +95,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming,
streaming=streaming
)
return helper.compact_generate_response(response)
except MessageNotExistsError:
@@ -132,7 +128,10 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
@@ -152,22 +151,10 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logging.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
return {'data': questions}
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')

View File

@@ -1,3 +1,4 @@
from flask_restful import fields, marshal_with
from configs import dify_config
@@ -10,32 +11,33 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
variable_fields = {
"key": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"default": fields.String,
"max_length": fields.Integer,
"options": fields.List(fields.String),
'key': fields.String,
'name': fields.String,
'description': fields.String,
'type': fields.String,
'default': fields.String,
'max_length': fields.Integer,
'options': fields.List(fields.String)
}
system_parameters_fields = {"image_file_size_limit": fields.String}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(system_parameters_fields),
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'text_to_speech': fields.Raw,
'retriever_resource': fields.Raw,
'annotation_reply': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -54,35 +56,30 @@ class AppParameterApi(InstalledAppResource):
app_model_config = app_model.app_model_config
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
user_input_form = features_dict.get('user_input_form', [])
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get(
"suggested_questions_after_answer", {"enabled": False}
),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
'opening_statement': features_dict.get('opening_statement'),
'suggested_questions': features_dict.get('suggested_questions', []),
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
{"enabled": False}),
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
'user_input_form': user_input_form,
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
{"enabled": False, "type": "", "configs": []}),
'file_upload': features_dict.get('file_upload', {"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': {
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
}
}
@@ -93,7 +90,6 @@ class ExploreAppMetaApi(InstalledAppResource):
return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
endpoint='installed_app_parameters')
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')

View File

@@ -8,28 +8,28 @@ from libs.login import login_required
from services.recommended_app_service import RecommendedAppService
app_fields = {
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_background": fields.String,
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
recommended_app_fields = {
"app": fields.Nested(app_fields, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
'app': fields.Nested(app_fields, attribute='app'),
'app_id': fields.String,
'description': fields.String(attribute='description'),
'copyright': fields.String,
'privacy_policy': fields.String,
'custom_disclaimer': fields.String,
'category': fields.String,
'position': fields.Integer,
'is_listed': fields.Boolean
}
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
"categories": fields.List(fields.String),
'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
'categories': fields.List(fields.String)
}
@@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
def get(self):
# language args
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, location="args")
parser.add_argument('language', type=str, location='args')
args = parser.parse_args()
if args.get("language") and args.get("language") in languages:
language_prefix = args.get("language")
if args.get('language') and args.get('language') in languages:
language_prefix = args.get('language')
elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language
else:
@@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
api.add_resource(RecommendedAppListApi, '/explore/apps')
api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')

View File

@@ -11,54 +11,56 @@ from libs.helper import TimestampField, uuid_value
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
feedback_fields = {"rating": fields.String}
feedback_fields = {
'rating': fields.String
}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
'id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != 'completion':
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
if app_model.mode != 'completion':
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
args = parser.parse_args()
try:
SavedMessageService.save(app_model, current_user, args["message_id"])
SavedMessageService.save(app_model, current_user, args['message_id'])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return {'result': 'success'}
class SavedMessageApi(InstalledAppResource):
@@ -67,21 +69,13 @@ class SavedMessageApi(InstalledAppResource):
message_id = str(message_id)
if app_model.mode != "completion":
if app_model.mode != 'completion':
raise NotCompletionAppError()
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}
return {'result': 'success'}
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)
api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')

View File

@@ -35,13 +35,17 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument('files', type=list, required=False, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=True
)
return helper.compact_generate_response(response)
@@ -72,10 +76,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}
return {
"result": "success"
}
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)
api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')

View File

@@ -14,33 +14,29 @@ def installed_app_required(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
if not kwargs.get('installed_app_id'):
raise ValueError('missing installed_app_id in path parameters')
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = kwargs.get('installed_app_id')
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
del kwargs['installed_app_id']
installed_app = (
db.session.query(InstalledApp)
.filter(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first()
)
installed_app = db.session.query(InstalledApp).filter(
InstalledApp.id == str(installed_app_id),
InstalledApp.tenant_id == current_user.current_tenant_id
).first()
if installed_app is None:
raise NotFound("Installed app not found")
raise NotFound('Installed app not found')
if not installed_app.app:
db.session.delete(installed_app)
db.session.commit()
raise NotFound("Installed app not found")
raise NotFound('Installed app not found')
return view(installed_app, *args, **kwargs)
return decorated
if view:

View File

@@ -13,18 +13,23 @@ from services.code_based_extension_service import CodeBasedExtensionService
class CodeBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("module", type=str, required=True, location="args")
parser.add_argument('module', type=str, required=True, location='args')
args = parser.parse_args()
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
return {
'module': args['module'],
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
}
class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -39,22 +44,23 @@ class APIBasedExtensionAPI(Resource):
@marshal_with(api_based_extension_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
args = parser.parse_args()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
name=args['name'],
api_endpoint=args['api_endpoint'],
api_key=args['api_key']
)
return APIBasedExtensionService.save(extension_data)
class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -76,16 +82,16 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
args = parser.parse_args()
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]
extension_data_from_db.name = args['name']
extension_data_from_db.api_endpoint = args['api_endpoint']
if args["api_key"] != HIDDEN_VALUE:
extension_data_from_db.api_key = args["api_key"]
if args['api_key'] != HIDDEN_VALUE:
extension_data_from_db.api_key = args['api_key']
return APIBasedExtensionService.save(extension_data_from_db)
@@ -100,10 +106,10 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return {"result": "success"}
return {'result': 'success'}
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')

View File

@@ -10,6 +10,7 @@ from .wraps import account_initialization_required, cloud_utm_record
class FeatureApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -23,5 +24,5 @@ class SystemFeatureApi(Resource):
return FeatureService.get_system_features().model_dump()
api.add_resource(FeatureApi, "/features")
api.add_resource(SystemFeatureApi, "/system-features")
api.add_resource(FeatureApi, '/features')
api.add_resource(SystemFeatureApi, '/system-features')

View File

@@ -14,11 +14,12 @@ from .wraps import only_edition_self_hosted
class InitValidateAPI(Resource):
def get(self):
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
return { 'status': 'finished' }
return {'status': 'not_started' }
@only_edition_self_hosted
def post(self):
@@ -28,23 +29,22 @@ class InitValidateAPI(Resource):
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument("password", type=str_len(30), required=True, location="json")
input_password = parser.parse_args()["password"]
parser.add_argument('password', type=str_len(30),
required=True, location='json')
input_password = parser.parse_args()['password']
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
if input_password != os.environ.get('INIT_PASSWORD'):
session['is_init_validated'] = False
raise InitValidateFailedError()
session["is_init_validated"] = True
return {"result": "success"}, 201
session['is_init_validated'] = True
return {'result': 'success'}, 201
def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first()
if dify_config.EDITION == 'SELF_HOSTED':
if os.environ.get('INIT_PASSWORD'):
return session.get('is_init_validated') or DifySetup.query.first()
return True
api.add_resource(InitValidateAPI, "/init")
api.add_resource(InitValidateAPI, '/init')

View File

@@ -4,11 +4,14 @@ from controllers.console import api
class PingApi(Resource):
def get(self):
"""
For connection health check
"""
return {"result": "pong"}
return {
"result": "pong"
}
api.add_resource(PingApi, "/ping")
api.add_resource(PingApi, '/ping')

View File

@@ -16,13 +16,17 @@ from .wraps import only_edition_self_hosted
class SetupApi(Resource):
def get(self):
if dify_config.EDITION == "SELF_HOSTED":
if dify_config.EDITION == 'SELF_HOSTED':
setup_status = get_setup_status()
if setup_status:
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
return {"step": "not_started"}
return {"step": "finished"}
return {
'step': 'finished',
'setup_at': setup_status.setup_at.isoformat()
}
return {'step': 'not_started'}
return {'step': 'finished'}
@only_edition_self_hosted
def post(self):
@@ -34,22 +38,28 @@ class SetupApi(Resource):
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if not get_init_validate_status():
raise NotInitValidateError()
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument('email', type=email,
required=True, location='json')
parser.add_argument('name', type=str_len(
30), required=True, location='json')
parser.add_argument('password', type=valid_password,
required=True, location='json')
args = parser.parse_args()
# setup
RegisterService.setup(
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
email=args['email'],
name=args['name'],
password=args['password'],
ip_address=get_remote_ip(request)
)
return {"result": "success"}, 201
return {'result': 'success'}, 201
def setup_required(view):
@@ -58,7 +68,7 @@ def setup_required(view):
# check setup
if not get_init_validate_status():
raise NotInitValidateError()
elif not get_setup_status():
raise NotSetupError()
@@ -68,10 +78,9 @@ def setup_required(view):
def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED":
if dify_config.EDITION == 'SELF_HOSTED':
return DifySetup.query.first()
else:
return True
api.add_resource(SetupApi, "/setup")
api.add_resource(SetupApi, '/setup')

View File

@@ -14,18 +14,19 @@ from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 50 characters.")
raise ValueError('Name must be between 1 to 50 characters.')
return name
class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tag_fields)
def get(self):
tag_type = request.args.get("type", type=str)
keyword = request.args.get("keyword", default=None, type=str)
tag_type = request.args.get('type', type=str)
keyword = request.args.get('keyword', default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
return tags, 200
@@ -39,21 +40,28 @@ class TagListApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': 0
}
return response, 200
class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -64,15 +72,20 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
args = parser.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': binding_count
}
return response, 200
@@ -91,6 +104,7 @@ class TagUpdateDeleteApi(Resource):
class TagBindingCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -100,15 +114,14 @@ class TagBindingCreateApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
parser.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
help='Tag IDs is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
TagService.save_tag_binding(args)
@@ -116,6 +129,7 @@ class TagBindingCreateApi(Resource):
class TagBindingDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@@ -125,18 +139,21 @@ class TagBindingDeleteApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
parser.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
parser.add_argument('tag_id', type=str, nullable=False, required=True,
help='Tag ID is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True,
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
TagService.delete_tag_binding(args)
return 200
api.add_resource(TagListApi, "/tags")
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")
api.add_resource(TagListApi, '/tags')
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')

View File

@@ -1,3 +1,4 @@
import json
import logging
@@ -10,39 +11,42 @@ from . import api
class VersionApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("current_version", type=str, required=True, location="args")
parser.add_argument('current_version', type=str, required=True, location='args')
args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
"version": dify_config.CURRENT_VERSION,
"release_date": "",
"release_notes": "",
"can_auto_update": False,
"features": {
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
},
'version': dify_config.CURRENT_VERSION,
'release_date': '',
'release_notes': '',
'can_auto_update': False,
'features': {
'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
}
}
if not check_update_url:
return result
try:
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
response = requests.get(check_update_url, {
'current_version': args.get('current_version')
})
except Exception as error:
logging.warning("Check update version error: {}.".format(str(error)))
result["version"] = args.get("current_version")
result['version'] = args.get('current_version')
return result
content = json.loads(response.content)
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
result['version'] = content['version']
result['release_date'] = content['releaseDate']
result['release_notes'] = content['releaseNotes']
result['can_auto_update'] = content['canAutoUpdate']
return result
api.add_resource(VersionApi, "/version")
api.add_resource(VersionApi, '/version')

Some files were not shown because too many files have changed in this diff Show More