mirror of
https://github.com/langgenius/dify.git
synced 2026-02-17 22:34:02 +00:00
Compare commits
88 Commits
fix/remove
...
fix/index-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea5e8ee7cc | ||
|
|
63e34e5227 | ||
|
|
c606295ea6 | ||
|
|
27d72e30ad | ||
|
|
5660878f7b | ||
|
|
12e55b2cac | ||
|
|
97e094dfd8 | ||
|
|
9622fbb62f | ||
|
|
cc8dc6d35e | ||
|
|
215661ef91 | ||
|
|
5a3e09518c | ||
|
|
ebba124c5c | ||
|
|
a62325ac87 | ||
|
|
1d2ab2126c | ||
|
|
b07dea836c | ||
|
|
f9d00e0498 | ||
|
|
757ceda063 | ||
|
|
d27e3ab99d | ||
|
|
ce930f19b9 | ||
|
|
3b14939d66 | ||
|
|
279caf033c | ||
|
|
eff280f3e7 | ||
|
|
7c70eb87bc | ||
|
|
6ef401a9f0 | ||
|
|
b29a36f461 | ||
|
|
17f22347ae | ||
|
|
22aaf8960b | ||
|
|
0046ef7707 | ||
|
|
68b1d063f7 | ||
|
|
5e6c3001bd | ||
|
|
7ed4e963aa | ||
|
|
001d868cbd | ||
|
|
6610b4cee5 | ||
|
|
cbbe28f40d | ||
|
|
603187393a | ||
|
|
411e938e3b | ||
|
|
610da4f662 | ||
|
|
3ec80f9dda | ||
|
|
91c5818236 | ||
|
|
c436454cd4 | ||
|
|
a877d4831d | ||
|
|
d522308a29 | ||
|
|
85744b72e5 | ||
|
|
f0b7051e1a | ||
|
|
3b23d6764f | ||
|
|
9b7c74a5d9 | ||
|
|
4d105d7bd7 | ||
|
|
eee779a923 | ||
|
|
ab847c81fa | ||
|
|
b217ee414f | ||
|
|
23dc6edb99 | ||
|
|
79df8825c8 | ||
|
|
71c50b7e20 | ||
|
|
af98fd29bf | ||
|
|
cddea83e65 | ||
|
|
9f16739518 | ||
|
|
3f0da88ff7 | ||
|
|
cc63af8e72 | ||
|
|
bf2268b0af | ||
|
|
00b4cc3cd4 | ||
|
|
f546db5437 | ||
|
|
f8aaa57f31 | ||
|
|
cabcf94be3 | ||
|
|
2d6624cf9e | ||
|
|
02982df0d4 | ||
|
|
421a24c38d | ||
|
|
d7f75d17cc | ||
|
|
5d9ad430af | ||
|
|
46eca01fa3 | ||
|
|
4d9c22bfc6 | ||
|
|
52e59cf4df | ||
|
|
688b8fe114 | ||
|
|
aecdfa2d5c | ||
|
|
cb8feb732f | ||
|
|
c490bdfbf9 | ||
|
|
e7494d632c | ||
|
|
e3006f98c9 | ||
|
|
b34baf1e3a | ||
|
|
372dc7ac1a | ||
|
|
66a62e6c13 | ||
|
|
04c0a9ad45 | ||
|
|
0944ca9d91 | ||
|
|
d468f8b75c | ||
|
|
6e256507d3 | ||
|
|
a33ce09e6e | ||
|
|
9f973bb703 | ||
|
|
6d0a605c5f | ||
|
|
a948bf6ee8 |
@@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
pipx install poetry
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
|
||||
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -1,13 +1,21 @@
|
||||
# Checklist:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Please review the checklist below before submitting your pull request.
|
||||
|
||||
- [ ] Please open an issue before creating a PR or link to an existing issue
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. If it fixes a bug or resolves a feature request, be sure to link to that issue. Close issue syntax: `Fixes #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
|
||||
Fixes # (issue)
|
||||
Fixes
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
@@ -15,18 +23,12 @@ Please delete options that are not relevant.
|
||||
- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement
|
||||
- [ ] Dependency upgrade
|
||||
|
||||
# How Has This Been Tested?
|
||||
# Testing Instructions
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
- [ ] Test A
|
||||
- [ ] Test B
|
||||
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
|
||||
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@@ -75,7 +75,7 @@ jobs:
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -89,5 +89,6 @@ jobs:
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
myscale
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
12
.github/workflows/build-push.yml
vendored
12
.github/workflows/build-push.yml
vendored
@@ -48,18 +48,18 @@ jobs:
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -174,3 +174,5 @@ sdks/python-client/dify_client.egg-info
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
pyrightconfig.json
|
||||
|
||||
.idea/
|
||||
|
||||
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@@ -13,7 +13,6 @@
|
||||
"jinja": true,
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
|
||||
@@ -83,7 +83,7 @@ OCI_REGION=your-region
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -106,6 +106,14 @@ MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
MYSCALE_PORT=8123
|
||||
MYSCALE_USER=default
|
||||
MYSCALE_PASSWORD=
|
||||
MYSCALE_DATABASE=default
|
||||
MYSCALE_FTS_PARAMS=
|
||||
|
||||
# Relyt configuration
|
||||
RELYT_HOST=127.0.0.1
|
||||
RELYT_PORT=5432
|
||||
@@ -151,6 +159,16 @@ CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# AnalyticDB configuration
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
OPENSEARCH_PORT=9200
|
||||
@@ -237,4 +255,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
@@ -5,8 +5,7 @@ WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir --upgrade poetry==${POETRY_VERSION}
|
||||
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
|
||||
|
||||
# Configure Poetry
|
||||
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
docker compose -f docker-compose.middleware.yaml -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from configs import dify_config
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
@@ -43,6 +43,8 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# TODO: Find a way to avoid importing models here
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
from services.account_service import AccountService
|
||||
|
||||
@@ -81,7 +83,7 @@ def create_flask_app_with_configs() -> Flask:
|
||||
with configs loaded from .env file
|
||||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(DifyConfig().model_dump())
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
|
||||
@@ -8,6 +8,7 @@ import click
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@@ -112,7 +113,7 @@ def reset_encrypt_key_pair():
|
||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
||||
return
|
||||
|
||||
@@ -336,6 +337,14 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from configs.deploy import DeploymentConfig
|
||||
from configs.enterprise import EnterpriseFeatureConfig
|
||||
@@ -8,13 +9,7 @@ from configs.middleware import MiddlewareConfig
|
||||
from configs.packaging import PackagingInfo
|
||||
|
||||
|
||||
# TODO: Both `BaseModel` and `BaseSettings` has `model_config` attribute but they are in different types.
|
||||
# This inheritance is depends on the order of the classes.
|
||||
# It is better to use `BaseSettings` as the base class.
|
||||
class DifyConfig(
|
||||
# based on pydantic-settings
|
||||
BaseSettings,
|
||||
|
||||
# Packaging info
|
||||
PackagingInfo,
|
||||
|
||||
@@ -34,12 +29,39 @@ class DifyConfig(
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
):
|
||||
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
frozen=True,
|
||||
|
||||
# ignore extra attributes
|
||||
extra='ignore',
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
|
||||
|
||||
@computed_field
|
||||
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
|
||||
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
|
||||
|
||||
SSRF_PROXY_HTTP_URL: str | None = None
|
||||
SSRF_PROXY_HTTPS_URL: str | None = None
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class DeploymentConfig(BaseModel):
|
||||
class DeploymentConfig(BaseSettings):
|
||||
"""
|
||||
Deployment configs
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class EnterpriseFeatureConfig(BaseModel):
|
||||
class EnterpriseFeatureConfig(BaseSettings):
|
||||
"""
|
||||
Enterprise feature configs.
|
||||
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
class NotionConfig(BaseSettings):
|
||||
"""
|
||||
Notion integration configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeFloat
|
||||
from pydantic import Field, NonNegativeFloat
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class SentryConfig(BaseModel):
|
||||
class SentryConfig(BaseSettings):
|
||||
"""
|
||||
Sentry configs
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
||||
|
||||
class SecurityConfig(BaseModel):
|
||||
class SecurityConfig(BaseSettings):
|
||||
"""
|
||||
Secret Key configs
|
||||
"""
|
||||
@@ -17,8 +18,12 @@ class SecurityConfig(BaseModel):
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description='Expiry time in hours for reset token',
|
||||
default=24,
|
||||
)
|
||||
|
||||
class AppExecutionConfig(BaseModel):
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
App Execution configs
|
||||
"""
|
||||
@@ -26,9 +31,13 @@ class AppExecutionConfig(BaseModel):
|
||||
description='execution timeout in seconds for app execution',
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description='max active request per app, 0 means unlimited',
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseModel):
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
Code Execution Sandbox configs
|
||||
"""
|
||||
@@ -43,36 +52,36 @@ class CodeExecutionSandboxConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseModel):
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Module URL configs
|
||||
"""
|
||||
CONSOLE_API_URL: str = Field(
|
||||
description='The backend URL prefix of the console API.'
|
||||
'used to concatenate the login authorization callback or notion integration callback.',
|
||||
default='https://cloud.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description='The front-end URL prefix of the console web.'
|
||||
'used to concatenate some front-end addresses and for CORS configuration use.',
|
||||
default='https://cloud.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description='Service API Url prefix.'
|
||||
'used to display Service API Base Url to the front-end.',
|
||||
default='https://api.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description='WebApp Url prefix.'
|
||||
'used to display WebAPP API Base Url to the front-end.',
|
||||
default='https://udify.app',
|
||||
default='',
|
||||
)
|
||||
|
||||
|
||||
class FileAccessConfig(BaseModel):
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
File Access configs
|
||||
"""
|
||||
@@ -82,7 +91,7 @@ class FileAccessConfig(BaseModel):
|
||||
'Url is signed and has expiration time.',
|
||||
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
|
||||
alias_priority=1,
|
||||
default='https://cloud.dify.ai',
|
||||
default='',
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
@@ -91,7 +100,7 @@ class FileAccessConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
class FileUploadConfig(BaseSettings):
|
||||
"""
|
||||
File Uploading configs
|
||||
"""
|
||||
@@ -116,7 +125,7 @@ class FileUploadConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HttpConfig(BaseModel):
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP configs
|
||||
"""
|
||||
@@ -148,7 +157,7 @@ class HttpConfig(BaseModel):
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
|
||||
|
||||
|
||||
class InnerAPIConfig(BaseModel):
|
||||
class InnerAPIConfig(BaseSettings):
|
||||
"""
|
||||
Inner API configs
|
||||
"""
|
||||
@@ -163,7 +172,7 @@ class InnerAPIConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
class LoggingConfig(BaseSettings):
|
||||
"""
|
||||
Logging configs
|
||||
"""
|
||||
@@ -195,7 +204,7 @@ class LoggingConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadBalanceConfig(BaseModel):
|
||||
class ModelLoadBalanceConfig(BaseSettings):
|
||||
"""
|
||||
Model load balance configs
|
||||
"""
|
||||
@@ -205,7 +214,7 @@ class ModelLoadBalanceConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BillingConfig(BaseModel):
|
||||
class BillingConfig(BaseSettings):
|
||||
"""
|
||||
Platform Billing Configurations
|
||||
"""
|
||||
@@ -215,7 +224,7 @@ class BillingConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class UpdateConfig(BaseModel):
|
||||
class UpdateConfig(BaseSettings):
|
||||
"""
|
||||
Update configs
|
||||
"""
|
||||
@@ -225,7 +234,7 @@ class UpdateConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowConfig(BaseModel):
|
||||
class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Workflow feature configs
|
||||
"""
|
||||
@@ -246,7 +255,7 @@ class WorkflowConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseModel):
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
oauth configs
|
||||
"""
|
||||
@@ -276,7 +285,7 @@ class OAuthConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseModel):
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
Moderation in app configs.
|
||||
"""
|
||||
@@ -288,7 +297,7 @@ class ModerationConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
class ToolConfig(BaseSettings):
|
||||
"""
|
||||
Tool configs
|
||||
"""
|
||||
@@ -299,7 +308,7 @@ class ToolConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MailConfig(BaseModel):
|
||||
class MailConfig(BaseSettings):
|
||||
"""
|
||||
Mail Configurations
|
||||
"""
|
||||
@@ -355,7 +364,7 @@ class MailConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseModel):
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
RAG ETL Configurations.
|
||||
"""
|
||||
@@ -381,7 +390,7 @@ class RagEtlConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class DataSetConfig(BaseModel):
|
||||
class DataSetConfig(BaseSettings):
|
||||
"""
|
||||
Dataset configs
|
||||
"""
|
||||
@@ -391,8 +400,13 @@ class DataSetConfig(BaseModel):
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description='whether to enable dataset operator',
|
||||
default=False,
|
||||
)
|
||||
|
||||
class WorkspaceConfig(BaseModel):
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
"""
|
||||
@@ -403,7 +417,7 @@ class WorkspaceConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class IndexingConfig(BaseModel):
|
||||
class IndexingConfig(BaseSettings):
|
||||
"""
|
||||
Indexing configs.
|
||||
"""
|
||||
@@ -414,7 +428,7 @@ class IndexingConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseModel):
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
description='multi model send image format, support base64, url, default is base64',
|
||||
default='base64',
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt
|
||||
from pydantic import Field, NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HostedOpenAiConfig(BaseModel):
|
||||
class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Hosted OpenAI service config
|
||||
"""
|
||||
@@ -68,7 +69,7 @@ class HostedOpenAiConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedAzureOpenAiConfig(BaseModel):
|
||||
class HostedAzureOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Hosted OpenAI service config
|
||||
"""
|
||||
@@ -94,7 +95,7 @@ class HostedAzureOpenAiConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedAnthropicConfig(BaseModel):
|
||||
class HostedAnthropicConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Azure OpenAI service config
|
||||
"""
|
||||
@@ -125,7 +126,7 @@ class HostedAnthropicConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseModel):
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Minmax service config
|
||||
"""
|
||||
@@ -136,7 +137,7 @@ class HostedMinmaxConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedSparkConfig(BaseModel):
|
||||
class HostedSparkConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Spark service config
|
||||
"""
|
||||
@@ -147,7 +148,7 @@ class HostedSparkConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedZhipuAIConfig(BaseModel):
|
||||
class HostedZhipuAIConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Minmax service config
|
||||
"""
|
||||
@@ -158,7 +159,7 @@ class HostedZhipuAIConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseModel):
|
||||
class HostedModerationConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Moderation service config
|
||||
"""
|
||||
@@ -174,7 +175,7 @@ class HostedModerationConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchAppTemplateConfig(BaseModel):
|
||||
class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Moderation service config
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
@@ -9,8 +10,10 @@ from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorag
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
@@ -22,7 +25,7 @@ from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description='storage type,'
|
||||
' default to `local`,'
|
||||
@@ -36,14 +39,14 @@ class StorageConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
class VectorStoreConfig(BaseSettings):
|
||||
VECTOR_STORE: Optional[str] = Field(
|
||||
description='vector store type',
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseModel):
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
description='keyword store type',
|
||||
default='jieba',
|
||||
@@ -81,6 +84,11 @@ class DatabaseConfig:
|
||||
default='',
|
||||
)
|
||||
|
||||
DB_EXTRAS: str = Field(
|
||||
description='db extras options. Example: keepalives_idle=60&keepalives=1',
|
||||
default='',
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description='db uri scheme',
|
||||
default='postgresql',
|
||||
@@ -89,7 +97,12 @@ class DatabaseConfig:
|
||||
@computed_field
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = f"?client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else ""
|
||||
db_extras = (
|
||||
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
|
||||
if self.DB_CHARSET
|
||||
else self.DB_EXTRAS
|
||||
).strip("&")
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}")
|
||||
@@ -114,7 +127,7 @@ class DatabaseConfig:
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_ECHO: bool = Field(
|
||||
SQLALCHEMY_ECHO: bool | str = Field(
|
||||
description='whether to enable SqlAlchemy echo',
|
||||
default=False,
|
||||
)
|
||||
@@ -172,8 +185,10 @@ class MiddlewareConfig(
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
PGVectorConfig,
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RedisConfig(BaseModel):
|
||||
class RedisConfig(BaseSettings):
|
||||
"""
|
||||
Redis configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AliyunOSSStorageConfig(BaseModel):
|
||||
class AliyunOSSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Aliyun storage configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class S3StorageConfig(BaseModel):
|
||||
class S3StorageConfig(BaseSettings):
|
||||
"""
|
||||
S3 storage configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AzureBlobStorageConfig(BaseModel):
|
||||
class AzureBlobStorageConfig(BaseSettings):
|
||||
"""
|
||||
Azure Blob storage configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class GoogleCloudStorageConfig(BaseModel):
|
||||
class GoogleCloudStorageConfig(BaseSettings):
|
||||
"""
|
||||
Google Cloud storage configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OCIStorageConfig(BaseModel):
|
||||
class OCIStorageConfig(BaseSettings):
|
||||
"""
|
||||
OCI storage configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TencentCloudCOSStorageConfig(BaseModel):
|
||||
class TencentCloudCOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Tencent Cloud COS storage configs
|
||||
"""
|
||||
|
||||
44
api/configs/middleware/vdb/analyticdb_config.py
Normal file
44
api/configs/middleware/vdb/analyticdb_config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to AnalyticDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ChromaConfig(BaseModel):
|
||||
class ChromaConfig(BaseSettings):
|
||||
"""
|
||||
Chroma configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
class MilvusConfig(BaseSettings):
|
||||
"""
|
||||
Milvus configs
|
||||
"""
|
||||
|
||||
39
api/configs/middleware/vdb/myscale_config.py
Normal file
39
api/configs/middleware/vdb/myscale_config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
description='MyScale host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
description='MyScale user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
description='MyScale password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
description='MyScale database name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default=None,
|
||||
)
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
OpenSearch configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OracleConfig(BaseModel):
|
||||
class OracleConfig(BaseSettings):
|
||||
"""
|
||||
ORACLE configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
class PGVectorConfig(BaseSettings):
|
||||
"""
|
||||
PGVector configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PGVectoRSConfig(BaseModel):
|
||||
class PGVectoRSConfig(BaseSettings):
|
||||
"""
|
||||
PGVectoRS configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
class QdrantConfig(BaseSettings):
|
||||
"""
|
||||
Qdrant configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RelytConfig(BaseModel):
|
||||
class RelytConfig(BaseSettings):
|
||||
"""
|
||||
Relyt configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TencentVectorDBConfig(BaseModel):
|
||||
class TencentVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Tencent Vector configs
|
||||
"""
|
||||
@@ -24,7 +25,7 @@ class TencentVectorDBConfig(BaseModel):
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
|
||||
description='Tencent Vector password',
|
||||
description='Tencent Vector username',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -38,7 +39,12 @@ class TencentVectorDBConfig(BaseModel):
|
||||
default=1,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_REPLICAS: PositiveInt = Field(
|
||||
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description='Tencent Vector replicas',
|
||||
default=2,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description='Tencent Vector Database',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TiDBVectorConfig(BaseModel):
|
||||
class TiDBVectorConfig(BaseSettings):
|
||||
"""
|
||||
TiDB Vector configs
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
class WeaviateConfig(BaseSettings):
|
||||
"""
|
||||
Weaviate configs
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class PackagingInfo(BaseModel):
|
||||
class PackagingInfo(BaseSettings):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.6.12-fix1',
|
||||
default='0.6.13',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -14,7 +14,7 @@ language_timezone_mapping = {
|
||||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
||||
'ro-RO': 'Europe/Bucharest',
|
||||
'pl-PL': 'Europe/Warsaw',
|
||||
'hi-IN': 'Asia/Kolkata'
|
||||
'hi-IN': 'Asia/Kolkata',
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
File diff suppressed because one or more lines are too long
4
api/constants/tts_auto_play_timeout.py
Normal file
4
api/constants/tts_auto_play_timeout.py
Normal file
@@ -0,0 +1,4 @@
|
||||
TTS_AUTO_PLAY_TIMEOUT = 5
|
||||
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02
|
||||
@@ -30,7 +30,7 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
|
||||
@@ -134,6 +134,7 @@ class AppApi(Resource):
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
@@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'],
|
||||
streaming=False
|
||||
text=text,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
@@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
@@ -16,11 +17,11 @@ from ..wraps import account_initialization_required
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'NOTION_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
|
||||
if not dify_config.NOTION_CLIENT_ID or not dify_config.NOTION_CLIENT_SECRET:
|
||||
return {}
|
||||
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
|
||||
client_secret=dify_config.NOTION_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'notion': notion_oauth
|
||||
@@ -39,8 +40,10 @@ class OAuthDataSource(Resource):
|
||||
print(vars(oauth_provider))
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
|
||||
internal_secret = dify_config.NOTION_INTERNAL_SECRET
|
||||
if not internal_secret:
|
||||
return {'error': 'Internal secret is not set'},
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return { 'data': '' }
|
||||
else:
|
||||
@@ -60,13 +63,13 @@ class OAuthDataSourceCallback(Resource):
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
|
||||
|
||||
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
|
||||
@@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
|
||||
error_code = 'auth_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class InvalidEmailError(BaseHTTPException):
|
||||
error_code = 'invalid_email'
|
||||
description = "The email address is not valid."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordMismatchError(BaseHTTPException):
|
||||
error_code = 'password_mismatch'
|
||||
description = "The passwords do not match."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidTokenError(BaseHTTPException):
|
||||
error_code = 'invalid_or_expired_token'
|
||||
description = "The token is invalid or has expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = 'password_reset_rate_limit_exceeded'
|
||||
description = "Password reset rate limit exceeded. Try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
107
api/controllers/console/auth/forgot_password.py
Normal file
107
api/controllers/console/auth/forgot_password.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
PasswordResetRateLimitExceededError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import RateLimitExceededError
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
email = args['email']
|
||||
|
||||
if not email_validate(email):
|
||||
raise InvalidEmailError()
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if account:
|
||||
try:
|
||||
AccountService.send_reset_password_email(account=account)
|
||||
except RateLimitExceededError:
|
||||
logging.warning(f"Rate limit exceeded for email: {account.email}")
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
else:
|
||||
# Return success to avoid revealing email registration status
|
||||
logging.warning(f"Attempt to reset password for unregistered email: {email}")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
token = args['token']
|
||||
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
return {'is_valid': False, 'email': None}
|
||||
return {'is_valid': True, 'email': reset_data.get('email')}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
||||
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
new_password = args['new_password']
|
||||
password_confirm = args['password_confirm']
|
||||
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
raise PasswordMismatchError()
|
||||
|
||||
token = args['token']
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_reset_password_token(token)
|
||||
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
||||
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
||||
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
@@ -56,14 +56,14 @@ class LogoutApi(Resource):
|
||||
class ResetPasswordApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('email', type=email, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
# parser = reqparse.RequestParser()
|
||||
# parser.add_argument('email', type=email, required=True, location='json')
|
||||
# args = parser.parse_args()
|
||||
|
||||
# import mailchimp_transactional as MailchimpTransactional
|
||||
# from mailchimp_transactional.api_client import ApiClientError
|
||||
|
||||
account = {'email': args['email']}
|
||||
# account = {'email': args['email']}
|
||||
# account = AccountService.get_by_email(args['email'])
|
||||
# if account is None:
|
||||
# raise ValueError('Email not found')
|
||||
@@ -71,22 +71,22 @@ class ResetPasswordApi(Resource):
|
||||
# AccountService.update_password(account, new_password)
|
||||
|
||||
# todo: Send email
|
||||
MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
|
||||
# MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
|
||||
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
|
||||
|
||||
message = {
|
||||
'from_email': 'noreply@example.com',
|
||||
'to': [{'email': account.email}],
|
||||
'subject': 'Reset your Dify password',
|
||||
'html': """
|
||||
<p>Dear User,</p>
|
||||
<p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
<p><strong>{new_password}</strong></p>
|
||||
<p>Please change your password to log in as soon as possible.</p>
|
||||
<p>Regards,</p>
|
||||
<p>The Dify Team</p>
|
||||
"""
|
||||
}
|
||||
# message = {
|
||||
# 'from_email': 'noreply@example.com',
|
||||
# 'to': [{'email': account['email']}],
|
||||
# 'subject': 'Reset your Dify password',
|
||||
# 'html': """
|
||||
# <p>Dear User,</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p><strong>{new_password}</strong></p>
|
||||
# <p>Please change your password to log in as soon as possible.</p>
|
||||
# <p>Regards,</p>
|
||||
# <p>The Dify Team</p>
|
||||
# """
|
||||
# }
|
||||
|
||||
# response = mailchimp.messages.send({
|
||||
# 'message': message,
|
||||
|
||||
@@ -6,6 +6,7 @@ import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import get_remote_ip
|
||||
@@ -18,22 +19,24 @@ from .. import api
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GITHUB_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
|
||||
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET:
|
||||
github_oauth = None
|
||||
else:
|
||||
github_oauth = GitHubOAuth(
|
||||
client_id=dify_config.GITHUB_CLIENT_ID,
|
||||
client_secret=dify_config.GITHUB_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
|
||||
)
|
||||
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
|
||||
google_oauth = None
|
||||
else:
|
||||
google_oauth = GoogleOAuth(
|
||||
client_id=dify_config.GOOGLE_CLIENT_ID,
|
||||
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
|
||||
)
|
||||
|
||||
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
|
||||
client_secret=current_app.config.get(
|
||||
'GOOGLE_CLIENT_SECRET'),
|
||||
redirect_uri=current_app.config.get(
|
||||
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
|
||||
|
||||
OAUTH_PROVIDERS = {
|
||||
'github': github_oauth,
|
||||
'google': google_oauth
|
||||
}
|
||||
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
@@ -63,8 +66,7 @@ class OAuthCallback(Resource):
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
|
||||
return {'error': 'OAuth process failed'}, 400
|
||||
|
||||
account = _generate_account(provider, user_info)
|
||||
@@ -81,7 +83,7 @@ class OAuthCallback(Resource):
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
|
||||
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
@@ -101,11 +103,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else 'Dify'
|
||||
account = RegisterService.register(
|
||||
email=user_info.email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
||||
@@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@@ -85,6 +85,12 @@ class DatasetListApi(Resource):
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
|
||||
if item.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
||||
item.update({'partial_member_list': part_users_list})
|
||||
else:
|
||||
item.update({'partial_member_list': []})
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
@@ -108,8 +114,8 @@ class DatasetListApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -140,6 +146,10 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
@@ -163,6 +173,11 @@ class DatasetApi(Resource):
|
||||
data['embedding_available'] = False
|
||||
else:
|
||||
data['embedding_available'] = True
|
||||
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@setup_required
|
||||
@@ -188,17 +203,21 @@ class DatasetApi(Resource):
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
@@ -206,7 +225,20 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
)
|
||||
else:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({'partial_member_list': partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -215,11 +247,12 @@ class DatasetApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
@@ -515,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@@ -539,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@@ -569,6 +602,27 @@ class DatasetErrorDocs(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
||||
@@ -582,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
||||
|
||||
@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@@ -344,7 +349,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
indexing_runner.calculate_tokens(document)
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
@@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
# check user's permission
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -19,6 +19,7 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
|
||||
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
from flask_restful import reqparse
|
||||
|
||||
app_model = installed_app.app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
@@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||
# endpoint='installed_app_text_with_message_id')
|
||||
|
||||
@@ -36,7 +36,7 @@ class TagListApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@@ -245,6 +245,8 @@ class AccountIntegrateApi(Resource):
|
||||
return {'data': integrate_data}
|
||||
|
||||
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, '/account/init')
|
||||
api.add_resource(AccountProfileApi, '/account/profile')
|
||||
|
||||
@@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class DatasetOperatorMemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {'result': 'success', 'accounts': members}, 200
|
||||
|
||||
|
||||
api.add_resource(MemberListApi, '/workspaces/current/members')
|
||||
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
|
||||
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
|
||||
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
|
||||
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
|
||||
|
||||
@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -72,19 +72,30 @@ class AudioApi(Resource):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=args['text'],
|
||||
end_user=end_user,
|
||||
voice=args.get('voice'),
|
||||
streaming=args['streaming']
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
@@ -69,7 +74,7 @@ class CompletionApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@@ -132,7 +137,7 @@ class ChatApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
@@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -19,7 +19,7 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App
|
||||
from models.model import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
from flask_restful import reqparse
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get(
|
||||
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form.get('voice') else None,
|
||||
streaming=False
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
||||
@@ -114,6 +114,10 @@ class VariableEntity(BaseModel):
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.variable
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
|
||||
135
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Normal file
135
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
def __init__(self, status: str, audio):
|
||||
self.audio = audio
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def _extract_sentence(self, org_text):
|
||||
tx = self.match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
result.append(org_text[start:end])
|
||||
start = end
|
||||
return result, org_text[start:]
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
@@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
@@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
-> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
@@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
@@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
@@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
@@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
@@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return None
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
@@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
@@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
|
||||
@@ -1,52 +1,56 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
for variable_config in variables:
|
||||
variable = variable_config.variable
|
||||
|
||||
if (variable not in user_inputs
|
||||
or user_inputs[variable] is None
|
||||
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
|
||||
if variable_config.required:
|
||||
raise ValueError(f"{variable} is required in input form")
|
||||
else:
|
||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value is not None:
|
||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
||||
if '.' in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
if value and isinstance(value, str):
|
||||
filtered_inputs[variable] = value.replace('\x00', '')
|
||||
else:
|
||||
filtered_inputs[variable] = value if value is not None else None
|
||||
|
||||
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.name)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.name} is required in input form')
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.SELECT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
)
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
):
|
||||
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
return float(user_input_value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||
if var.type == VariableEntity.Type.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||
|
||||
return user_input_value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.replace('\x00', '')
|
||||
return value
|
||||
|
||||
@@ -51,7 +51,6 @@ class AppQueueManager:
|
||||
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
||||
@@ -94,7 +94,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@@ -104,7 +103,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -166,10 +164,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
@@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
@@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
@@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
|
||||
@@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
@@ -86,6 +87,8 @@ class StreamEvent(Enum):
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
@@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageAudioEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageEndStreamResponse entity
|
||||
@@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeCompletedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextReplaceStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
WorkflowAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
||||
1
api/core/app/features/rate_limiting/__init__.py
Normal file
1
api/core/app/features/rate_limiting/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .rate_limit import RateLimit
|
||||
120
api/core/app/features/rate_limiting/rate_limit.py
Normal file
120
api/core/app/features/rate_limiting/rate_limit.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
||||
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
with redis_client.pipeline() as pipe:
|
||||
pipe.set(self.max_active_requests_key, self.max_active_requests)
|
||||
pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
def enter(self, request_id: Optional[str] = None) -> str:
|
||||
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||
self.flush_cache()
|
||||
if self.max_active_requests <= 0:
|
||||
return RateLimit._UNLIMITED_REQUEST_ID
|
||||
if not request_id:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
def exit(self, request_id: str):
|
||||
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
||||
return
|
||||
redis_client.hdel(self.active_requests_key, request_id)
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(self, generator, request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
if callable(generator):
|
||||
self.generator = generator()
|
||||
else:
|
||||
self.generator = generator
|
||||
self.request_id = request_id
|
||||
self.closed = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.closed:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
self.generator.close()
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
@@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
|
||||
CompletionAppStreamResponse,
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
@@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._app_config = application_generate_entity.app_config
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
@@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
@@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
"""
|
||||
Save message.
|
||||
|
||||
@@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from configs import dify_config
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
@@ -31,10 +30,10 @@ class APIBasedExtensionRequestor:
|
||||
try:
|
||||
# proxy support for security
|
||||
proxies = None
|
||||
if os.environ.get("SSRF_PROXY_HTTP_URL") and os.environ.get("SSRF_PROXY_HTTPS_URL"):
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxies = {
|
||||
'http': os.environ.get("SSRF_PROXY_HTTP_URL"),
|
||||
'https': os.environ.get("SSRF_PROXY_HTTPS_URL"),
|
||||
'http': dify_config.SSRF_PROXY_HTTP_URL,
|
||||
'https': dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
|
||||
@@ -186,7 +186,7 @@ class MessageFileParser:
|
||||
}
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code == 200:
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
@@ -9,6 +8,7 @@ from httpx import get, post
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
@@ -18,8 +18,8 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code Executor
|
||||
CODE_EXECUTION_ENDPOINT = os.environ.get('CODE_EXECUTION_ENDPOINT', 'http://sandbox:8194')
|
||||
CODE_EXECUTION_API_KEY = os.environ.get('CODE_EXECUTION_API_KEY', 'dify-sandbox')
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT= (10, 60)
|
||||
|
||||
|
||||
@@ -214,6 +214,61 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def calculate_tokens(self, tenant_id: str, tokens: int, dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
||||
|
||||
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
@@ -730,7 +785,7 @@ class IndexingRunner:
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
if embedding_model_instance:
|
||||
tokens += sum(
|
||||
embedding_model_instance.get_text_embedding_num_tokens(
|
||||
[document.page_content]
|
||||
|
||||
@@ -12,7 +12,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, Message
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
@@ -30,33 +31,46 @@ class TokenBufferMemory:
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
query = db.session.query(Message).filter(
|
||||
query = db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id
|
||||
).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer != ''
|
||||
).order_by(Message.created_at.desc())
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
messages = query.limit(message_limit).all()
|
||||
message_limit = message_limit if message_limit <= 500 else 500
|
||||
else:
|
||||
messages = query.all()
|
||||
message_limit = 500
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id
|
||||
)
|
||||
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
files = message.message_files
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
else:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
message.workflow_run.workflow.features_dict,
|
||||
is_vision=False
|
||||
)
|
||||
if message.workflow_run_id:
|
||||
workflow_run = (db.session.query(WorkflowRun)
|
||||
.filter(WorkflowRun.id == message.workflow_run_id).first())
|
||||
|
||||
if workflow_run:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict,
|
||||
is_vision=False
|
||||
)
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
@@ -136,4 +150,4 @@ class TokenBufferMemory:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@@ -264,7 +264,7 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
@@ -287,8 +287,7 @@ class ModelInstance:
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
voice=voice
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
@@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
@@ -20,7 +22,7 @@ class TTSModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -35,14 +37,15 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
|
||||
self._is_ffmpeg_installed()
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
|
||||
return self._invoke(model=model, credentials=credentials, user=user,
|
||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -123,26 +126,26 @@ class TTSModel(AIModel):
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
|
||||
match = re.compile(pattern)
|
||||
tx = match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
one_sentence = ''
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
tmp = org_text[start:end]
|
||||
if len(one_sentence + tmp) > max_length:
|
||||
result.append(one_sentence)
|
||||
one_sentence = ''
|
||||
one_sentence += tmp
|
||||
start = end
|
||||
last_sens = org_text[start:]
|
||||
if last_sens:
|
||||
one_sentence += last_sens
|
||||
if one_sentence != '':
|
||||
result.append(one_sentence)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
|
||||
@@ -33,3 +33,4 @@
|
||||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
|
||||
@@ -4,7 +4,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
@@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
@@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
# max font is 4096,there is 3500 limit for each request
|
||||
max_length = 3500
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
@@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
||||
@@ -5,6 +5,8 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@@ -5,6 +5,8 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@@ -5,6 +5,8 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@@ -5,6 +5,8 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
|
||||
@@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@@ -68,7 +69,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
# TODO: consolidate different invocation methods for models based on base model capabilities
|
||||
# invoke anthropic models via boto3 client
|
||||
if "anthropic" in model:
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
# invoke Cohere models via boto3 client
|
||||
if "cohere.command-r" in model:
|
||||
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
@@ -151,7 +152,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
|
||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke Anthropic large language model
|
||||
|
||||
@@ -171,23 +172,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
parameters = {
|
||||
'modelId': model,
|
||||
'messages': prompt_message_dicts,
|
||||
'inferenceConfig': inference_config,
|
||||
'additionalModelRequestFields': additional_model_fields,
|
||||
}
|
||||
|
||||
if system and len(system) > 0:
|
||||
parameters['system'] = system
|
||||
|
||||
if tools:
|
||||
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
||||
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(
|
||||
modelId=model,
|
||||
messages=prompt_message_dicts,
|
||||
system=system,
|
||||
inferenceConfig=inference_config,
|
||||
additionalModelRequestFields=additional_model_fields
|
||||
)
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = bedrock_client.converse(
|
||||
modelId=model,
|
||||
messages=prompt_message_dicts,
|
||||
system=system,
|
||||
inferenceConfig=inference_config,
|
||||
additionalModelRequestFields=additional_model_fields
|
||||
)
|
||||
response = bedrock_client.converse(**parameters)
|
||||
return self._handle_converse_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
||||
@@ -246,12 +248,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
output_tokens = 0
|
||||
finish_reason = None
|
||||
index = 0
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_use = {}
|
||||
|
||||
for chunk in response['stream']:
|
||||
if 'messageStart' in chunk:
|
||||
return_model = model
|
||||
elif 'messageStop' in chunk:
|
||||
finish_reason = chunk['messageStop']['stopReason']
|
||||
elif 'contentBlockStart' in chunk:
|
||||
tool = chunk['contentBlockStart']['start']['toolUse']
|
||||
tool_use['toolUseId'] = tool['toolUseId']
|
||||
tool_use['name'] = tool['name']
|
||||
elif 'metadata' in chunk:
|
||||
input_tokens = chunk['metadata']['usage']['inputTokens']
|
||||
output_tokens = chunk['metadata']['usage']['outputTokens']
|
||||
@@ -260,29 +268,49 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
model=return_model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index + 1,
|
||||
index=index,
|
||||
message=AssistantPromptMessage(
|
||||
content=''
|
||||
content='',
|
||||
tool_calls=tool_calls
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
elif 'contentBlockDelta' in chunk:
|
||||
chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
)
|
||||
index = chunk['contentBlockDelta']['contentBlockIndex']
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
delta = chunk['contentBlockDelta']['delta']
|
||||
if 'text' in delta:
|
||||
chunk_text = delta['text'] if delta['text'] else ''
|
||||
full_assistant_content += chunk_text
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=chunk_text if chunk_text else '',
|
||||
)
|
||||
)
|
||||
index = chunk['contentBlockDelta']['contentBlockIndex']
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index+1,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
elif 'toolUse' in delta:
|
||||
if 'input' not in tool_use:
|
||||
tool_use['input'] = ''
|
||||
tool_use['input'] += delta['toolUse']['input']
|
||||
elif 'contentBlockStop' in chunk:
|
||||
if 'input' in tool_use:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_use['toolUseId'],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_use['name'],
|
||||
arguments=tool_use['input']
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
tool_use = {}
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
|
||||
@@ -312,16 +340,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
system = []
|
||||
first_loop = True
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
if first_loop:
|
||||
system=message.content
|
||||
first_loop=False
|
||||
else:
|
||||
system+="\n"
|
||||
system+=message.content
|
||||
system.append({"text": message.content})
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
@@ -330,6 +352,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return system, prompt_message_dicts
|
||||
|
||||
def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
|
||||
tool_config = {}
|
||||
configs = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
configs.append(
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": {
|
||||
"json": tool.parameters
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
tool_config["tools"] = configs
|
||||
return tool_config
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict
|
||||
@@ -379,10 +420,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
|
||||
if message.tool_calls:
|
||||
message_dict = {
|
||||
"role": "assistant", "content":[{
|
||||
"toolUse": {
|
||||
"toolUseId": message.tool_calls[0].id,
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"input": json.loads(message.tool_calls[0].function.arguments)
|
||||
}
|
||||
}]
|
||||
}
|
||||
else:
|
||||
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = [{'text': message.content}]
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"toolResult": {
|
||||
"toolUseId": message.tool_call_id,
|
||||
"content": [{"json": {"text": message.content}}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
@@ -401,11 +464,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
if isinstance(prompt_messages, str):
|
||||
prompt = prompt_messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
|
||||
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
@@ -489,11 +554,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
content = message.content
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}"
|
||||
body = content
|
||||
if (isinstance(content, list)):
|
||||
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
|
||||
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt_prefix} {message.content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ model_properties:
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
||||
@@ -21,7 +21,7 @@ model_properties:
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
||||
@@ -3,7 +3,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import OpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@@ -11,7 +11,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
@@ -20,7 +19,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@@ -29,22 +28,17 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
# if streaming:
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
@@ -79,7 +73,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@@ -104,34 +98,40 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
if len(content_text) > word_limit:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 21 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 48 KiB |
@@ -0,0 +1,61 @@
|
||||
model: Qwen-14B-Chat-Int4
|
||||
label:
|
||||
en_US: Qwen-14B-Chat-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 128
|
||||
min: 1
|
||||
max: 256
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user