mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 14:58:32 +00:00
Compare commits
7 Commits
0.8.0-beta
...
feat/workf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fcb048607 | ||
|
|
18b61591b8 | ||
|
|
5408e923e3 | ||
|
|
1bcfa747db | ||
|
|
071b7d607b | ||
|
|
1a5e5cd8f8 | ||
|
|
3bd5f9542e |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -125,6 +125,7 @@ jobs:
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
@@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
|
||||
## Before you jump in
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
|
||||
|
||||
## 飛び込む前に
|
||||
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
|
||||
### 機能リクエスト
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
|
||||
|
||||
## Trước khi bắt đầu
|
||||
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
|
||||
### Yêu cầu tính năng:
|
||||
|
||||
|
||||
@@ -60,8 +60,7 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
|
||||
1
api/.idea/vcs.xml
generated
1
api/.idea/vcs.xml
generated
@@ -12,6 +12,5 @@
|
||||
</component>
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
|
||||
@@ -5,10 +5,6 @@ WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
|
||||
|
||||
# Configure Poetry
|
||||
@@ -20,9 +16,6 @@ ENV POETRY_REQUESTS_TIMEOUT=15
|
||||
|
||||
FROM base AS packages
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
@@ -50,12 +43,10 @@ WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -65,7 +56,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
RUN python -c "import nltk; nltk.download('punkt')"
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
|
||||
@@ -559,9 +559,8 @@ def add_qdrant_doc_id_index(field: str):
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
@@ -581,15 +580,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
||||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
dify_config = DifyConfig()
|
||||
@@ -1,3 +1,4 @@
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from configs.deploy import DeploymentConfig
|
||||
@@ -23,16 +24,44 @@ class DifyConfig(
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
):
|
||||
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
frozen=True,
|
||||
# ignore extra attributes
|
||||
extra="ignore",
|
||||
extra='ignore',
|
||||
)
|
||||
|
||||
# Before adding any config,
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
# Thanks for your concentration and consideration.
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_DEPTH: int = 5
|
||||
CODE_MAX_PRECISION: int = 20
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
SSRF_PROXY_HTTP_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
|
||||
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
|
||||
|
||||
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')
|
||||
|
||||
@@ -6,28 +6,22 @@ class DeploymentConfig(BaseSettings):
|
||||
"""
|
||||
Deployment configs
|
||||
"""
|
||||
|
||||
APPLICATION_NAME: str = Field(
|
||||
description="application name",
|
||||
default="langgenius/dify",
|
||||
)
|
||||
|
||||
DEBUG: bool = Field(
|
||||
description="whether to enable debug mode.",
|
||||
default=False,
|
||||
description='application name',
|
||||
default='langgenius/dify',
|
||||
)
|
||||
|
||||
TESTING: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
description="deployment edition",
|
||||
default="SELF_HOSTED",
|
||||
description='deployment edition',
|
||||
default='SELF_HOSTED',
|
||||
)
|
||||
|
||||
DEPLOY_ENV: str = Field(
|
||||
description="deployment environment, default to PRODUCTION.",
|
||||
default="PRODUCTION",
|
||||
description='deployment environment, default to PRODUCTION.',
|
||||
default='PRODUCTION',
|
||||
)
|
||||
|
||||
@@ -7,14 +7,13 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
Enterprise feature configs.
|
||||
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
"""
|
||||
|
||||
ENTERPRISE_ENABLED: bool = Field(
|
||||
description="whether to enable enterprise features."
|
||||
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
|
||||
description='whether to enable enterprise features.'
|
||||
'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
|
||||
default=False,
|
||||
)
|
||||
|
||||
CAN_REPLACE_LOGO: bool = Field(
|
||||
description="whether to allow replacing enterprise logo.",
|
||||
description='whether to allow replacing enterprise logo.',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -8,28 +8,27 @@ class NotionConfig(BaseSettings):
|
||||
"""
|
||||
Notion integration configs
|
||||
"""
|
||||
|
||||
NOTION_CLIENT_ID: Optional[str] = Field(
|
||||
description="Notion client ID",
|
||||
description='Notion client ID',
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="Notion client secret key",
|
||||
description='Notion client secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
|
||||
description="Notion integration type, default to None, available values: internal.",
|
||||
description='Notion integration type, default to None, available values: internal.',
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTERNAL_SECRET: Optional[str] = Field(
|
||||
description="Notion internal secret key",
|
||||
description='Notion internal secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
|
||||
description="Notion integration token",
|
||||
description='Notion integration token',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -8,18 +8,17 @@ class SentryConfig(BaseSettings):
|
||||
"""
|
||||
Sentry configs
|
||||
"""
|
||||
|
||||
SENTRY_DSN: Optional[str] = Field(
|
||||
description="Sentry DSN",
|
||||
description='Sentry DSN',
|
||||
default=None,
|
||||
)
|
||||
|
||||
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description="Sentry trace sample rate",
|
||||
description='Sentry trace sample rate',
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description="Sentry profiles sample rate",
|
||||
description='Sentry profiles sample rate',
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Annotated, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
@@ -10,17 +10,16 @@ class SecurityConfig(BaseSettings):
|
||||
"""
|
||||
Secret Key configs
|
||||
"""
|
||||
|
||||
SECRET_KEY: Optional[str] = Field(
|
||||
description="Your App secret key will be used for securely signing the session cookie"
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"You can generate a strong key using `openssl rand -base64 42`."
|
||||
"Alternatively you can set it with `SECRET_KEY` environment variable.",
|
||||
description='Your App secret key will be used for securely signing the session cookie'
|
||||
'Make sure you are changing this key for your deployment with a strong key.'
|
||||
'You can generate a strong key using `openssl rand -base64 42`.'
|
||||
'Alternatively you can set it with `SECRET_KEY` environment variable.',
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description="Expiry time in hours for reset token",
|
||||
description='Expiry time in hours for reset token',
|
||||
default=24,
|
||||
)
|
||||
|
||||
@@ -29,13 +28,12 @@ class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
App Execution configs
|
||||
"""
|
||||
|
||||
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description="execution timeout in seconds for app execution",
|
||||
description='execution timeout in seconds for app execution',
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="max active request per app, 0 means unlimited",
|
||||
description='max active request per app, 0 means unlimited',
|
||||
default=0,
|
||||
)
|
||||
|
||||
@@ -44,70 +42,14 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
Code Execution Sandbox configs
|
||||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution servcie",
|
||||
default="http://sandbox:8194",
|
||||
CODE_EXECUTION_ENDPOINT: str = Field(
|
||||
description='endpoint URL of code execution servcie',
|
||||
default='http://sandbox:8194',
|
||||
)
|
||||
|
||||
CODE_EXECUTION_API_KEY: str = Field(
|
||||
description="API key for code execution service",
|
||||
default="dify-sandbox",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
|
||||
description="connect timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
|
||||
description="read timeout in seconds for code execution request",
|
||||
default=60.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
|
||||
description="write timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MIN_NUMBER: NegativeInt = Field(
|
||||
description="",
|
||||
default=-9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MAX_DEPTH: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=5,
|
||||
)
|
||||
|
||||
CODE_MAX_PRECISION: PositiveInt = Field(
|
||||
description="max precision digits for float type in code execution",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="max string length for code execution",
|
||||
default=80000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
default=1000,
|
||||
description='API key for code execution service',
|
||||
default='dify-sandbox',
|
||||
)
|
||||
|
||||
|
||||
@@ -115,27 +57,28 @@ class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Module URL configs
|
||||
"""
|
||||
|
||||
CONSOLE_API_URL: str = Field(
|
||||
description="The backend URL prefix of the console API."
|
||||
"used to concatenate the login authorization callback or notion integration callback.",
|
||||
default="",
|
||||
description='The backend URL prefix of the console API.'
|
||||
'used to concatenate the login authorization callback or notion integration callback.',
|
||||
default='',
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="The front-end URL prefix of the console web."
|
||||
"used to concatenate some front-end addresses and for CORS configuration use.",
|
||||
default="",
|
||||
description='The front-end URL prefix of the console web.'
|
||||
'used to concatenate some front-end addresses and for CORS configuration use.',
|
||||
default='',
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
|
||||
default="",
|
||||
description='Service API Url prefix.'
|
||||
'used to display Service API Base Url to the front-end.',
|
||||
default='',
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
|
||||
default="",
|
||||
description='WebApp Url prefix.'
|
||||
'used to display WebAPP API Base Url to the front-end.',
|
||||
default='',
|
||||
)
|
||||
|
||||
|
||||
@@ -143,18 +86,17 @@ class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
File Access configs
|
||||
"""
|
||||
|
||||
FILES_URL: str = Field(
|
||||
description="File preview or download Url prefix."
|
||||
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
|
||||
"Url is signed and has expiration time.",
|
||||
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
|
||||
description='File preview or download Url prefix.'
|
||||
' used to display File preview or download Url to the front-end or as Multi-model inputs;'
|
||||
'Url is signed and has expiration time.',
|
||||
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
|
||||
alias_priority=1,
|
||||
default="",
|
||||
default='',
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="timeout in seconds for file accessing",
|
||||
description='timeout in seconds for file accessing',
|
||||
default=300,
|
||||
)
|
||||
|
||||
@@ -163,24 +105,23 @@ class FileUploadConfig(BaseSettings):
|
||||
"""
|
||||
File Uploading configs
|
||||
"""
|
||||
|
||||
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="size limit in Megabytes for uploading files",
|
||||
description='size limit in Megabytes for uploading files',
|
||||
default=15,
|
||||
)
|
||||
|
||||
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
|
||||
description="batch size limit for uploading files",
|
||||
description='batch size limit for uploading files',
|
||||
default=5,
|
||||
)
|
||||
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="image file size limit in Megabytes for uploading files",
|
||||
description='image file size limit in Megabytes for uploading files',
|
||||
default=10,
|
||||
)
|
||||
|
||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||
description="", # todo: to be clarified
|
||||
description='', # todo: to be clarified
|
||||
default=20,
|
||||
)
|
||||
|
||||
@@ -189,79 +130,45 @@ class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP configs
|
||||
"""
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="whether to enable HTTP response compression of gzip",
|
||||
description='whether to enable HTTP response compression of gzip',
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description="",
|
||||
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
|
||||
default="",
|
||||
description='',
|
||||
validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
|
||||
default='',
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
|
||||
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
|
||||
|
||||
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description="",
|
||||
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"),
|
||||
default="*",
|
||||
description='',
|
||||
validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
|
||||
default='*',
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
] = 10
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
] = 60
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
] = 20
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
default=10 * 1024 * 1024,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
|
||||
description="HTTP URL for SSRF proxy",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
|
||||
description="HTTPS URL for SSRF proxy",
|
||||
default=None,
|
||||
)
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
|
||||
|
||||
|
||||
class InnerAPIConfig(BaseSettings):
|
||||
"""
|
||||
Inner API configs
|
||||
"""
|
||||
|
||||
INNER_API: bool = Field(
|
||||
description="whether to enable the inner API",
|
||||
description='whether to enable the inner API',
|
||||
default=False,
|
||||
)
|
||||
|
||||
INNER_API_KEY: Optional[str] = Field(
|
||||
description="The inner API key is used to authenticate the inner API",
|
||||
description='The inner API key is used to authenticate the inner API',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -272,27 +179,28 @@ class LoggingConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
LOG_LEVEL: str = Field(
|
||||
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
|
||||
default="INFO",
|
||||
description='Log output level, default to INFO.'
|
||||
'It is recommended to set it to ERROR for production.',
|
||||
default='INFO',
|
||||
)
|
||||
|
||||
LOG_FILE: Optional[str] = Field(
|
||||
description="logging output file path",
|
||||
description='logging output file path',
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="log format",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
description='log format',
|
||||
default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: Optional[str] = Field(
|
||||
description="log date format",
|
||||
description='log date format',
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_TZ: Optional[str] = Field(
|
||||
description="specify log timezone, eg: America/New_York",
|
||||
description='specify log timezone, eg: America/New_York',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -301,9 +209,8 @@ class ModelLoadBalanceConfig(BaseSettings):
|
||||
"""
|
||||
Model load balance configs
|
||||
"""
|
||||
|
||||
MODEL_LB_ENABLED: bool = Field(
|
||||
description="whether to enable model load balancing",
|
||||
description='whether to enable model load balancing',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -312,9 +219,8 @@ class BillingConfig(BaseSettings):
|
||||
"""
|
||||
Platform Billing Configurations
|
||||
"""
|
||||
|
||||
BILLING_ENABLED: bool = Field(
|
||||
description="whether to enable billing",
|
||||
description='whether to enable billing',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -323,10 +229,9 @@ class UpdateConfig(BaseSettings):
|
||||
"""
|
||||
Update configs
|
||||
"""
|
||||
|
||||
CHECK_UPDATE_URL: str = Field(
|
||||
description="url for checking updates",
|
||||
default="https://updates.dify.ai",
|
||||
description='url for checking updates',
|
||||
default='https://updates.dify.ai',
|
||||
)
|
||||
|
||||
|
||||
@@ -336,53 +241,47 @@ class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
|
||||
description="max execution steps in single workflow execution",
|
||||
description='max execution steps in single workflow execution',
|
||||
default=500,
|
||||
)
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description="max execution time in seconds in single workflow execution",
|
||||
description='max execution time in seconds in single workflow execution',
|
||||
default=1200,
|
||||
)
|
||||
|
||||
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
|
||||
description="max depth of calling in single workflow execution",
|
||||
description='max depth of calling in single workflow execution',
|
||||
default=5,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="The maximum size in bytes of a variable. default to 5KB.",
|
||||
default=5 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
oauth configs
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
description="redirect path for OAuth",
|
||||
default="/console/api/oauth/authorize",
|
||||
description='redirect path for OAuth',
|
||||
default='/console/api/oauth/authorize',
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description="GitHub client id for OAuth",
|
||||
description='GitHub client id for OAuth',
|
||||
default=None,
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="GitHub client secret key for OAuth",
|
||||
description='GitHub client secret key for OAuth',
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_ID: Optional[str] = Field(
|
||||
description="Google client id for OAuth",
|
||||
description='Google client id for OAuth',
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="Google client secret key for OAuth",
|
||||
description='Google client secret key for OAuth',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -392,8 +291,9 @@ class ModerationConfig(BaseSettings):
|
||||
Moderation in app configs.
|
||||
"""
|
||||
|
||||
MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description="buffer size for moderation",
|
||||
# todo: to be clarified in usage and unit
|
||||
OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description='buffer size for moderation',
|
||||
default=300,
|
||||
)
|
||||
|
||||
@@ -404,7 +304,7 @@ class ToolConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
|
||||
description="max age in seconds for tool icon caching",
|
||||
description='max age in seconds for tool icon caching',
|
||||
default=3600,
|
||||
)
|
||||
|
||||
@@ -415,52 +315,52 @@ class MailConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
|
||||
description="default email address for sending from ",
|
||||
description='default email address for sending from ',
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_KEY: Optional[str] = Field(
|
||||
description="API key for Resend",
|
||||
description='API key for Resend',
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_URL: Optional[str] = Field(
|
||||
description="API URL for Resend",
|
||||
description='API URL for Resend',
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_SERVER: Optional[str] = Field(
|
||||
description="smtp server host",
|
||||
description='smtp server host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PORT: Optional[int] = Field(
|
||||
description="smtp server port",
|
||||
description='smtp server port',
|
||||
default=465,
|
||||
)
|
||||
|
||||
SMTP_USERNAME: Optional[str] = Field(
|
||||
description="smtp server username",
|
||||
description='smtp server username',
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PASSWORD: Optional[str] = Field(
|
||||
description="smtp server password",
|
||||
description='smtp server password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_USE_TLS: bool = Field(
|
||||
description="whether to use TLS connection to smtp server",
|
||||
description='whether to use TLS connection to smtp server',
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_OPPORTUNISTIC_TLS: bool = Field(
|
||||
description="whether to use opportunistic TLS connection to smtp server",
|
||||
description='whether to use opportunistic TLS connection to smtp server',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -471,22 +371,22 @@ class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
ETL_TYPE: str = Field(
|
||||
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
|
||||
default="dify",
|
||||
description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
|
||||
default='dify',
|
||||
)
|
||||
|
||||
KEYWORD_DATA_SOURCE_TYPE: str = Field(
|
||||
description="source type for keyword data, default to `database`, available values are `database` .",
|
||||
default="database",
|
||||
description='source type for keyword data, default to `database`, available values are `database` .',
|
||||
default='database',
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_URL: Optional[str] = Field(
|
||||
description="API URL for Unstructured",
|
||||
description='API URL for Unstructured',
|
||||
default=None,
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = Field(
|
||||
description="API key for Unstructured",
|
||||
description='API key for Unstructured',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -497,12 +397,12 @@ class DataSetConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="interval in days for cleaning up dataset",
|
||||
description='interval in days for cleaning up dataset',
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description="whether to enable dataset operator",
|
||||
description='whether to enable dataset operator',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -513,7 +413,7 @@ class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
INVITE_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description="workspaces invitation expiration in hours",
|
||||
description='workspaces invitation expiration in hours',
|
||||
default=72,
|
||||
)
|
||||
|
||||
@@ -524,79 +424,80 @@ class IndexingConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
|
||||
description="max segmentation token length for indexing",
|
||||
description='max segmentation token length for indexing',
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
description="multi model send image format, support base64, url, default is base64",
|
||||
default="base64",
|
||||
description='multi model send image format, support base64, url, default is base64',
|
||||
default='base64',
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description="the time of the celery scheduler, default to 1 day",
|
||||
description='the time of the celery scheduler, default to 1 day',
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
|
||||
POSITION_PROVIDER_PINS: str = Field(
|
||||
description="The heads of model providers",
|
||||
default="",
|
||||
description='The heads of model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||
description="The included model providers",
|
||||
default="",
|
||||
description='The included model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||
description="The excluded model providers",
|
||||
default="",
|
||||
description='The excluded model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_PINS: str = Field(
|
||||
description="The heads of tools",
|
||||
default="",
|
||||
description='The heads of tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_INCLUDES: str = Field(
|
||||
description="The included tools",
|
||||
default="",
|
||||
description='The included tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_EXCLUDES: str = Field(
|
||||
description="The excluded tools",
|
||||
default="",
|
||||
description='The excluded tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
@@ -624,6 +525,7 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
||||
@@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description="",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"text-davinci-003",
|
||||
description='',
|
||||
default='gpt-3.5-turbo,'
|
||||
'gpt-3.5-turbo-1106,'
|
||||
'gpt-3.5-turbo-instruct,'
|
||||
'gpt-3.5-turbo-16k,'
|
||||
'gpt-3.5-turbo-16k-0613,'
|
||||
'gpt-3.5-turbo-0613,'
|
||||
'gpt-3.5-turbo-0125,'
|
||||
'text-davinci-003',
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=200,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_MODELS: str = Field(
|
||||
description="",
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003",
|
||||
description='',
|
||||
default='gpt-4,'
|
||||
'gpt-4-turbo-preview,'
|
||||
'gpt-4-turbo-2024-04-09,'
|
||||
'gpt-4-1106-preview,'
|
||||
'gpt-4-0125-preview,'
|
||||
'gpt-3.5-turbo,'
|
||||
'gpt-3.5-turbo-16k,'
|
||||
'gpt-3.5-turbo-16k-0613,'
|
||||
'gpt-3.5-turbo-1106,'
|
||||
'gpt-3.5-turbo-0613,'
|
||||
'gpt-3.5-turbo-0125,'
|
||||
'gpt-3.5-turbo-instruct,'
|
||||
'text-davinci-003',
|
||||
)
|
||||
|
||||
|
||||
@@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=200,
|
||||
)
|
||||
|
||||
@@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_MINIMAX_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_SPARK_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_ZHIPUAI_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_MODERATION_ENABLED: bool = Field(
|
||||
description="",
|
||||
description='',
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_MODERATION_PROVIDERS: str = Field(
|
||||
description="",
|
||||
default="",
|
||||
description='',
|
||||
default='',
|
||||
)
|
||||
|
||||
|
||||
@@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="the mode for fetching app templates,"
|
||||
" default to remote,"
|
||||
" available values: remote, db, builtin",
|
||||
default="remote",
|
||||
description='the mode for fetching app templates,'
|
||||
' default to remote,'
|
||||
' available values: remote, db, builtin',
|
||||
default='remote',
|
||||
)
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="the domain for fetching remote app templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
description='the domain for fetching remote app templates',
|
||||
default='https://tmpl.dify.ai',
|
||||
)
|
||||
|
||||
|
||||
@@ -202,6 +202,7 @@ class HostedServiceConfig(
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
HostedZhipuAIConfig,
|
||||
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
):
|
||||
|
||||
@@ -13,7 +13,6 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
@@ -29,108 +28,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description="storage type,"
|
||||
" default to `local`,"
|
||||
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
|
||||
default="local",
|
||||
description='storage type,'
|
||||
' default to `local`,'
|
||||
' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
|
||||
default='local',
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description="local storage path",
|
||||
default="storage",
|
||||
description='local storage path',
|
||||
default='storage',
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseSettings):
|
||||
VECTOR_STORE: Optional[str] = Field(
|
||||
description="vector store type",
|
||||
description='vector store type',
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
description="keyword store type",
|
||||
default="jieba",
|
||||
description='keyword store type',
|
||||
default='jieba',
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
DB_HOST: str = Field(
|
||||
description="db host",
|
||||
default="localhost",
|
||||
description='db host',
|
||||
default='localhost',
|
||||
)
|
||||
|
||||
DB_PORT: PositiveInt = Field(
|
||||
description="db port",
|
||||
description='db port',
|
||||
default=5432,
|
||||
)
|
||||
|
||||
DB_USERNAME: str = Field(
|
||||
description="db username",
|
||||
default="postgres",
|
||||
description='db username',
|
||||
default='postgres',
|
||||
)
|
||||
|
||||
DB_PASSWORD: str = Field(
|
||||
description="db password",
|
||||
default="",
|
||||
description='db password',
|
||||
default='',
|
||||
)
|
||||
|
||||
DB_DATABASE: str = Field(
|
||||
description="db database",
|
||||
default="dify",
|
||||
description='db database',
|
||||
default='dify',
|
||||
)
|
||||
|
||||
DB_CHARSET: str = Field(
|
||||
description="db charset",
|
||||
default="",
|
||||
description='db charset',
|
||||
default='',
|
||||
)
|
||||
|
||||
DB_EXTRAS: str = Field(
|
||||
description="db extras options. Example: keepalives_idle=60&keepalives=1",
|
||||
default="",
|
||||
description='db extras options. Example: keepalives_idle=60&keepalives=1',
|
||||
default='',
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="db uri scheme",
|
||||
default="postgresql",
|
||||
description='db uri scheme',
|
||||
default='postgresql',
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
|
||||
if self.DB_CHARSET
|
||||
else self.DB_EXTRAS
|
||||
).strip("&")
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (
|
||||
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}"
|
||||
)
|
||||
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}")
|
||||
|
||||
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
|
||||
description="pool size of SqlAlchemy",
|
||||
description='pool size of SqlAlchemy',
|
||||
default=30,
|
||||
)
|
||||
|
||||
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
|
||||
description="max overflows for SqlAlchemy",
|
||||
description='max overflows for SqlAlchemy',
|
||||
default=10,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
|
||||
description="SqlAlchemy pool recycle",
|
||||
description='SqlAlchemy pool recycle',
|
||||
default=3600,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING: bool = Field(
|
||||
description="whether to enable pool pre-ping in SqlAlchemy",
|
||||
description='whether to enable pool pre-ping in SqlAlchemy',
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_ECHO: bool | str = Field(
|
||||
description="whether to enable SqlAlchemy echo",
|
||||
description='whether to enable SqlAlchemy echo',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -138,38 +137,35 @@ class DatabaseConfig:
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": {"options": "-c timezone=UTC"},
|
||||
'pool_size': self.SQLALCHEMY_POOL_SIZE,
|
||||
'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
|
||||
'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
|
||||
'connect_args': {'options': '-c timezone=UTC'},
|
||||
}
|
||||
|
||||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description="Celery backend, available values are `database`, `redis`",
|
||||
default="database",
|
||||
description='Celery backend, available values are `database`, `redis`',
|
||||
default='database',
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: Optional[str] = Field(
|
||||
description="CELERY_BROKER_URL",
|
||||
description='CELERY_BROKER_URL',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
return (
|
||||
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
|
||||
if self.CELERY_BACKEND == "database"
|
||||
else self.CELERY_BROKER_URL
|
||||
)
|
||||
return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def BROKER_USE_SSL(self) -> bool:
|
||||
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
|
||||
return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
@@ -178,6 +174,7 @@ class MiddlewareConfig(
|
||||
DatabaseConfig,
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
|
||||
# configs of storage and storage providers
|
||||
StorageConfig,
|
||||
AliyunOSSStorageConfig,
|
||||
@@ -186,6 +183,7 @@ class MiddlewareConfig(
|
||||
TencentCloudCOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
@@ -201,6 +199,5 @@ class MiddlewareConfig(
|
||||
TencentVectorDBConfig,
|
||||
TiDBVectorConfig,
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
15
api/configs/middleware/cache/redis_config.py
vendored
15
api/configs/middleware/cache/redis_config.py
vendored
@@ -8,33 +8,32 @@ class RedisConfig(BaseSettings):
|
||||
"""
|
||||
Redis configs
|
||||
"""
|
||||
|
||||
REDIS_HOST: str = Field(
|
||||
description="Redis host",
|
||||
default="localhost",
|
||||
description='Redis host',
|
||||
default='localhost',
|
||||
)
|
||||
|
||||
REDIS_PORT: PositiveInt = Field(
|
||||
description="Redis port",
|
||||
description='Redis port',
|
||||
default=6379,
|
||||
)
|
||||
|
||||
REDIS_USERNAME: Optional[str] = Field(
|
||||
description="Redis username",
|
||||
description='Redis username',
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_PASSWORD: Optional[str] = Field(
|
||||
description="Redis password",
|
||||
description='Redis password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_DB: NonNegativeInt = Field(
|
||||
description="Redis database id, default to 0",
|
||||
description='Redis database id, default to 0',
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="whether to use SSL for Redis connection",
|
||||
description='whether to use SSL for Redis connection',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -10,36 +10,31 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Aliyun OSS bucket name",
|
||||
description='Aliyun OSS bucket name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Aliyun OSS access key",
|
||||
description='Aliyun OSS access key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Aliyun OSS secret key",
|
||||
description='Aliyun OSS secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
|
||||
description="Aliyun OSS endpoint URL",
|
||||
description='Aliyun OSS endpoint URL',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_REGION: Optional[str] = Field(
|
||||
description="Aliyun OSS region",
|
||||
description='Aliyun OSS region',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
|
||||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
description='Aliyun OSS authentication version',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
S3_ENDPOINT: Optional[str] = Field(
|
||||
description="S3 storage endpoint",
|
||||
description='S3 storage endpoint',
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_REGION: Optional[str] = Field(
|
||||
description="S3 storage region",
|
||||
description='S3 storage region',
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_BUCKET_NAME: Optional[str] = Field(
|
||||
description="S3 storage bucket name",
|
||||
description='S3 storage bucket name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ACCESS_KEY: Optional[str] = Field(
|
||||
description="S3 storage access key",
|
||||
description='S3 storage access key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_SECRET_KEY: Optional[str] = Field(
|
||||
description="S3 storage secret key",
|
||||
description='S3 storage secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ADDRESS_STYLE: str = Field(
|
||||
description="S3 storage address style",
|
||||
default="auto",
|
||||
description='S3 storage address style',
|
||||
default='auto',
|
||||
)
|
||||
|
||||
S3_USE_AWS_MANAGED_IAM: bool = Field(
|
||||
description="whether to use aws managed IAM for S3",
|
||||
description='whether to use aws managed IAM for S3',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
|
||||
description="Azure Blob account name",
|
||||
description='Azure Blob account name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
|
||||
description="Azure Blob account key",
|
||||
description='Azure Blob account key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
|
||||
description="Azure Blob container name",
|
||||
description='Azure Blob container name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
|
||||
description="Azure Blob account URL",
|
||||
description='Azure Blob account URL',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Google Cloud storage bucket name",
|
||||
description='Google Cloud storage bucket name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
|
||||
description="Google Cloud storage service account json base64",
|
||||
description='Google Cloud storage service account json base64',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,27 @@ class OCIStorageConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
OCI_ENDPOINT: Optional[str] = Field(
|
||||
description="OCI storage endpoint",
|
||||
description='OCI storage endpoint',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_REGION: Optional[str] = Field(
|
||||
description="OCI storage region",
|
||||
description='OCI storage region',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_BUCKET_NAME: Optional[str] = Field(
|
||||
description="OCI storage bucket name",
|
||||
description='OCI storage bucket name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_ACCESS_KEY: Optional[str] = Field(
|
||||
description="OCI storage access key",
|
||||
description='OCI storage access key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_SECRET_KEY: Optional[str] = Field(
|
||||
description="OCI storage secret key",
|
||||
description='OCI storage secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Tencent Cloud COS bucket name",
|
||||
description='Tencent Cloud COS bucket name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_REGION: Optional[str] = Field(
|
||||
description="Tencent Cloud COS region",
|
||||
description='Tencent Cloud COS region',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_ID: Optional[str] = Field(
|
||||
description="Tencent Cloud COS secret id",
|
||||
description='Tencent Cloud COS secret id',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Tencent Cloud COS secret key",
|
||||
description='Tencent Cloud COS secret key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SCHEME: Optional[str] = Field(
|
||||
description="Tencent Cloud COS scheme",
|
||||
description='Tencent Cloud COS scheme',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,28 +10,35 @@ class AnalyticdbConfig(BaseModel):
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID: Optional[str] = Field(
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID: Optional[str] = Field(
|
||||
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_ACCOUNT: Optional[str] = Field(
|
||||
default=None, description="The account name used to log in to the AnalyticDB instance."
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password associated with the AnalyticDB account for authentication."
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE: Optional[str] = Field(
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation."
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
|
||||
@@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
CHROMA_HOST: Optional[str] = Field(
|
||||
description="Chroma host",
|
||||
description='Chroma host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_PORT: PositiveInt = Field(
|
||||
description="Chroma port",
|
||||
description='Chroma port',
|
||||
default=8000,
|
||||
)
|
||||
|
||||
CHROMA_TENANT: Optional[str] = Field(
|
||||
description="Chroma database",
|
||||
description='Chroma database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_DATABASE: Optional[str] = Field(
|
||||
description="Chroma database",
|
||||
description='Chroma database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
|
||||
description="Chroma authentication provider",
|
||||
description='Chroma authentication provider',
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
|
||||
description="Chroma authentication credentials",
|
||||
description='Chroma authentication credentials',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ElasticsearchConfig(BaseSettings):
|
||||
"""
|
||||
Elasticsearch configs
|
||||
"""
|
||||
|
||||
ELASTICSEARCH_HOST: Optional[str] = Field(
|
||||
description="Elasticsearch host",
|
||||
default="127.0.0.1",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PORT: PositiveInt = Field(
|
||||
description="Elasticsearch port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_USERNAME: Optional[str] = Field(
|
||||
description="Elasticsearch username",
|
||||
default="elastic",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description="Elasticsearch password",
|
||||
default="elastic",
|
||||
)
|
||||
@@ -10,31 +10,31 @@ class MilvusConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description="Milvus host",
|
||||
description='Milvus host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description="Milvus RestFul API port",
|
||||
description='Milvus RestFul API port',
|
||||
default=9091,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
description="Milvus user",
|
||||
description='Milvus user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PASSWORD: Optional[str] = Field(
|
||||
description="Milvus password",
|
||||
description='Milvus password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for Milvus",
|
||||
description='whether to use SSL connection for Milvus',
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
description='Milvus database, default to `default`',
|
||||
default='default',
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
@@ -7,31 +8,31 @@ class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: str = Field(
|
||||
description="MyScale host",
|
||||
default="localhost",
|
||||
description='MyScale host',
|
||||
default='localhost',
|
||||
)
|
||||
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
description="MyScale port",
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: str = Field(
|
||||
description="MyScale user",
|
||||
default="default",
|
||||
description='MyScale user',
|
||||
default='default',
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
description="MyScale password",
|
||||
default="",
|
||||
description='MyScale password',
|
||||
default='',
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
description="MyScale database name",
|
||||
default="default",
|
||||
description='MyScale database name',
|
||||
default='default',
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
description="MyScale fts index parameters",
|
||||
default="",
|
||||
description='MyScale fts index parameters',
|
||||
default='',
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
OPENSEARCH_HOST: Optional[str] = Field(
|
||||
description="OpenSearch host",
|
||||
description='OpenSearch host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PORT: PositiveInt = Field(
|
||||
description="OpenSearch port",
|
||||
description='OpenSearch port',
|
||||
default=9200,
|
||||
)
|
||||
|
||||
OPENSEARCH_USER: Optional[str] = Field(
|
||||
description="OpenSearch user",
|
||||
description='OpenSearch user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description="OpenSearch password",
|
||||
description='OpenSearch password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for OpenSearch",
|
||||
description='whether to use SSL connection for OpenSearch',
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
ORACLE_HOST: Optional[str] = Field(
|
||||
description="ORACLE host",
|
||||
description='ORACLE host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
description="ORACLE port",
|
||||
description='ORACLE port',
|
||||
default=1521,
|
||||
)
|
||||
|
||||
ORACLE_USER: Optional[str] = Field(
|
||||
description="ORACLE user",
|
||||
description='ORACLE user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PASSWORD: Optional[str] = Field(
|
||||
description="ORACLE password",
|
||||
description='ORACLE password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_DATABASE: Optional[str] = Field(
|
||||
description="ORACLE database",
|
||||
description='ORACLE database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
PGVECTOR_HOST: Optional[str] = Field(
|
||||
description="PGVector host",
|
||||
description='PGVector host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description="PGVector port",
|
||||
description='PGVector port',
|
||||
default=5433,
|
||||
)
|
||||
|
||||
PGVECTOR_USER: Optional[str] = Field(
|
||||
description="PGVector user",
|
||||
description='PGVector user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PASSWORD: Optional[str] = Field(
|
||||
description="PGVector password",
|
||||
description='PGVector password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_DATABASE: Optional[str] = Field(
|
||||
description="PGVector database",
|
||||
description='PGVector database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
PGVECTO_RS_HOST: Optional[str] = Field(
|
||||
description="PGVectoRS host",
|
||||
description='PGVectoRS host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
description="PGVectoRS port",
|
||||
description='PGVectoRS port',
|
||||
default=5431,
|
||||
)
|
||||
|
||||
PGVECTO_RS_USER: Optional[str] = Field(
|
||||
description="PGVectoRS user",
|
||||
description='PGVectoRS user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PASSWORD: Optional[str] = Field(
|
||||
description="PGVectoRS password",
|
||||
description='PGVectoRS password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_DATABASE: Optional[str] = Field(
|
||||
description="PGVectoRS database",
|
||||
description='PGVectoRS database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
QDRANT_URL: Optional[str] = Field(
|
||||
description="Qdrant url",
|
||||
description='Qdrant url',
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_API_KEY: Optional[str] = Field(
|
||||
description="Qdrant api key",
|
||||
description='Qdrant api key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Qdrant client timeout in seconds",
|
||||
description='Qdrant client timeout in seconds',
|
||||
default=20,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_ENABLED: bool = Field(
|
||||
description="whether enable grpc support for Qdrant connection",
|
||||
description='whether enable grpc support for Qdrant connection',
|
||||
default=False,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||
description="Qdrant grpc port",
|
||||
description='Qdrant grpc port',
|
||||
default=6334,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
RELYT_HOST: Optional[str] = Field(
|
||||
description="Relyt host",
|
||||
description='Relyt host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PORT: PositiveInt = Field(
|
||||
description="Relyt port",
|
||||
description='Relyt port',
|
||||
default=9200,
|
||||
)
|
||||
|
||||
RELYT_USER: Optional[str] = Field(
|
||||
description="Relyt user",
|
||||
description='Relyt user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PASSWORD: Optional[str] = Field(
|
||||
description="Relyt password",
|
||||
description='Relyt password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_DATABASE: Optional[str] = Field(
|
||||
description="Relyt database",
|
||||
default="default",
|
||||
description='Relyt database',
|
||||
default='default',
|
||||
)
|
||||
|
||||
@@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
|
||||
description="Tencent Vector URL",
|
||||
description='Tencent Vector URL',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description="Tencent Vector API key",
|
||||
description='Tencent Vector API key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
|
||||
description="Tencent Vector timeout in seconds",
|
||||
description='Tencent Vector timeout in seconds',
|
||||
default=30,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
|
||||
description="Tencent Vector username",
|
||||
description='Tencent Vector username',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
|
||||
description="Tencent Vector password",
|
||||
description='Tencent Vector password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description="Tencent Vector sharding number",
|
||||
description='Tencent Vector sharding number',
|
||||
default=1,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description="Tencent Vector replicas",
|
||||
description='Tencent Vector replicas',
|
||||
default=2,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description="Tencent Vector Database",
|
||||
description='Tencent Vector Database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
TIDB_VECTOR_HOST: Optional[str] = Field(
|
||||
description="TiDB Vector host",
|
||||
description='TiDB Vector host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description="TiDB Vector port",
|
||||
description='TiDB Vector port',
|
||||
default=4000,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_USER: Optional[str] = Field(
|
||||
description="TiDB Vector user",
|
||||
description='TiDB Vector user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
|
||||
description="TiDB Vector password",
|
||||
description='TiDB Vector password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_DATABASE: Optional[str] = Field(
|
||||
description="TiDB Vector database",
|
||||
description='TiDB Vector database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
WEAVIATE_ENDPOINT: Optional[str] = Field(
|
||||
description="Weaviate endpoint URL",
|
||||
description='Weaviate endpoint URL',
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_API_KEY: Optional[str] = Field(
|
||||
description="Weaviate API key",
|
||||
description='Weaviate API key',
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description="whether to enable gRPC for Weaviate connection",
|
||||
description='whether to enable gRPC for Weaviate connection',
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Weaviate batch size",
|
||||
description='Weaviate batch size',
|
||||
default=100,
|
||||
)
|
||||
|
||||
@@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
|
||||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.8.0-beta1",
|
||||
description='Dify version',
|
||||
default='0.7.1',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
description="SHA-1 checksum of the git commit used to build the app",
|
||||
default="",
|
||||
default='',
|
||||
)
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
bp = Blueprint('console', __name__, url_prefix='/console/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
|
||||
@@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
|
||||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not os.getenv("ADMIN_API_KEY"):
|
||||
raise Unauthorized("API key is invalid.")
|
||||
if not os.getenv('ADMIN_API_KEY'):
|
||||
raise Unauthorized('API key is invalid.')
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header is None:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
raise Unauthorized('Authorization header is missing.')
|
||||
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
if os.getenv("ADMIN_API_KEY") != auth_token:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
if os.getenv('ADMIN_API_KEY') != auth_token:
|
||||
raise Unauthorized('API key is invalid.')
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -44,41 +44,37 @@ class InsertExploreAppListApi(Resource):
|
||||
@admin_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("desc", type=str, location="json")
|
||||
parser.add_argument("copyright", type=str, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('desc', type=str, location='json')
|
||||
parser.add_argument('copyright', type=str, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, location='json')
|
||||
parser.add_argument('custom_disclaimer', type=str, location='json')
|
||||
parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('position', type=int, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
app = App.query.filter(App.id == args['app_id']).first()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] if args["desc"] else ""
|
||||
copy_right = args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||
desc = args['desc'] if args['desc'] else ''
|
||||
copy_right = args['copyright'] if args['copyright'] else ''
|
||||
privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
|
||||
custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
|
||||
else:
|
||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = (
|
||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
)
|
||||
custom_disclaimer = (
|
||||
site.custom_disclaimer
|
||||
if site.custom_disclaimer
|
||||
else args["custom_disclaimer"]
|
||||
if args["custom_disclaimer"]
|
||||
else ""
|
||||
)
|
||||
desc = site.description if site.description else \
|
||||
args['desc'] if args['desc'] else ''
|
||||
copy_right = site.copyright if site.copyright else \
|
||||
args['copyright'] if args['copyright'] else ''
|
||||
privacy_policy = site.privacy_policy if site.privacy_policy else \
|
||||
args['privacy_policy'] if args['privacy_policy'] else ''
|
||||
custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
|
||||
args['custom_disclaimer'] if args['custom_disclaimer'] else ''
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
@@ -87,9 +83,9 @@ class InsertExploreAppListApi(Resource):
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
language=args['language'],
|
||||
category=args['category'],
|
||||
position=args['position']
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
@@ -97,21 +93,21 @@ class InsertExploreAppListApi(Resource):
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
return {'result': 'success'}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
recommended_app.language = args['language']
|
||||
recommended_app.category = args['category']
|
||||
recommended_app.position = args['position']
|
||||
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class InsertExploreAppApi(Resource):
|
||||
@@ -120,14 +116,15 @@ class InsertExploreAppApi(Resource):
|
||||
def delete(self, app_id):
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
installed_apps = InstalledApp.query.filter(
|
||||
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
@@ -136,8 +133,8 @@ class InsertExploreAppApi(Resource):
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
|
||||
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")
|
||||
api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
|
||||
api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')
|
||||
|
||||
@@ -14,21 +14,26 @@ from .setup import setup_required
|
||||
from .wraps import account_initialization_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"token": fields.String,
|
||||
"last_used_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
'id': fields.String,
|
||||
'type': fields.String,
|
||||
'token': fields.String,
|
||||
'last_used_at': TimestampField,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||
api_key_list = {
|
||||
'data': fields.List(fields.Nested(api_key_fields), attribute="items")
|
||||
}
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||
resource = resource_model.query.filter_by(
|
||||
id=resource_id, tenant_id=tenant_id
|
||||
).first()
|
||||
|
||||
if resource is None:
|
||||
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
flask_restful.abort(
|
||||
404, message=f"{resource_model.__name__} not found.")
|
||||
|
||||
return resource
|
||||
|
||||
@@ -45,32 +50,30 @@ class BaseApiKeyListResource(Resource):
|
||||
@marshal_with(api_key_list)
|
||||
def get(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.all()
|
||||
)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
keys = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
|
||||
all()
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.count()
|
||||
)
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
|
||||
count()
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restful.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
code='max_keys_exceeded'
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
@@ -94,78 +97,79 @@ class BaseApiKeyResource(Resource):
|
||||
def delete(self, resource_id, api_key_id):
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
_get_resource(resource_id, current_user.current_tenant_id,
|
||||
self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
key = db.session.query(ApiToken). \
|
||||
filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
|
||||
first()
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message="API key not found")
|
||||
flask_restful.abort(404, message='API key not found')
|
||||
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = 'app'
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
resource_id_field = 'app_id'
|
||||
token_prefix = 'app-'
|
||||
|
||||
|
||||
class AppApiKeyResource(BaseApiKeyResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = 'app'
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
resource_id_field = 'app_id'
|
||||
|
||||
|
||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = 'dataset'
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
resource_id_field = 'dataset_id'
|
||||
token_prefix = 'ds-'
|
||||
|
||||
|
||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
resp.headers['Access-Control-Allow-Origin'] = '*'
|
||||
resp.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = 'dataset'
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
resource_id_field = 'dataset_id'
|
||||
|
||||
|
||||
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
|
||||
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
|
||||
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
|
||||
api.add_resource(AppApiKeyResource,
|
||||
'/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiKeyListResource,
|
||||
'/datasets/<uuid:resource_id>/api-keys')
|
||||
api.add_resource(DatasetApiKeyResource,
|
||||
'/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
|
||||
|
||||
@@ -8,18 +8,19 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
|
||||
|
||||
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")
|
||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
||||
@@ -18,12 +18,15 @@ class AgentLogApi(Resource):
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
parser.add_argument('message_id', type=uuid_value, required=True, location='args')
|
||||
parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
|
||||
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs")
|
||||
return AgentService.get_agent_logs(
|
||||
app_model,
|
||||
args['conversation_id'],
|
||||
args['message_id']
|
||||
)
|
||||
|
||||
api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')
|
||||
@@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
if action == "enable":
|
||||
if action == 'enable':
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == "disable":
|
||||
elif action == 'disable':
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
else:
|
||||
raise ValueError("Unsupported annotation reply action")
|
||||
raise ValueError('Unsupported annotation reply action')
|
||||
return result, 200
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
@@ -77,24 +77,28 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@@ -105,18 +109,18 @@ class AnnotationListApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
'data': marshal(annotation_list, annotation_fields),
|
||||
'has_more': len(annotation_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@@ -131,7 +135,9 @@ class AnnotationExportApi(Resource):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields)
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
@@ -139,7 +145,7 @@ class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
@@ -147,8 +153,8 @@ class AnnotationCreateApi(Resource):
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
@@ -158,7 +164,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
@@ -167,8 +173,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
@@ -183,29 +189,29 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
@@ -214,23 +220,27 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id))
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@@ -241,32 +251,30 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
||||
page, limit)
|
||||
response = {
|
||||
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
"has_more": len(annotation_hit_history_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
'has_more': len(annotation_hit_history_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
api.add_resource(
|
||||
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
|
||||
)
|
||||
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
|
||||
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
|
||||
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting")
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
||||
api.add_resource(AnnotationReplyActionStatusApi,
|
||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
||||
|
||||
@@ -18,35 +18,27 @@ from libs.login import login_required
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
|
||||
|
||||
class AppListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=["chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -54,7 +46,7 @@ class AppListApi(Resource):
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
|
||||
@@ -62,23 +54,23 @@ class AppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
if 'mode' not in args or args['mode'] is None:
|
||||
raise BadRequest("mode is required")
|
||||
|
||||
app_service = AppService()
|
||||
@@ -92,7 +84,7 @@ class AppImportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Import app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@@ -100,16 +92,19 @@ class AppImportApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=args['data'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
@@ -120,7 +115,7 @@ class AppImportFromUrlApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@@ -128,21 +123,25 @@ class AppImportFromUrlApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
url=args['url'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -166,15 +165,14 @@ class AppApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@@ -195,7 +193,7 @@ class AppApi(Resource):
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class AppCopyApi(Resource):
|
||||
@@ -211,16 +209,19 @@ class AppCopyApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
@@ -239,10 +240,12 @@ class AppExportApi(Resource):
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
return {
|
||||
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
|
||||
}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
@@ -255,13 +258,13 @@ class AppNameApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.get("name"))
|
||||
app_model = app_service.update_app_name(app_model, args.get('name'))
|
||||
|
||||
return app_model
|
||||
|
||||
@@ -276,14 +279,14 @@ class AppIconApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background"))
|
||||
app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
|
||||
|
||||
return app_model
|
||||
|
||||
@@ -298,13 +301,13 @@ class AppSiteStatus(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
||||
parser.add_argument('enable_site', type=bool, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.get("enable_site"))
|
||||
app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
|
||||
|
||||
return app_model
|
||||
|
||||
@@ -319,13 +322,13 @@ class AppApiStatus(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
||||
parser.add_argument('enable_api', type=bool, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.get("enable_api"))
|
||||
app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
|
||||
|
||||
return app_model
|
||||
|
||||
@@ -336,7 +339,9 @@ class AppTraceApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
"""Get app trace"""
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(
|
||||
app_id=app_id
|
||||
)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@@ -348,27 +353,27 @@ class AppTraceApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument('enabled', type=bool, required=True, location='json')
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
enabled=args['enabled'],
|
||||
tracing_provider=args['tracing_provider'],
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(AppListApi, "/apps")
|
||||
api.add_resource(AppImportApi, "/apps/import")
|
||||
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
|
||||
api.add_resource(AppApi, "/apps/<uuid:app_id>")
|
||||
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
|
||||
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")
|
||||
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name")
|
||||
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon")
|
||||
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable")
|
||||
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable")
|
||||
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace")
|
||||
api.add_resource(AppListApi, '/apps')
|
||||
api.add_resource(AppImportApi, '/apps/import')
|
||||
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
|
||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
|
||||
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
|
||||
api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
|
||||
api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
|
||||
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
|
||||
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
|
||||
api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')
|
||||
|
||||
@@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model):
|
||||
file = request.files["file"]
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
@@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
parser.add_argument('message_id', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=text,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
@@ -145,12 +145,12 @@ class TextModesApi(Resource):
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("language", type=str, required=True, location="args")
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args["language"],
|
||||
language=args['language'],
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -179,6 +179,6 @@ class TextModesApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text")
|
||||
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio")
|
||||
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
|
||||
@@ -17,7 +17,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@@ -32,33 +31,37 @@ from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
class CompletionMessageApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -94,7 +97,7 @@ class CompletionMessageStopApi(Resource):
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatMessageApi(Resource):
|
||||
@@ -104,23 +107,27 @@ class ChatMessageApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
app_model=app_model,
|
||||
user=account,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -137,8 +144,6 @@ class ChatMessageApi(Resource):
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
@@ -158,10 +163,10 @@ class ChatMessageStopApi(Resource):
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages")
|
||||
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages")
|
||||
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
|
||||
api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
|
||||
api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
|
||||
api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')
|
||||
|
||||
@@ -26,6 +26,7 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -35,23 +36,24 @@ class CompletionConversationApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('annotation_status', type=str,
|
||||
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
|
||||
|
||||
if args["keyword"]:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
|
||||
if args['keyword']:
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id
|
||||
).filter(
|
||||
or_(
|
||||
Message.query.ilike("%{}%".format(args["keyword"])),
|
||||
Message.answer.ilike("%{}%".format(args["keyword"])),
|
||||
Message.query.ilike('%{}%'.format(args['keyword'])),
|
||||
Message.answer.ilike('%{}%'.format(args['keyword']))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -59,8 +61,8 @@ class CompletionConversationApi(Resource):
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
@@ -68,8 +70,8 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
@@ -77,25 +79,29 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args['annotation_status'] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
elif args['annotation_status'] == "not_annotated":
|
||||
query = query.outerjoin(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -117,11 +123,8 @@ class CompletionConversationDetailApi(Resource):
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@@ -129,10 +132,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class ChatConversationApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -142,28 +146,22 @@ class ChatConversationApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('annotation_status', type=str,
|
||||
choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
|
||||
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
Conversation.id.label('conversation_id'),
|
||||
EndUser.session_id.label('from_end_user_session_id')
|
||||
)
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
@@ -171,31 +169,28 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id)
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = "%{}%".format(args["keyword"])
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
Message.conversation_id == Conversation.id,
|
||||
)
|
||||
.join(subquery, subquery.c.conversation_id == Conversation.id)
|
||||
.filter(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
),
|
||||
)
|
||||
if args['keyword']:
|
||||
keyword_filter = '%{}%'.format(args['keyword'])
|
||||
query = query.join(
|
||||
Message, Message.conversation_id == Conversation.id,
|
||||
).join(
|
||||
subquery, subquery.c.conversation_id == Conversation.id
|
||||
).filter(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter)
|
||||
),
|
||||
)
|
||||
|
||||
account = current_user
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
@@ -203,8 +198,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
@@ -212,46 +207,50 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args['annotation_status'] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
elif args['annotation_status'] == "not_annotated":
|
||||
query = query.outerjoin(
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
|
||||
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
if args['message_count_gte'] and args['message_count_gte'] >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages))
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
.having(func.count(Message.id) >= args['message_count_gte'])
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
match args['sort_by']:
|
||||
case 'created_at':
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
case '-created_at':
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
case "updated_at":
|
||||
case 'updated_at':
|
||||
query = query.order_by(Conversation.updated_at.asc())
|
||||
case "-updated_at":
|
||||
case '-updated_at':
|
||||
query = query.order_by(Conversation.updated_at.desc())
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
class ChatConversationDetailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -273,11 +272,8 @@ class ChatConversationDetailApi(Resource):
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@@ -285,21 +281,18 @@ class ChatConversationDetailApi(Resource):
|
||||
conversation.is_deleted = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations")
|
||||
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations")
|
||||
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
|
||||
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
|
||||
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
|
||||
api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation) \
|
||||
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
|
||||
@marshal_with(paginated_conversation_variable_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", type=str, location="args")
|
||||
parser.add_argument('conversation_id', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
stmt = (
|
||||
@@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
if args['conversation_id']:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
raise ValueError('conversation_id is required')
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
@@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"total": len(rows),
|
||||
"has_more": False,
|
||||
"data": [
|
||||
'page': page,
|
||||
'limit': page_size,
|
||||
'total': len(rows),
|
||||
'has_more': False,
|
||||
'data': [
|
||||
{
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
'created_at': row.created_at,
|
||||
'updated_at': row.updated_at,
|
||||
**row.to_variable().model_dump(),
|
||||
}
|
||||
for row in rows
|
||||
@@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables")
|
||||
api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')
|
||||
|
||||
@@ -2,128 +2,116 @@ from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class AppNotFoundError(BaseHTTPException):
|
||||
error_code = "app_not_found"
|
||||
error_code = 'app_not_found'
|
||||
description = "App not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class ProviderNotInitializeError(BaseHTTPException):
|
||||
error_code = "provider_not_initialize"
|
||||
description = (
|
||||
"No valid model provider credentials found. "
|
||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
error_code = 'provider_not_initialize'
|
||||
description = "No valid model provider credentials found. " \
|
||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = "provider_quota_exceeded"
|
||||
description = (
|
||||
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
error_code = 'provider_quota_exceeded'
|
||||
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||
error_code = "model_currently_not_support"
|
||||
error_code = 'model_currently_not_support'
|
||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||
code = 400
|
||||
|
||||
|
||||
class ConversationCompletedError(BaseHTTPException):
|
||||
error_code = "conversation_completed"
|
||||
error_code = 'conversation_completed'
|
||||
description = "The conversation has ended. Please start a new conversation."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppUnavailableError(BaseHTTPException):
|
||||
error_code = "app_unavailable"
|
||||
error_code = 'app_unavailable'
|
||||
description = "App unavailable, please check your app configurations."
|
||||
code = 400
|
||||
|
||||
|
||||
class CompletionRequestError(BaseHTTPException):
|
||||
error_code = "completion_request_error"
|
||||
error_code = 'completion_request_error'
|
||||
description = "Completion request failed."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppMoreLikeThisDisabledError(BaseHTTPException):
|
||||
error_code = "app_more_like_this_disabled"
|
||||
error_code = 'app_more_like_this_disabled'
|
||||
description = "The 'More like this' feature is disabled. Please refresh your page."
|
||||
code = 403
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = "no_audio_uploaded"
|
||||
error_code = 'no_audio_uploaded'
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = "audio_too_large"
|
||||
error_code = 'audio_too_large'
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_audio_type"
|
||||
error_code = 'unsupported_audio_type'
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = "provider_not_support_speech_to_text"
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_exist"
|
||||
error_code = 'draft_workflow_not_exist'
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_sync"
|
||||
error_code = 'draft_workflow_not_sync'
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigNotExist(BaseHTTPException):
|
||||
error_code = "trace_config_not_exist"
|
||||
error_code = 'trace_config_not_exist'
|
||||
description = "Trace config not exist."
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigIsExist(BaseHTTPException):
|
||||
error_code = "trace_config_is_exist"
|
||||
error_code = 'trace_config_is_exist'
|
||||
description = "Trace config is exist."
|
||||
code = 400
|
||||
|
||||
|
||||
class TracingConfigCheckError(BaseHTTPException):
|
||||
error_code = "trace_config_check_error"
|
||||
error_code = 'trace_config_check_error'
|
||||
description = "Invalid Credentials."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvokeRateLimitError(BaseHTTPException):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
@@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))
|
||||
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
|
||||
instruction=args['instruction'],
|
||||
model_config=args['model_config'],
|
||||
no_variable=args['no_variable'],
|
||||
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
||||
|
||||
@@ -33,9 +33,9 @@ from services.message_service import MessageService
|
||||
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_detail_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@@ -45,69 +45,55 @@ class ChatMessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.id == args['conversation_id'],
|
||||
Conversation.app_id == app_model.id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args["first_id"]:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.first()
|
||||
)
|
||||
if args['first_id']:
|
||||
first_message = db.session.query(Message) \
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
history_messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id
|
||||
) \
|
||||
.order_by(Message.created_at.desc()).limit(args['limit']).all()
|
||||
else:
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.desc()).limit(args['limit']).all()
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == args["limit"]:
|
||||
if len(history_messages) == args['limit']:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
rest_count = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=args['limit'],
|
||||
has_more=has_more
|
||||
)
|
||||
|
||||
|
||||
class MessageFeedbackApi(Resource):
|
||||
@@ -117,46 +103,49 @@ class MessageFeedbackApi(Resource):
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
message_id = str(args['message_id'])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
if not args['rating'] and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
elif args['rating'] and feedback:
|
||||
feedback.rating = args['rating']
|
||||
elif not args['rating'] and not feedback:
|
||||
raise ValueError('rating cannot be None when feedback not exists')
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
rating=args['rating'],
|
||||
from_source='admin',
|
||||
from_account_id=current_user.id
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageAnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@get_app_model
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model):
|
||||
@@ -164,10 +153,10 @@ class MessageAnnotationApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
@@ -180,9 +169,11 @@ class MessageAnnotationCountApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
|
||||
count = db.session.query(MessageAnnotation).filter(
|
||||
MessageAnnotation.app_id == app_model.id
|
||||
).count()
|
||||
|
||||
return {"count": count}
|
||||
return {'count': count}
|
||||
|
||||
|
||||
class MessageSuggestedQuestionApi(Resource):
|
||||
@@ -195,7 +186,10 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
invoke_from=InvokeFrom.DEBUGGER
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
@@ -215,7 +209,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
class MessageApi(Resource):
|
||||
@@ -227,7 +221,10 @@ class MessageApi(Resource):
|
||||
def get(self, app_model, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -235,9 +232,9 @@ class MessageApi(Resource):
|
||||
return message
|
||||
|
||||
|
||||
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages")
|
||||
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks")
|
||||
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count")
|
||||
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message")
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
|
||||
api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
|
||||
api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
|
||||
api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')
|
||||
|
||||
@@ -19,35 +19,37 @@ from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class ModelConfigResource(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
|
||||
"""Modify app model config"""
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
config=request.json,
|
||||
app_mode=AppMode.value_of(app_model.mode)
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app_model.app_model_config_id
|
||||
).first()
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
@@ -64,7 +66,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
@@ -77,18 +79,18 @@ class ModelConfigResource(Resource):
|
||||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
@@ -106,7 +108,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
@@ -114,17 +116,15 @@ class ModelConfigResource(Resource):
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
|
||||
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||
if (
|
||||
masked_key in agent_tool_entity.tool_parameters
|
||||
and agent_tool_entity.tool_parameters[masked_key] == masked_value
|
||||
):
|
||||
if masked_key in agent_tool_entity.tool_parameters and \
|
||||
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
||||
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
@@ -135,9 +135,12 @@ class ModelConfigResource(Resource):
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
app_model_config_was_updated.send(
|
||||
app_model,
|
||||
app_model_config=new_app_model_config
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config")
|
||||
api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')
|
||||
|
||||
@@ -18,11 +18,13 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
trace_config = OpsService.get_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args['tracing_provider']
|
||||
)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
@@ -35,17 +37,19 @@ class TraceAppConfigApi(Resource):
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
||||
parser.add_argument('tracing_config', type=dict, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
app_id=app_id,
|
||||
tracing_provider=args['tracing_provider'],
|
||||
tracing_config=args['tracing_config']
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
if result.get("error"):
|
||||
if result.get('error'):
|
||||
raise TracingConfigCheckError()
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -57,13 +61,15 @@ class TraceAppConfigApi(Resource):
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='json')
|
||||
parser.add_argument('tracing_config', type=dict, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
||||
app_id=app_id,
|
||||
tracing_provider=args['tracing_provider'],
|
||||
tracing_config=args['tracing_config']
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@@ -77,11 +83,14 @@ class TraceAppConfigApi(Resource):
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser.add_argument('tracing_provider', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
||||
result = OpsService.delete_tracing_app_config(
|
||||
app_id=app_id,
|
||||
tracing_provider=args['tracing_provider']
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
@@ -89,4 +98,4 @@ class TraceAppConfigApi(Resource):
|
||||
raise e
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@@ -17,24 +15,23 @@ from models.model import Site
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("title", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
parser.add_argument(
|
||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
parser.add_argument('title', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_type', type=str, required=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, location='json')
|
||||
parser.add_argument('description', type=str, required=False, location='json')
|
||||
parser.add_argument('default_language', type=supported_language, required=False, location='json')
|
||||
parser.add_argument('chat_color_theme', type=str, required=False, location='json')
|
||||
parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
|
||||
parser.add_argument('customize_domain', type=str, required=False, location='json')
|
||||
parser.add_argument('copyright', type=str, required=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=False, location='json')
|
||||
parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
|
||||
parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
|
||||
required=False,
|
||||
location='json')
|
||||
parser.add_argument('prompt_public', type=bool, required=False, location='json')
|
||||
parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -51,38 +48,38 @@ class AppSite(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
|
||||
site = db.session.query(Site). \
|
||||
filter(Site.app_id == app_model.id). \
|
||||
one_or_404()
|
||||
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
'title',
|
||||
'icon_type',
|
||||
'icon',
|
||||
'icon_background',
|
||||
'description',
|
||||
'default_language',
|
||||
'chat_color_theme',
|
||||
'chat_color_theme_inverted',
|
||||
'customize_domain',
|
||||
'copyright',
|
||||
'privacy_policy',
|
||||
'custom_disclaimer',
|
||||
'customize_token_strategy',
|
||||
'prompt_public',
|
||||
'show_workflow_steps'
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
class AppSiteAccessTokenReset(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -99,12 +96,10 @@ class AppSiteAccessTokenReset(Resource):
|
||||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
api.add_resource(AppSite, "/apps/<uuid:app_id>/site")
|
||||
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset")
|
||||
api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
|
||||
api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')
|
||||
|
||||
@@ -16,61 +16,8 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "message_count": i.message_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -79,52 +26,58 @@ class DailyConversationStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'conversation_count': i.conversation_count
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -133,49 +86,54 @@ class DailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
@@ -187,53 +145,58 @@ class DailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
||||
sum(total_price) as total_price
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
|
||||
)
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
'total_price': i.total_price,
|
||||
'currency': 'USD'
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
@@ -245,8 +208,8 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
@@ -255,30 +218,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
FROM conversations c
|
||||
JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and c.created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and c.created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += """
|
||||
GROUP BY m.conversation_id) subquery
|
||||
@@ -287,15 +250,18 @@ GROUP BY date
|
||||
ORDER BY date"""
|
||||
|
||||
response_data = []
|
||||
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@@ -307,57 +273,57 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||
FROM messages m
|
||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
||||
WHERE m.app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and m.created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and m.created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
}
|
||||
)
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@@ -369,51 +335,56 @@ class AverageResponseTimeStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) as latency
|
||||
FROM messages
|
||||
WHERE app_id = :app_id
|
||||
"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'latency': round(i.latency * 1000, 4)
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
@@ -425,59 +396,63 @@ class TokensPerSecondStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
END as tokens_per_second
|
||||
FROM messages
|
||||
WHERE app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
WHERE app_id = :app_id'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'tps': round(i.tokens_per_second, 4)
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
|
||||
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
|
||||
api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
|
||||
api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
|
||||
api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
|
||||
api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
|
||||
api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
|
||||
api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')
|
||||
|
||||
@@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get('Content-Type', '')
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
if 'application/json' in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_variables', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
data = json.loads(request.data.decode('utf-8'))
|
||||
if 'graph' not in data or 'features' not in data:
|
||||
raise ValueError('graph or features not found in data')
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
|
||||
raise ValueError('graph or features is not a dict')
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash'),
|
||||
'environment_variables': data.get('environment_variables'),
|
||||
'conversation_variables': data.get('conversation_variables'),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables_list = args.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables_list = args.get('conversation_variables') or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
features=args["features"],
|
||||
unique_hash=args.get("hash"),
|
||||
graph=args['graph'],
|
||||
features=args['features'],
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
@@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
|
||||
}
|
||||
|
||||
|
||||
@@ -138,11 +138,13 @@ class DraftWorkflowImportApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
app_model=app_model, data=args["data"], account=current_user
|
||||
app_model=app_model,
|
||||
data=args['data'],
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return workflow
|
||||
@@ -160,17 +162,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json', default='')
|
||||
parser.add_argument('files', type=list, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -184,7 +190,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -197,14 +202,18 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -218,7 +227,6 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -231,14 +239,18 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument('inputs', type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -252,7 +264,6 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -265,15 +276,19 @@ class DraftWorkflowRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -296,10 +311,12 @@ class WorkflowTaskStopApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
return {
|
||||
"result": "success"
|
||||
}
|
||||
|
||||
|
||||
class DraftWorkflowNodeRunApi(Resource):
|
||||
@@ -315,20 +332,24 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
|
||||
app_model=app_model,
|
||||
node_id=node_id,
|
||||
user_inputs=args.get('inputs'),
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
class PublishedWorkflowApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -341,7 +362,7 @@ class PublishedWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# fetch published workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
@@ -360,11 +381,14 @@ class PublishedWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
|
||||
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": TimestampField().format(workflow.created_at)
|
||||
}
|
||||
|
||||
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@@ -379,7 +403,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# Get default block configs
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.get_default_block_configs()
|
||||
@@ -397,21 +421,24 @@ class DefaultBlockConfigApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
parser.add_argument('q', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
filters = None
|
||||
if args.get("q"):
|
||||
if args.get('q'):
|
||||
try:
|
||||
filters = json.loads(args.get("q"))
|
||||
filters = json.loads(args.get('q'))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
raise ValueError('Invalid filters')
|
||||
|
||||
# Get default block configs
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
return workflow_service.get_default_block_config(
|
||||
node_type=block_type,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@@ -428,43 +455,41 @@ class ConvertToWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args)
|
||||
new_app_model = workflow_service.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
args=args
|
||||
)
|
||||
|
||||
# return app id
|
||||
return {
|
||||
"new_app_id": new_app_model.id,
|
||||
'new_app_id': new_app_model.id,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||
)
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
|
||||
api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
|
||||
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
|
||||
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
|
||||
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
|
||||
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
|
||||
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
|
||||
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
|
||||
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
|
||||
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
|
||||
'/<string:block_type>')
|
||||
api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
|
||||
|
||||
@@ -22,19 +22,20 @@ class WorkflowAppLogApi(Resource):
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
parser.add_argument('keyword', type=str, location='args')
|
||||
parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
app_model=app_model,
|
||||
args=args
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")
|
||||
api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')
|
||||
|
||||
@@ -28,12 +28,15 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -49,12 +52,15 @@ class WorkflowRunListApi(Resource):
|
||||
Get workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_workflow_runs(
|
||||
app_model=app_model,
|
||||
args=args
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -92,10 +98,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
workflow_run_service = WorkflowRunService()
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
|
||||
|
||||
return {"data": node_executions}
|
||||
return {
|
||||
'data': node_executions
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs")
|
||||
api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||
api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
|
||||
api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
|
||||
api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
|
||||
api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')
|
||||
|
||||
@@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'runs': i.runs
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'terminal_count': i.terminal_count
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@@ -146,63 +146,58 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
sql_query = '''
|
||||
SELECT
|
||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) as token_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
'''
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query += ' and created_at >= :start'
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query += ' and created_at < :end'
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += ' GROUP BY date order by date'
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'token_count': i.token_count,
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
@@ -213,8 +208,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@@ -234,54 +229,50 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
GROUP BY date, c.created_by) sub
|
||||
GROUP BY sub.date
|
||||
"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
if args['start']:
|
||||
start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
|
||||
arg_dict['start'] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
sql_query = sql_query.replace('{{start}}', '')
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
if args['end']:
|
||||
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
|
||||
arg_dict['end'] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
sql_query = sql_query.replace('{{end}}', '')
|
||||
|
||||
response_data = []
|
||||
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
response_data.append({
|
||||
'date': str(i.date),
|
||||
'interactions': float(i.interactions.quantize(Decimal('0.01')))
|
||||
})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
return jsonify({
|
||||
'data': response_data
|
||||
})
|
||||
|
||||
|
||||
api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
api.add_resource(
|
||||
WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
|
||||
)
|
||||
api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
|
||||
api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
|
||||
api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
|
||||
api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')
|
||||
|
||||
@@ -8,23 +8,24 @@ from libs.login import current_user
|
||||
from models.model import App, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
def get_app_model(view: Optional[Callable] = None, *,
|
||||
mode: Union[AppMode, list[AppMode]] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
if not kwargs.get('app_id'):
|
||||
raise ValueError('missing app_id in path parameters')
|
||||
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = kwargs.get('app_id')
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs["app_id"]
|
||||
del kwargs['app_id']
|
||||
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
app_model = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
@@ -43,10 +44,9 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
|
||||
@@ -17,61 +17,60 @@ from services.account_service import RegisterService
|
||||
class ActivateCheckApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
|
||||
parser.add_argument('email', type=email, required=False, nullable=True, location='args')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
token = args["token"]
|
||||
workspaceId = args['workspace_id']
|
||||
reg_email = args['email']
|
||||
token = args['token']
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
|
||||
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
|
||||
return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('email', type=email, required=False, nullable=True, location='json')
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
||||
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args["name"]
|
||||
account = invitation['account']
|
||||
account.name = args['name']
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args["password"], salt)
|
||||
password_hashed = hash_password(args['password'], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.interface_language = args['interface_language']
|
||||
account.timezone = args['timezone']
|
||||
account.interface_theme = 'light'
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, "/activate/check")
|
||||
api.add_resource(ActivateApi, "/activate")
|
||||
api.add_resource(ActivateCheckApi, '/activate/check')
|
||||
api.add_resource(ActivateApi, '/activate')
|
||||
|
||||
@@ -19,19 +19,18 @@ class ApiKeyAuthDataSource(Resource):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
{
|
||||
"id": data_source_api_key_binding.id,
|
||||
"category": data_source_api_key_binding.category,
|
||||
"provider": data_source_api_key_binding.provider,
|
||||
"disabled": data_source_api_key_binding.disabled,
|
||||
"created_at": int(data_source_api_key_binding.created_at.timestamp()),
|
||||
"updated_at": int(data_source_api_key_binding.updated_at.timestamp()),
|
||||
}
|
||||
for data_source_api_key_binding in data_source_api_key_bindings
|
||||
]
|
||||
'sources': [{
|
||||
'id': data_source_api_key_binding.id,
|
||||
'category': data_source_api_key_binding.category,
|
||||
'provider': data_source_api_key_binding.provider,
|
||||
'disabled': data_source_api_key_binding.disabled,
|
||||
'created_at': int(data_source_api_key_binding.created_at.timestamp()),
|
||||
'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
|
||||
}
|
||||
for data_source_api_key_binding in
|
||||
data_source_api_key_bindings]
|
||||
}
|
||||
return {"sources": []}
|
||||
return {'sources': []}
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@@ -43,16 +42,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@@ -66,9 +65,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
||||
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
|
||||
|
||||
@@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(
|
||||
client_id=dify_config.NOTION_CLIENT_ID,
|
||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
|
||||
)
|
||||
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
|
||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {"notion": notion_oauth}
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
@@ -37,16 +37,18 @@ class OAuthDataSource(Resource):
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == "internal":
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
|
||||
internal_secret = dify_config.NOTION_INTERNAL_SECRET
|
||||
if not internal_secret:
|
||||
return ({"error": "Internal secret is not set"},)
|
||||
return {'error': 'Internal secret is not set'},
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return {"data": ""}
|
||||
return { 'data': '' }
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return {"data": auth_url}, 200
|
||||
return { 'data': auth_url }, 200
|
||||
|
||||
|
||||
|
||||
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
@@ -55,18 +57,18 @@ class OAuthDataSourceCallback(Resource):
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if "code" in request.args:
|
||||
code = request.args.get("code")
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}")
|
||||
elif "error" in request.args:
|
||||
error = request.args.get("error")
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}")
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
|
||||
else:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
|
||||
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
|
||||
|
||||
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
def get(self, provider: str):
|
||||
@@ -74,18 +76,17 @@ class OAuthDataSourceBinding(Resource):
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
if "code" in request.args:
|
||||
code = request.args.get("code")
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}"
|
||||
)
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@@ -99,17 +100,18 @@ class OAuthDataSourceSync(Resource):
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
||||
|
||||
@@ -2,30 +2,31 @@ from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class ApiKeyAuthFailedError(BaseHTTPException):
|
||||
error_code = "auth_failed"
|
||||
error_code = 'auth_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class InvalidEmailError(BaseHTTPException):
|
||||
error_code = "invalid_email"
|
||||
error_code = 'invalid_email'
|
||||
description = "The email address is not valid."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordMismatchError(BaseHTTPException):
|
||||
error_code = "password_mismatch"
|
||||
error_code = 'password_mismatch'
|
||||
description = "The passwords do not match."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidTokenError(BaseHTTPException):
|
||||
error_code = "invalid_or_expired_token"
|
||||
error_code = 'invalid_or_expired_token'
|
||||
description = "The token is invalid or has expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "password_reset_rate_limit_exceeded"
|
||||
error_code = 'password_reset_rate_limit_exceeded'
|
||||
description = "Password reset rate limit exceeded. Try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
@@ -21,13 +21,14 @@ from services.errors.account import RateLimitExceededError
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument('email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
email = args["email"]
|
||||
email = args['email']
|
||||
|
||||
if not email_validate(email):
|
||||
raise InvalidEmailError()
|
||||
@@ -48,36 +49,38 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
token = args["token"]
|
||||
token = args['token']
|
||||
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
return {"is_valid": False, "email": None}
|
||||
return {"is_valid": True, "email": reset_data.get("email")}
|
||||
return {'is_valid': False, 'email': None}
|
||||
return {'is_valid': True, 'email': reset_data.get('email')}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
new_password = args["new_password"]
|
||||
password_confirm = args["password_confirm"]
|
||||
new_password = args['new_password']
|
||||
password_confirm = args['password_confirm']
|
||||
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
raise PasswordMismatchError()
|
||||
|
||||
token = args["token"]
|
||||
token = args['token']
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
@@ -91,14 +94,14 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
||||
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
||||
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
||||
|
||||
@@ -20,39 +20,37 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument('email', type=email, required=True, location='json')
|
||||
parser.add_argument('password', type=valid_password, required=True, location='json')
|
||||
parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: Verify the recaptcha
|
||||
|
||||
try:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
account = AccountService.authenticate(args['email'], args['password'])
|
||||
except services.errors.account.AccountLoginError as e:
|
||||
return {"code": "unauthorized", "message": str(e)}, 401
|
||||
return {'code': 'unauthorized', 'message': str(e)}, 401
|
||||
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
return {
|
||||
"result": "fail",
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
return {'result': 'success', 'data': token}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
||||
token = request.headers.get('Authorization', '').split(' ')[1]
|
||||
AccountService.logout(account=account, token=token)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ResetPasswordApi(Resource):
|
||||
@@ -82,11 +80,11 @@ class ResetPasswordApi(Resource):
|
||||
# 'subject': 'Reset your Dify password',
|
||||
# 'html': """
|
||||
# <p>Dear User,</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p><strong>{new_password}</strong></p>
|
||||
# <p>Please change your password to log in as soon as possible.</p>
|
||||
# <p>Regards,</p>
|
||||
# <p>The Dify Team</p>
|
||||
# <p>The Dify Team</p>
|
||||
# """
|
||||
# }
|
||||
|
||||
@@ -103,8 +101,8 @@ class ResetPasswordApi(Resource):
|
||||
# # handle error
|
||||
# pass
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(LoginApi, '/login')
|
||||
api.add_resource(LogoutApi, '/logout')
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_oauth_providers():
|
||||
github_oauth = GitHubOAuth(
|
||||
client_id=dify_config.GITHUB_CLIENT_ID,
|
||||
client_secret=dify_config.GITHUB_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
|
||||
)
|
||||
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
|
||||
google_oauth = None
|
||||
@@ -33,10 +33,10 @@ def get_oauth_providers():
|
||||
google_oauth = GoogleOAuth(
|
||||
client_id=dify_config.GOOGLE_CLIENT_ID,
|
||||
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
|
||||
)
|
||||
|
||||
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
|
||||
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class OAuthLogin(Resource):
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
@@ -59,20 +59,20 @@ class OAuthCallback(Resource):
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
|
||||
code = request.args.get("code")
|
||||
code = request.args.get('code')
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
|
||||
return {'error': 'OAuth process failed'}, 400
|
||||
|
||||
account = _generate_account(provider, user_info)
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
return {'error': 'Account is banned or closed.'}, 403
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
@@ -83,7 +83,7 @@ class OAuthCallback(Resource):
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
|
||||
if not account:
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else "Dify"
|
||||
account_name = user_info.name if user_info.name else 'Dify'
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
@@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
return account
|
||||
|
||||
|
||||
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
|
||||
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
|
||||
api.add_resource(OAuthLogin, '/oauth/login/<provider>')
|
||||
api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')
|
||||
|
||||
@@ -9,24 +9,28 @@ from services.billing_service import BillingService
|
||||
|
||||
|
||||
class Subscription(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
|
||||
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
|
||||
args = parser.parse_args()
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
)
|
||||
return BillingService.get_subscription(args['plan'],
|
||||
args['interval'],
|
||||
current_user.email,
|
||||
current_user.current_tenant_id)
|
||||
|
||||
|
||||
class Invoices(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -36,5 +40,5 @@ class Invoices(Resource):
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, "/billing/subscription")
|
||||
api.add_resource(Invoices, "/billing/invoices")
|
||||
api.add_resource(Subscription, '/billing/subscription')
|
||||
api.add_resource(Invoices, '/billing/invoices')
|
||||
|
||||
@@ -22,22 +22,19 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False
|
||||
).all()
|
||||
|
||||
base_url = request.url_root.rstrip("/")
|
||||
base_url = request.url_root.rstrip('/')
|
||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||
providers = ["notion"]
|
||||
|
||||
@@ -47,30 +44,26 @@ class DataSourceApi(Resource):
|
||||
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
|
||||
if existing_integrates:
|
||||
for existing_integrate in list(existing_integrates):
|
||||
integrate_data.append(
|
||||
{
|
||||
"id": existing_integrate.id,
|
||||
"provider": provider,
|
||||
"created_at": existing_integrate.created_at,
|
||||
"is_bound": True,
|
||||
"disabled": existing_integrate.disabled,
|
||||
"source_info": existing_integrate.source_info,
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
integrate_data.append({
|
||||
'id': existing_integrate.id,
|
||||
'provider': provider,
|
||||
'created_at': existing_integrate.created_at,
|
||||
'is_bound': True,
|
||||
'disabled': existing_integrate.disabled,
|
||||
'source_info': existing_integrate.source_info,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
else:
|
||||
integrate_data.append(
|
||||
{
|
||||
"id": None,
|
||||
"provider": provider,
|
||||
"created_at": None,
|
||||
"source_info": None,
|
||||
"is_bound": False,
|
||||
"disabled": None,
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
return {"data": integrate_data}, 200
|
||||
integrate_data.append({
|
||||
'id': None,
|
||||
'provider': provider,
|
||||
'created_at': None,
|
||||
'source_info': None,
|
||||
'is_bound': False,
|
||||
'disabled': None,
|
||||
'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
|
||||
})
|
||||
return {'data': integrate_data}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -78,82 +71,92 @@ class DataSourceApi(Resource):
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(
|
||||
id=binding_id
|
||||
).first()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
raise NotFound('Data source binding not found.')
|
||||
# enable binding
|
||||
if action == "enable":
|
||||
if action == 'enable':
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
raise ValueError('Data source is not disabled.')
|
||||
# disable binding
|
||||
if action == "disable":
|
||||
if action == 'disable':
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
raise ValueError('Data source is disabled.')
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DataSourceNotionListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
dataset_id = request.args.get('dataset_id', default=None, type=str)
|
||||
exist_page_ids = []
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
raise NotFound('Dataset not found.')
|
||||
if dataset.data_source_type != 'notion_import':
|
||||
raise ValueError('Dataset is not notion type.')
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
data_source_type='notion_import',
|
||||
enabled=True
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='notion',
|
||||
disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
return {
|
||||
'notion_info': []
|
||||
}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
pages = source_info['pages']
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
if page['page_id'] in exist_page_ids:
|
||||
page['is_bound'] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
page['is_bound'] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
'workspace_name': source_info['workspace_name'],
|
||||
'workspace_icon': source_info['workspace_icon'],
|
||||
'workspace_id': source_info['workspace_id'],
|
||||
'pages': pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
return {
|
||||
'notion_info': pre_import_info_list
|
||||
}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -163,67 +166,64 @@ class DataSourceNotionApi(Resource):
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
raise NotFound('Data source binding not found.')
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
|
||||
return {
|
||||
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
notion_info_list = args["notion_info_list"]
|
||||
notion_info_list = args['notion_info_list']
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args["doc_form"],
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'])
|
||||
return response, 200
|
||||
|
||||
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -240,6 +240,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
|
||||
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -257,14 +258,10 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
||||
api.add_resource(
|
||||
DataSourceNotionApi,
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
||||
api.add_resource(
|
||||
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
||||
)
|
||||
api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
|
||||
api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
|
||||
api.add_resource(DataSourceNotionApi,
|
||||
'/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
|
||||
'/datasets/notion-indexing-estimate')
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
|
||||
api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
|
||||
|
||||
@@ -31,40 +31,45 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
raise ValueError('Name must be between 1 to 40 characters.')
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
raise ValueError('Description cannot exceed 400 characters.')
|
||||
return description
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
)
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
@@ -72,22 +77,28 @@ class DatasetListApi(Resource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == "high_quality":
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
item['embedding_available'] = True
|
||||
|
||||
if item.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
|
||||
item.update({"partial_member_list": part_users_list})
|
||||
if item.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
||||
item.update({'partial_member_list': part_users_list})
|
||||
else:
|
||||
item.update({"partial_member_list": []})
|
||||
item.update({'partial_member_list': []})
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@@ -95,21 +106,13 @@ class DatasetListApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
@@ -119,10 +122,9 @@ class DatasetListApi(Resource):
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args["name"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@@ -140,36 +142,42 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
DatasetService.check_dataset_permission(
|
||||
dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get("permission") == "partial_members":
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
embedding_models = configurations.get_models(
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
if data['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
data['embedding_available'] = True
|
||||
else:
|
||||
data["embedding_available"] = False
|
||||
data['embedding_available'] = False
|
||||
else:
|
||||
data["embedding_available"] = True
|
||||
data['embedding_available'] = True
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@@ -183,49 +191,42 @@ class DatasetApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
parser.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
parser.add_argument('name', nullable=False,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('description',
|
||||
location='json', store_missing=False,
|
||||
type=_validate_description_length)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
|
||||
)
|
||||
if data.get('indexing_technique') == 'high_quality':
|
||||
DatasetService.check_embedding_model_setting(dataset.tenant_id,
|
||||
data.get('embedding_model_provider'),
|
||||
data.get('embedding_model')
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@@ -233,19 +234,16 @@ class DatasetApi(Resource):
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
result_data.update({'partial_member_list': partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@@ -262,13 +260,12 @@ class DatasetApi(Resource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DatasetUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -277,10 +274,10 @@ class DatasetUseCheckApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
||||
return {"is_using": dataset_is_using}, 200
|
||||
|
||||
return {'is_using': dataset_is_using}, 200
|
||||
|
||||
class DatasetQueryApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -295,53 +292,51 @@ class DatasetQueryApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(
|
||||
dataset_id=dataset.id,
|
||||
page=page,
|
||||
per_page=limit
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(dataset_queries, dataset_query_detail_fields),
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
'data': marshal(dataset_queries, dataset_query_detail_fields),
|
||||
'has_more': len(dataset_queries) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
.all()
|
||||
)
|
||||
if args['info_list']['data_source_type'] == 'upload_file':
|
||||
file_ids = args['info_list']['file_info_list']['file_ids']
|
||||
file_details = db.session.query(UploadFile).filter(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id,
|
||||
UploadFile.id.in_(file_ids)
|
||||
).all()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
@@ -349,58 +344,55 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
notion_info_list = args['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
for page in notion_info["pages"]:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
for page in notion_info['pages']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args["doc_form"],
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
elif args['info_list']['data_source_type'] == 'website_crawl':
|
||||
website_info_list = args['info_list']['website_info_list']
|
||||
for url in website_info_list['urls']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"provider": website_info_list['provider'],
|
||||
"job_id": website_info_list['job_id'],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
"mode": 'crawl',
|
||||
"only_main_content": website_info_list['only_main_content']
|
||||
},
|
||||
document_model=args["doc_form"],
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
args["dataset_id"],
|
||||
args["indexing_technique"],
|
||||
)
|
||||
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
)
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
@@ -410,6 +402,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
|
||||
class DatasetRelatedAppListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -433,52 +426,52 @@ class DatasetRelatedAppListApi(Resource):
|
||||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||
return {
|
||||
'data': related_apps,
|
||||
'total': len(related_apps)
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
token_prefix = 'dataset-'
|
||||
resource_type = 'dataset'
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
def get(self):
|
||||
keys = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
keys = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
all()
|
||||
return {"items": keys}
|
||||
|
||||
@setup_required
|
||||
@@ -490,17 +483,15 @@ class DatasetApiKeyApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||
.count()
|
||||
)
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
count()
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restful.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
code='max_keys_exceeded'
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
@@ -514,7 +505,7 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = "dataset"
|
||||
resource_type = 'dataset'
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -526,23 +517,18 @@ class DatasetApiDeleteApi(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
ApiToken.tenant_id == current_user.current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
key = db.session.query(ApiToken). \
|
||||
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id). \
|
||||
first()
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message="API key not found")
|
||||
flask_restful.abort(404, message='API key not found')
|
||||
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@@ -551,10 +537,8 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
"api_base_url": (
|
||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||
)
|
||||
+ "/v1"
|
||||
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
}
|
||||
|
||||
|
||||
@@ -565,26 +549,15 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
):
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
||||
return {
|
||||
"retrieval_method": [
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
@@ -600,27 +573,15 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
):
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
|
||||
return {
|
||||
"retrieval_method": [
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
@@ -630,6 +591,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -641,7 +603,10 @@ class DatasetErrorDocs(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||
return {
|
||||
'data': [marshal(item, document_status_fields) for item in results],
|
||||
'total': len(results)
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
@@ -661,21 +626,21 @@ class DatasetPermissionUserListApi(Resource):
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
"data": partial_members_list,
|
||||
'data': partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
|
||||
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
|
||||
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
|
||||
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
|
||||
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
|
||||
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
|
||||
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
|
||||
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
|
||||
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
|
||||
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
|
||||
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
||||
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
@@ -50,33 +50,37 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
raise NotFound('Document not found.')
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=str, default=None, location="args")
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument('last_id', type=str, default=None, location='args')
|
||||
parser.add_argument('limit', type=int, default=20, location='args')
|
||||
parser.add_argument('status', type=str,
|
||||
action='append', default=[], location='args')
|
||||
parser.add_argument('hit_count_gte', type=int,
|
||||
default=None, location='args')
|
||||
parser.add_argument('enabled', type=str, default='all', location='args')
|
||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
last_id = args["last_id"]
|
||||
limit = min(args["limit"], 100)
|
||||
status_list = args["status"]
|
||||
hit_count_gte = args["hit_count_gte"]
|
||||
keyword = args["keyword"]
|
||||
last_id = args['last_id']
|
||||
limit = min(args['limit'], 100)
|
||||
status_list = args['status']
|
||||
hit_count_gte = args['hit_count_gte']
|
||||
keyword = args['keyword']
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
|
||||
if last_id is not None:
|
||||
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||
if last_segment:
|
||||
query = query.filter(DocumentSegment.position > last_segment.position)
|
||||
query = query.filter(
|
||||
DocumentSegment.position > last_segment.position)
|
||||
else:
|
||||
return {"data": [], "has_more": False, "limit": limit}, 200
|
||||
return {'data': [], 'has_more': False, 'limit': limit}, 200
|
||||
|
||||
if status_list:
|
||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||
@@ -85,12 +89,12 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
||||
|
||||
if args["enabled"].lower() != "all":
|
||||
if args["enabled"].lower() == "true":
|
||||
if args['enabled'].lower() != 'all':
|
||||
if args['enabled'].lower() == 'true':
|
||||
query = query.filter(DocumentSegment.enabled == True)
|
||||
elif args["enabled"].lower() == "false":
|
||||
elif args['enabled'].lower() == 'false':
|
||||
query = query.filter(DocumentSegment.enabled == False)
|
||||
|
||||
total = query.count()
|
||||
@@ -102,11 +106,11 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
segments = segments[:-1]
|
||||
|
||||
return {
|
||||
"data": marshal(segments, segment_fields),
|
||||
"doc_form": document.doc_form,
|
||||
"has_more": has_more,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form,
|
||||
'has_more': has_more,
|
||||
'limit': limit,
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
@@ -114,12 +118,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def patch(self, dataset_id, segment_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@@ -130,7 +134,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@@ -138,32 +142,32 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
raise NotFound('Segment not found.')
|
||||
|
||||
if segment.status != "completed":
|
||||
raise NotFound("Segment is not completed, enable or disable function is not allowed")
|
||||
if segment.status != 'completed':
|
||||
raise NotFound('Segment is not completed, enable or disable function is not allowed')
|
||||
|
||||
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
|
||||
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
|
||||
cache_result = redis_client.get(document_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Document is being indexed, please try again later")
|
||||
|
||||
indexing_cache_key = "segment_{}_indexing".format(segment.id)
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError("Segment is being indexed, please try again later")
|
||||
@@ -182,7 +186,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
|
||||
enable_segment_to_index_task.delay(segment.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
elif action == "disable":
|
||||
if not segment.enabled:
|
||||
raise InvalidActionError("Segment is already disabled.")
|
||||
@@ -197,7 +201,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
|
||||
disable_segment_from_index_task.delay(segment.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
@@ -206,36 +210,35 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
@cloud_edition_billing_knowledge_limit_check('add_segment')
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
raise NotFound('Document not found.')
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
@@ -244,34 +247,37 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
raise NotFound('Document not found.')
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@@ -279,22 +285,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@@ -304,13 +310,16 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -320,21 +329,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
raise NotFound('Document not found.')
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
raise NotFound('Segment not found.')
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@@ -343,36 +353,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_resource_check('vector_space')
|
||||
@cloud_edition_billing_knowledge_limit_check('add_segment')
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
raise NotFound('Document not found.')
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith(".csv"):
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
@@ -380,47 +390,51 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == "qa_model":
|
||||
data = {"content": row[0], "answer": row[1]}
|
||||
if document.doc_form == 'qa_model':
|
||||
data = {'content': row[0], 'answer': row[1]}
|
||||
else:
|
||||
data = {"content": row[0]}
|
||||
data = {'content': row[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||
)
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
return {'error': str(e)}, 500
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
||||
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': cache_result.decode()
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
|
||||
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentBatchImportApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
api.add_resource(DatasetDocumentSegmentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetDocumentSegmentApi,
|
||||
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
|
||||
api.add_resource(DatasetDocumentSegmentAddApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
|
||||
api.add_resource(DatasetDocumentSegmentUpdateApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
api.add_resource(DatasetDocumentSegmentBatchImportApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
|
||||
'/datasets/batch_import_status/<uuid:job_id>')
|
||||
|
||||
@@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||
error_code = "high_quality_dataset_only"
|
||||
error_code = 'high_quality_dataset_only'
|
||||
description = "Current operation only supports 'high-quality' datasets."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = "dataset_not_initialized"
|
||||
error_code = 'dataset_not_initialized'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = "archived_document_immutable"
|
||||
error_code = 'archived_document_immutable'
|
||||
description = "The archived document is not editable."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = "dataset_name_duplicate"
|
||||
error_code = 'dataset_name_duplicate'
|
||||
description = "The dataset name already exists. Please modify your dataset name."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = "invalid_action"
|
||||
error_code = 'invalid_action'
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = "document_already_finished"
|
||||
error_code = 'document_already_finished'
|
||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = "document_indexing"
|
||||
error_code = 'document_indexing'
|
||||
description = "The document is being processed and cannot be edited."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = "invalid_metadata"
|
||||
error_code = 'invalid_metadata'
|
||||
description = "The metadata content is incorrect. Please check and verify."
|
||||
code = 400
|
||||
|
||||
|
||||
class WebsiteCrawlError(BaseHTTPException):
|
||||
error_code = "crawl_failed"
|
||||
error_code = 'crawl_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class DatasetInUseError(BaseHTTPException):
|
||||
error_code = "dataset_in_use"
|
||||
error_code = 'dataset_in_use'
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
code = 409
|
||||
|
||||
|
||||
class IndexingEstimateError(BaseHTTPException):
|
||||
error_code = "indexing_estimate_error"
|
||||
error_code = 'indexing_estimate_error'
|
||||
description = "Knowledge indexing estimate failed: {message}"
|
||||
code = 500
|
||||
|
||||
@@ -21,6 +21,7 @@ PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -30,22 +31,23 @@ class FileApi(Resource):
|
||||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
return {
|
||||
"file_size_limit": file_size_limit,
|
||||
"batch_count_limit": batch_count_limit,
|
||||
"image_file_size_limit": image_file_size_limit,
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit,
|
||||
'image_file_size_limit': image_file_size_limit
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@cloud_edition_billing_resource_check(resource='documents')
|
||||
def post(self):
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
@@ -67,7 +69,7 @@ class FilePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService.get_file_preview(file_id)
|
||||
return {"content": text}
|
||||
return {'content': text}
|
||||
|
||||
|
||||
class FileSupportTypeApi(Resource):
|
||||
@@ -76,10 +78,10 @@ class FileSupportTypeApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||
return {"allowed_extensions": allowed_extensions}
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
|
||||
api.add_resource(FileSupportTypeApi, '/files/support-type')
|
||||
|
||||
@@ -29,6 +29,7 @@ from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -45,8 +46,8 @@ class HitTestingApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@@ -54,13 +55,13 @@ class HitTestingApi(Resource):
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
query=args['query'],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
limit=10,
|
||||
retrieval_model=args['retrieval_model'],
|
||||
limit=10
|
||||
)
|
||||
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
@@ -72,8 +73,7 @@ class HitTestingApi(Resource):
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
"in the Settings -> Model Provider.")
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
@@ -83,4 +83,4 @@ class HitTestingApi(Resource):
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
|
||||
|
||||
@@ -9,14 +9,16 @@ from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class WebsiteCrawlApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
|
||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'],
|
||||
required=True, nullable=True, location='json')
|
||||
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
WebsiteService.document_create_args_validate(args)
|
||||
# crawl url
|
||||
@@ -33,15 +35,15 @@ class WebsiteCrawlStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status(job_id, args["provider"])
|
||||
result = WebsiteService.get_crawl_status(job_id, args['provider'])
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
||||
|
||||
api.add_resource(WebsiteCrawlApi, "/website/crawl")
|
||||
api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>")
|
||||
api.add_resource(WebsiteCrawlApi, '/website/crawl')
|
||||
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
|
||||
|
||||
@@ -2,41 +2,35 @@ from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class AlreadySetupError(BaseHTTPException):
|
||||
error_code = "already_setup"
|
||||
error_code = 'already_setup'
|
||||
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
|
||||
code = 403
|
||||
|
||||
|
||||
class NotSetupError(BaseHTTPException):
|
||||
error_code = "not_setup"
|
||||
description = (
|
||||
"Dify has not been initialized and installed yet. "
|
||||
"Please proceed with the initialization and installation process first."
|
||||
)
|
||||
error_code = 'not_setup'
|
||||
description = "Dify has not been initialized and installed yet. " \
|
||||
"Please proceed with the initialization and installation process first."
|
||||
code = 401
|
||||
|
||||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = "not_init_validated"
|
||||
description = (
|
||||
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
|
||||
)
|
||||
error_code = 'not_init_validated'
|
||||
description = "Init validation has not been completed yet. " \
|
||||
"Please proceed with the init validation process first."
|
||||
code = 401
|
||||
|
||||
|
||||
class InitValidateFailedError(BaseHTTPException):
|
||||
error_code = "init_validate_failed"
|
||||
error_code = 'init_validate_failed'
|
||||
description = "Init validation failed. Please check the password and try again."
|
||||
code = 401
|
||||
|
||||
|
||||
class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = "account_not_link_tenant"
|
||||
error_code = 'account_not_link_tenant'
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = "already_activate"
|
||||
error_code = 'already_activate'
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
@@ -33,10 +33,14 @@ class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
|
||||
file = request.files["file"]
|
||||
file = request.files['file']
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
response = AudioService.transcript_asr(
|
||||
app_model=app_model,
|
||||
file=file,
|
||||
end_user=None
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
@@ -72,31 +76,30 @@ class ChatTextApi(InstalledAppResource):
|
||||
app_model = installed_app.app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
voice=voice,
|
||||
text=text
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
@@ -124,7 +127,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||
# endpoint='installed_app_text_with_message_id')
|
||||
|
||||
@@ -30,28 +30,33 @@ from services.app_generate_service import AppGenerateService
|
||||
|
||||
# define completion api for user
|
||||
class CompletionApi(InstalledAppResource):
|
||||
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -80,12 +85,12 @@ class CompletionApi(InstalledAppResource):
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatApi(InstalledAppResource):
|
||||
@@ -96,21 +101,25 @@ class ChatApi(InstalledAppResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -145,22 +154,10 @@ class ChatStopApi(InstalledAppResource):
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
CompletionStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
api.add_resource(
|
||||
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
ChatStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
|
||||
api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
|
||||
api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
|
||||
api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')
|
||||
|
||||
@@ -16,6 +16,7 @@ from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
@@ -24,21 +25,21 @@ class ConversationListApi(InstalledAppResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
@@ -64,6 +65,7 @@ class ConversationApi(InstalledAppResource):
|
||||
|
||||
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
@@ -74,19 +76,24 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, location="json")
|
||||
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@@ -116,26 +123,8 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
ConversationRenameApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationUnPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
|
||||
api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
|
||||
api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
|
||||
api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
|
||||
api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')
|
||||
|
||||
@@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NotCompletionAppError(BaseHTTPException):
|
||||
error_code = "not_completion_app"
|
||||
error_code = 'not_completion_app'
|
||||
description = "Not Completion App"
|
||||
code = 400
|
||||
|
||||
|
||||
class NotChatAppError(BaseHTTPException):
|
||||
error_code = "not_chat_app"
|
||||
error_code = 'not_chat_app'
|
||||
description = "App mode is invalid."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotWorkflowAppError(BaseHTTPException):
|
||||
error_code = "not_workflow_app"
|
||||
error_code = 'not_workflow_app'
|
||||
description = "Only support workflow app."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||
error_code = "app_suggested_questions_after_answer_disabled"
|
||||
error_code = 'app_suggested_questions_after_answer_disabled'
|
||||
description = "Function Suggested questions after answer disabled."
|
||||
code = 403
|
||||
|
||||
@@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource):
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
installed_apps = db.session.query(InstalledApp).filter(
|
||||
InstalledApp.tenant_id == current_tenant_id
|
||||
).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
{
|
||||
"id": installed_app.id,
|
||||
"app": installed_app.app,
|
||||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||
"is_pinned": installed_app.is_pinned,
|
||||
"last_used_at": installed_app.last_used_at,
|
||||
"editable": current_user.role in ["owner", "admin"],
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
'id': installed_app.id,
|
||||
'app': installed_app.app,
|
||||
'app_owner_tenant_id': installed_app.app_owner_tenant_id,
|
||||
'is_pinned': installed_app.is_pinned,
|
||||
'last_used_at': installed_app.last_used_at,
|
||||
'editable': current_user.role in ["owner", "admin"],
|
||||
'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
installed_apps.sort(
|
||||
key=lambda app: (
|
||||
-app["is_pinned"],
|
||||
app["last_used_at"] is None,
|
||||
-app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
|
||||
)
|
||||
)
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'],
|
||||
app['last_used_at'] is None,
|
||||
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
|
||||
|
||||
return {"installed_apps": installed_apps}
|
||||
return {'installed_apps': installed_apps}
|
||||
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||
parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
|
||||
args = parser.parse_args()
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
|
||||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound('App not found')
|
||||
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
app = db.session.query(App).filter(App.id == args["app_id"]).first()
|
||||
app = db.session.query(App).filter(
|
||||
App.id == args['app_id']
|
||||
).first()
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App not found")
|
||||
raise NotFound('App not found')
|
||||
|
||||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
raise Forbidden('You can\'t install a non-public app')
|
||||
|
||||
installed_app = InstalledApp.query.filter(
|
||||
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
|
||||
).first()
|
||||
installed_app = InstalledApp.query.filter(and_(
|
||||
InstalledApp.app_id == args['app_id'],
|
||||
InstalledApp.tenant_id == current_tenant_id
|
||||
)).first()
|
||||
|
||||
if installed_app is None:
|
||||
# todo: position
|
||||
recommended_app.install_count += 1
|
||||
|
||||
new_installed_app = InstalledApp(
|
||||
app_id=args["app_id"],
|
||||
app_id=args['app_id'],
|
||||
tenant_id=current_tenant_id,
|
||||
app_owner_tenant_id=app.tenant_id,
|
||||
is_pinned=False,
|
||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
db.session.add(new_installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"message": "App installed successfully"}
|
||||
return {'message': 'App installed successfully'}
|
||||
|
||||
|
||||
class InstalledAppApi(InstalledAppResource):
|
||||
@@ -94,31 +94,30 @@ class InstalledAppApi(InstalledAppResource):
|
||||
update and delete an installed app
|
||||
use InstalledAppResource to apply default decorators and get installed_app
|
||||
"""
|
||||
|
||||
def delete(self, installed_app):
|
||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
raise BadRequest('You can\'t uninstall an app owned by the current tenant')
|
||||
|
||||
db.session.delete(installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "App uninstalled successfully"}
|
||||
return {'result': 'success', 'message': 'App uninstalled successfully'}
|
||||
|
||||
def patch(self, installed_app):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("is_pinned", type=inputs.boolean)
|
||||
parser.add_argument('is_pinned', type=inputs.boolean)
|
||||
args = parser.parse_args()
|
||||
|
||||
commit_args = False
|
||||
if "is_pinned" in args:
|
||||
installed_app.is_pinned = args["is_pinned"]
|
||||
if 'is_pinned' in args:
|
||||
installed_app.is_pinned = args['is_pinned']
|
||||
commit_args = True
|
||||
|
||||
if commit_args:
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "App info updated successfully"}
|
||||
return {'result': 'success', 'message': 'App info updated successfully'}
|
||||
|
||||
|
||||
api.add_resource(InstalledAppsListApi, "/installed-apps")
|
||||
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
|
||||
api.add_resource(InstalledAppsListApi, '/installed-apps')
|
||||
api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')
|
||||
|
||||
@@ -44,21 +44,19 @@ class MessageListApi(InstalledAppResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
def post(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
@@ -66,32 +64,30 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||
)
|
||||
parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
@@ -99,7 +95,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming,
|
||||
streaming=streaming
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except MessageNotExistsError:
|
||||
@@ -132,7 +128,10 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
@@ -152,22 +151,10 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
||||
api.add_resource(
|
||||
MessageFeedbackApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageMoreLikeThisApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageSuggestedQuestionApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
|
||||
api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
|
||||
api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
|
||||
api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
@@ -10,32 +11,33 @@ from services.app_service import AppService
|
||||
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
'key': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'type': fields.String,
|
||||
'default': fields.String,
|
||||
'max_length': fields.Integer,
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {"image_file_size_limit": fields.String}
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'text_to_speech': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@@ -54,35 +56,30 @@ class AppParameterApi(InstalledAppResource):
|
||||
app_model_config = app_model.app_model_config
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
user_input_form = features_dict.get('user_input_form', [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
|
||||
'opening_statement': features_dict.get('opening_statement'),
|
||||
'suggested_questions': features_dict.get('suggested_questions', []),
|
||||
'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
|
||||
{"enabled": False}),
|
||||
'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
|
||||
'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
|
||||
'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
|
||||
'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
|
||||
'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
|
||||
'user_input_form': user_input_form,
|
||||
'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
|
||||
{"enabled": False, "type": "", "configs": []}),
|
||||
'file_upload': features_dict.get('file_upload', {"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +90,6 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
api.add_resource(
|
||||
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
|
||||
)
|
||||
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
|
||||
endpoint='installed_app_parameters')
|
||||
api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
|
||||
|
||||
@@ -8,28 +8,28 @@ from libs.login import login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
app_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String
|
||||
}
|
||||
|
||||
recommended_app_fields = {
|
||||
"app": fields.Nested(app_fields, attribute="app"),
|
||||
"app_id": fields.String,
|
||||
"description": fields.String(attribute="description"),
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
'app': fields.Nested(app_fields, attribute='app'),
|
||||
'app_id': fields.String,
|
||||
'description': fields.String(attribute='description'),
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'custom_disclaimer': fields.String,
|
||||
'category': fields.String,
|
||||
'position': fields.Integer,
|
||||
'is_listed': fields.Boolean
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
|
||||
"categories": fields.List(fields.String),
|
||||
'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
|
||||
'categories': fields.List(fields.String)
|
||||
}
|
||||
|
||||
|
||||
@@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
|
||||
def get(self):
|
||||
# language args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("language", type=str, location="args")
|
||||
parser.add_argument('language', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.get("language") and args.get("language") in languages:
|
||||
language_prefix = args.get("language")
|
||||
if args.get('language') and args.get('language') in languages:
|
||||
language_prefix = args.get('language')
|
||||
elif current_user and current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
@@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
|
||||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
|
||||
api.add_resource(RecommendedAppListApi, "/explore/apps")
|
||||
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
|
||||
api.add_resource(RecommendedAppListApi, '/explore/apps')
|
||||
api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')
|
||||
|
||||
@@ -11,54 +11,56 @@ from libs.helper import TimestampField, uuid_value
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
'id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
|
||||
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||
parser.add_argument('message_id', type=uuid_value, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||
SavedMessageService.save(app_model, current_user, args['message_id'])
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@@ -67,21 +69,13 @@ class SavedMessageApi(InstalledAppResource):
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
if app_model.mode != "completion":
|
||||
if app_model.mode != 'completion':
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
SavedMessageListApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages",
|
||||
endpoint="installed_app_saved_messages",
|
||||
)
|
||||
api.add_resource(
|
||||
SavedMessageApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
|
||||
endpoint="installed_app_saved_message",
|
||||
)
|
||||
api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
|
||||
api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')
|
||||
|
||||
@@ -35,13 +35,17 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
@@ -72,10 +76,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
return {
|
||||
"result": "success"
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
api.add_resource(
|
||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
|
||||
api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')
|
||||
|
||||
@@ -14,33 +14,29 @@ def installed_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not kwargs.get("installed_app_id"):
|
||||
raise ValueError("missing installed_app_id in path parameters")
|
||||
if not kwargs.get('installed_app_id'):
|
||||
raise ValueError('missing installed_app_id in path parameters')
|
||||
|
||||
installed_app_id = kwargs.get("installed_app_id")
|
||||
installed_app_id = kwargs.get('installed_app_id')
|
||||
installed_app_id = str(installed_app_id)
|
||||
|
||||
del kwargs["installed_app_id"]
|
||||
del kwargs['installed_app_id']
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(
|
||||
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installed_app = db.session.query(InstalledApp).filter(
|
||||
InstalledApp.id == str(installed_app_id),
|
||||
InstalledApp.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
|
||||
if installed_app is None:
|
||||
raise NotFound("Installed app not found")
|
||||
raise NotFound('Installed app not found')
|
||||
|
||||
if not installed_app.app:
|
||||
db.session.delete(installed_app)
|
||||
db.session.commit()
|
||||
|
||||
raise NotFound("Installed app not found")
|
||||
raise NotFound('Installed app not found')
|
||||
|
||||
return view(installed_app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
|
||||
@@ -13,18 +13,23 @@ from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
|
||||
class CodeBasedExtensionAPI(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("module", type=str, required=True, location="args")
|
||||
parser.add_argument('module', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||
return {
|
||||
'module': args['module'],
|
||||
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
|
||||
}
|
||||
|
||||
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -39,22 +44,23 @@ class APIBasedExtensionAPI(Resource):
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
||||
parser.add_argument('api_key', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args["name"],
|
||||
api_endpoint=args["api_endpoint"],
|
||||
api_key=args["api_key"],
|
||||
name=args['name'],
|
||||
api_endpoint=args['api_endpoint'],
|
||||
api_key=args['api_key']
|
||||
)
|
||||
|
||||
return APIBasedExtensionService.save(extension_data)
|
||||
|
||||
|
||||
class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -76,16 +82,16 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
||||
parser.add_argument('api_key', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
extension_data_from_db.name = args["name"]
|
||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||
extension_data_from_db.name = args['name']
|
||||
extension_data_from_db.api_endpoint = args['api_endpoint']
|
||||
|
||||
if args["api_key"] != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = args["api_key"]
|
||||
if args['api_key'] != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = args['api_key']
|
||||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
|
||||
@@ -100,10 +106,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
return {"result": "success"}
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
|
||||
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
|
||||
|
||||
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
|
||||
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")
|
||||
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
|
||||
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')
|
||||
|
||||
@@ -10,6 +10,7 @@ from .wraps import account_initialization_required, cloud_utm_record
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -23,5 +24,5 @@ class SystemFeatureApi(Resource):
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
||||
|
||||
api.add_resource(FeatureApi, "/features")
|
||||
api.add_resource(SystemFeatureApi, "/system-features")
|
||||
api.add_resource(FeatureApi, '/features')
|
||||
api.add_resource(SystemFeatureApi, '/system-features')
|
||||
|
||||
@@ -14,11 +14,12 @@ from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
class InitValidateAPI(Resource):
|
||||
|
||||
def get(self):
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return {"status": "finished"}
|
||||
return {"status": "not_started"}
|
||||
return { 'status': 'finished' }
|
||||
return {'status': 'not_started' }
|
||||
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
@@ -28,23 +29,22 @@ class InitValidateAPI(Resource):
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
parser.add_argument('password', type=str_len(30),
|
||||
required=True, location='json')
|
||||
input_password = parser.parse_args()['password']
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
if input_password != os.environ.get('INIT_PASSWORD'):
|
||||
session['is_init_validated'] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
session["is_init_validated"] = True
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
session['is_init_validated'] = True
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
def get_init_validate_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
return session.get("is_init_validated") or DifySetup.query.first()
|
||||
|
||||
if dify_config.EDITION == 'SELF_HOSTED':
|
||||
if os.environ.get('INIT_PASSWORD'):
|
||||
return session.get('is_init_validated') or DifySetup.query.first()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(InitValidateAPI, "/init")
|
||||
api.add_resource(InitValidateAPI, '/init')
|
||||
|
||||
@@ -4,11 +4,14 @@ from controllers.console import api
|
||||
|
||||
|
||||
class PingApi(Resource):
|
||||
|
||||
def get(self):
|
||||
"""
|
||||
For connection health check
|
||||
"""
|
||||
return {"result": "pong"}
|
||||
return {
|
||||
"result": "pong"
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(PingApi, "/ping")
|
||||
api.add_resource(PingApi, '/ping')
|
||||
|
||||
@@ -16,13 +16,17 @@ from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
class SetupApi(Resource):
|
||||
|
||||
def get(self):
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if dify_config.EDITION == 'SELF_HOSTED':
|
||||
setup_status = get_setup_status()
|
||||
if setup_status:
|
||||
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||
return {"step": "not_started"}
|
||||
return {"step": "finished"}
|
||||
return {
|
||||
'step': 'finished',
|
||||
'setup_at': setup_status.setup_at.isoformat()
|
||||
}
|
||||
return {'step': 'not_started'}
|
||||
return {'step': 'finished'}
|
||||
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
@@ -34,22 +38,28 @@ class SetupApi(Resource):
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
parser.add_argument('email', type=email,
|
||||
required=True, location='json')
|
||||
parser.add_argument('name', type=str_len(
|
||||
30), required=True, location='json')
|
||||
parser.add_argument('password', type=valid_password,
|
||||
required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
||||
email=args['email'],
|
||||
name=args['name'],
|
||||
password=args['password'],
|
||||
ip_address=get_remote_ip(request)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
def setup_required(view):
|
||||
@@ -58,7 +68,7 @@ def setup_required(view):
|
||||
# check setup
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
|
||||
elif not get_setup_status():
|
||||
raise NotSetupError()
|
||||
|
||||
@@ -68,10 +78,9 @@ def setup_required(view):
|
||||
|
||||
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if dify_config.EDITION == 'SELF_HOSTED':
|
||||
return DifySetup.query.first()
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(SetupApi, "/setup")
|
||||
api.add_resource(SetupApi, '/setup')
|
||||
|
||||
@@ -14,18 +14,19 @@ from services.tag_service import TagService
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
raise ValueError('Name must be between 1 to 50 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class TagListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tag_fields)
|
||||
def get(self):
|
||||
tag_type = request.args.get("type", type=str)
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
tag_type = request.args.get('type', type=str)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
||||
return tags, 200
|
||||
@@ -39,21 +40,28 @@ class TagListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
parser.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': 0
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -64,15 +72,20 @@ class TagUpdateDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': binding_count
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
@@ -91,6 +104,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
|
||||
class TagBindingCreateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -100,15 +114,14 @@ class TagBindingCreateApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
|
||||
help='Tag IDs is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
@@ -116,6 +129,7 @@ class TagBindingCreateApi(Resource):
|
||||
|
||||
|
||||
class TagBindingDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -125,18 +139,21 @@ class TagBindingDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
parser.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
parser.add_argument('tag_id', type=str, nullable=False, required=True,
|
||||
help='Tag ID is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True,
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, "/tags")
|
||||
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
|
||||
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
|
||||
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")
|
||||
api.add_resource(TagListApi, '/tags')
|
||||
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
|
||||
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
|
||||
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -10,39 +11,42 @@ from . import api
|
||||
|
||||
|
||||
class VersionApi(Resource):
|
||||
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("current_version", type=str, required=True, location="args")
|
||||
parser.add_argument('current_version', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
"version": dify_config.CURRENT_VERSION,
|
||||
"release_date": "",
|
||||
"release_notes": "",
|
||||
"can_auto_update": False,
|
||||
"features": {
|
||||
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
|
||||
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
|
||||
},
|
||||
'version': dify_config.CURRENT_VERSION,
|
||||
'release_date': '',
|
||||
'release_notes': '',
|
||||
'can_auto_update': False,
|
||||
'features': {
|
||||
'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
|
||||
'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
|
||||
}
|
||||
}
|
||||
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")})
|
||||
response = requests.get(check_update_url, {
|
||||
'current_version': args.get('current_version')
|
||||
})
|
||||
except Exception as error:
|
||||
logging.warning("Check update version error: {}.".format(str(error)))
|
||||
result["version"] = args.get("current_version")
|
||||
result['version'] = args.get('current_version')
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
result["can_auto_update"] = content["canAutoUpdate"]
|
||||
result['version'] = content['version']
|
||||
result['release_date'] = content['releaseDate']
|
||||
result['release_notes'] = content['releaseNotes']
|
||||
result['can_auto_update'] = content['canAutoUpdate']
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(VersionApi, "/version")
|
||||
api.add_resource(VersionApi, '/version')
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user