Compare commits

...

88 Commits

Author SHA1 Message Date
jyong
ea5e8ee7cc calculate tokens 2024-07-15 18:39:25 +08:00
Zhuo Qiu
63e34e5227 feat: support MyScale vector database (#6092) 2024-07-11 15:21:59 +08:00
crazywoola
c606295ea6 fix: data not updated (#6161) 2024-07-11 11:09:14 +08:00
非法操作
27d72e30ad fix: can add a custom tool without a name (#6172) 2024-07-11 11:08:50 +08:00
非法操作
5660878f7b chore: update the tool's doc (#6167) 2024-07-11 11:02:58 +08:00
Nam Vu
12e55b2cac chore: update i18n for #6069 (#6163) 2024-07-11 10:02:35 +08:00
Nam Vu
97e094dfd8 chore: update i18n for #5943 (#6162) 2024-07-10 23:28:02 +08:00
liuzhenghua
9622fbb62f feat: app rate limit (#5844)
Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
Co-authored-by: takatost <takatost@gmail.com>
2024-07-10 21:31:35 +08:00
crazywoola
cc8dc6d35e Revert "chore: update the tool's doc" (#6153) 2024-07-10 19:57:12 +08:00
Su Yang
215661ef91 feat: add PerfXCloud, Qwen series #6116 (#6117) 2024-07-10 18:26:10 +08:00
Joe
5a3e09518c feat: add if elif (#6094) 2024-07-10 18:22:51 +08:00
zxhlyh
ebba124c5c feat: workflow if-else support elif (#6072) 2024-07-10 18:20:13 +08:00
Joel
a62325ac87 feat: add until className defines (#6141) 2024-07-10 15:34:56 +08:00
非法操作
1d2ab2126c chore: update the tool's doc (#6122) 2024-07-10 12:42:34 +08:00
天魂
b07dea836c feat(embed): enhance config and add custom styling support (#5781) 2024-07-10 09:27:24 +08:00
Bowen Liang
f9d00e0498 chore: use poetry for linter tools installation and bump Ruff from 0.4 to 0.5 (#6081) 2024-07-09 23:06:23 +08:00
dependabot[bot]
757ceda063 chore(deps): bump braces from 3.0.2 to 3.0.3 in /web (#6098)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-09 23:05:12 +08:00
sino
d27e3ab99d chore: remove unresolved reference (#6110) 2024-07-09 23:04:44 +08:00
Joe
ce930f19b9 fix dataset operator (#6064)
Co-authored-by: JzoNg <jzongcode@gmail.com>
2024-07-09 17:47:54 +08:00
Joel
3b14939d66 Chore: new tailwind vars (#6100) 2024-07-09 16:37:59 +08:00
AIxGEEK
279caf033c fix: node-title-is-overflow-in-checklist (#5870) 2024-07-09 15:12:41 +08:00
Joel
eff280f3e7 feat: tailwind related improvement (#6085) 2024-07-09 15:05:40 +08:00
8bitpd
7c70eb87bc feat: support AnalyticDB vector store (#5586)
Co-authored-by: xiaozeyu <xiaozeyu.xzy@alibaba-inc.com>
2024-07-09 13:32:04 +08:00
chenxu9741
6ef401a9f0 feat:add tts-streaming config and future (#5492) 2024-07-09 11:33:58 +08:00
非法操作
b29a36f461 Feat: add index bar to select tool panel of workflow (#6066)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-07-09 09:43:34 +08:00
takatost
17f22347ae bump to 0.6.13 (#6078) 2024-07-08 23:23:07 +08:00
AIxGEEK
22aaf8960b fix: Inconsistency Between Actual and Debug Input Variables (#6055) 2024-07-08 22:27:55 +08:00
Whitewater
0046ef7707 refactor: revamp picker block (#4227)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-07-08 21:56:09 +08:00
takatost
68b1d063f7 chore: remove tsne unused code (#6077) 2024-07-08 21:25:19 +08:00
AIxGEEK
5e6c3001bd fix: relative in overflow div (#5998) 2024-07-08 18:35:12 +08:00
takatost
7ed4e963aa chore(action): move docker login above Set up QEMU in build-push action workflow (#6073) 2024-07-08 18:32:29 +08:00
Chenhe Gu
001d868cbd remove clunky welcome message (#6069) 2024-07-08 18:17:41 +08:00
Xiao Ley
6610b4cee5 feat: add request_params field to jina_reader tool (#5610) 2024-07-08 18:11:50 +08:00
Jyong
cbbe28f40d fix azure stream download (#6063) 2024-07-08 17:13:16 +08:00
Joel
603187393a chore: hide tracing introduce detail (#6049) 2024-07-08 10:50:18 +08:00
Waffle
411e938e3b Address the issue of the absence of poetry in the development container. (#6036)
Co-authored-by: ox01024@163.com <Waffle>
2024-07-08 09:44:07 +08:00
75py
610da4f662 Fix authorization header validation to handle bearer types correctly - "authorization config header is required" error (#6040) 2024-07-08 00:09:59 +08:00
crazywoola
3ec80f9dda Fix/6034 get random order of categories in explore and workflow is missing in zh hant (#6043) 2024-07-07 17:06:47 +08:00
Mab
91c5818236 Modify slack webhook url validation to allow workflow (#6041) (#6042)
Co-authored-by: Shunsuke Mabuchi <mabuchs@amazon.co.jp>
2024-07-07 14:09:20 +08:00
-LAN-
c436454cd4 fix(configs): Update pydantic settings in config files (#6023) 2024-07-07 12:18:15 +08:00
Yeuoly
a877d4831d Fix/incorrect parameter extractor memory (#6038) 2024-07-07 12:17:34 +08:00
takatost
d522308a29 chore: optimize memory fetch performance (#6039) 2024-07-07 08:54:24 +08:00
sino
85744b72e5 feat: support moonshot and glm base models for volcengine provider (#6029) 2024-07-07 01:17:33 +08:00
Cherilyn Buren
f0b7051e1a Optimize db config (#6011) 2024-07-07 01:06:51 +08:00
Masashi Tomooka
3b23d6764f fix: token count includes base64 string of input images (#5868) 2024-07-06 16:53:32 +08:00
Bowen Liang
9b7c74a5d9 chore: skip pip upgrade preparation in api dockerfile (#5999) 2024-07-06 14:17:34 +08:00
-LAN-
4d105d7bd7 feat(*): Swtich to dify_config. (#6025) 2024-07-06 12:05:13 +08:00
非法操作
eee779a923 fix: the input field of tool panel not worked as expected (#6003) 2024-07-06 09:54:30 +08:00
ahasasjeb
ab847c81fa Add 2 firecrawl tools : Scrape and Search (#6016)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-07-06 09:45:39 +08:00
-LAN-
b217ee414f test(test_rerank): Remove duplicate test cases. (#6024) 2024-07-06 09:44:50 +08:00
takatost
23dc6edb99 chore: optimize memory messages fetch count limit (#6021) 2024-07-06 03:25:38 +08:00
takatost
79df8825c8 Revert "feat: knowledge admin role" (#6018) 2024-07-05 21:31:34 +08:00
K8sCat
71c50b7e20 feat: add Llama 3 and Mixtral model options to ddgo_ai.yaml (#5979)
Signed-off-by: K8sCat <k8scat@gmail.com>
2024-07-05 21:11:15 +08:00
opriuwohg
af98fd29bf fix: add status_code 304 (#6000)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-07-05 21:10:33 +08:00
crazywoola
cddea83e65 6014 i18n add support for spanish (#6017) 2024-07-05 21:05:33 +08:00
Jinq Qian
9f16739518 [Feature] Support loading for mermaid. (#6004)
Co-authored-by: 靖谦 <jingqian@kaiwu.cloud>
2024-07-05 21:01:50 +08:00
Joe
3f0da88ff7 fix: update workflow trace query (#6010) 2024-07-05 18:37:26 +08:00
ahasasjeb
cc63af8e72 Removed firecrawl-py, fixed and improved firecrawl tool (#5896)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-07-05 18:04:51 +08:00
非法操作
bf2268b0af fix API tool's schema not support array (#6006) 2024-07-05 17:11:59 +08:00
xielong
00b4cc3cd4 feat: implement forgot password feature (#5534) 2024-07-05 13:38:51 +08:00
Aurelius Huang
f546db5437 fix: document truncation and loss in notion document sync (#5631)
Co-authored-by: Aurelius Huang <cm.huang@aftership.com>
2024-07-05 11:48:17 +08:00
orangeclk
f8aaa57f31 feat: add retry mechanism for zhipuai (#5926) 2024-07-05 10:49:18 +08:00
jianglin1008
cabcf94be3 fix: TENCENT_VECTOR_DB_REPLICAS can be set to 0 (#5968)
Co-authored-by: jianglin <jianglin@wangxiaobao.com>
2024-07-05 08:32:28 +08:00
legao
2d6624cf9e typo: Update README.md (#5987) 2024-07-04 22:50:27 +08:00
-LAN-
02982df0d4 fix: Fix some type error in http executor. (#5915) 2024-07-04 19:34:37 +08:00
-LAN-
421a24c38d refactor(api/app.py): Simplify the retrieval of debug settings. (#5916) 2024-07-04 18:19:06 +08:00
-LAN-
d7f75d17cc Chore/remove-unused-code (#5917) 2024-07-04 18:18:26 +08:00
Joe
5d9ad430af feat: knowledge admin role (#5965)
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
2024-07-04 16:21:40 +08:00
非法操作
46eca01fa3 fix: no json output vars in front-page tool (#5943) 2024-07-04 16:15:56 +08:00
AIxGEEK
4d9c22bfc6 refactor: optimize-the-performance-of-var-reference-picker (#5918) 2024-07-04 15:33:36 +08:00
Joel
52e59cf4df chore: enchance firecrawl user experience (#5958) 2024-07-04 15:26:38 +08:00
Joe
688b8fe114 fix: langfuse logical operator error (#5948) 2024-07-04 13:47:15 +08:00
longzhihun
aecdfa2d5c feat: add claude3 function calling (#5889) 2024-07-03 22:21:02 +08:00
-LAN-
cb8feb732f refactor: Create a dify_config with Pydantic. (#5938) 2024-07-03 21:09:23 +08:00
orangeclk
c490bdfbf9 fix: zhipuai pytest correction (#5934) 2024-07-03 19:19:33 +08:00
-LAN-
e7494d632c docs(api/core/tools/docs/en_US/tool_scale_out.md): Format by markdownlint. (#5903) 2024-07-03 13:13:38 +08:00
takatost
e3006f98c9 chore: remove dify SaaS URL in default configs (#5888) 2024-07-02 22:49:18 +08:00
crazywoola
b34baf1e3a feat: pr template (#5886) 2024-07-02 22:38:17 +08:00
quicksand
372dc7ac1a fix bug : TencentVectorDBConfig Add TENCENT_VECTOR_DB_DATABASE (#5879) 2024-07-02 21:31:14 +08:00
-LAN-
66a62e6c13 refactor(api/core/app/apps/base_app_generator.py): improve input validation and sanitization in BaseAppGenerator (#5866) 2024-07-02 18:58:07 +08:00
Joel
04c0a9ad45 chore: click area that trigger showing tracing config is too large (#5878) 2024-07-02 18:28:41 +08:00
Jyong
0944ca9d91 Fix/remove tsne position test (#5858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-07-02 17:57:42 +08:00
Chenhe Gu
d468f8b75c Fix/docker env namings (#5857)
Co-authored-by: dahuahua <38651850@qq.com>
2024-07-02 16:14:34 +08:00
Masashi Tomooka
6e256507d3 doc: docker-compose won't start due to wrong README (#5859) 2024-07-02 16:10:22 +08:00
Joel
a33ce09e6e fix:retieval setting document link 404 (#5861) 2024-07-02 16:02:17 +08:00
quicksand
9f973bb703 chore:remove .env.example duplicate key (#5853) 2024-07-02 15:10:18 +08:00
Joel
6d0a605c5f fix: not show opening question if the opening message is empty (#5856) 2024-07-02 15:04:02 +08:00
Xingyu
a948bf6ee8 fix: react.js error 185 maximum update depth exceeded in streaming responses during conversation (#5849)
Co-authored-by: 林星宇 <xy@udxx.cn>
2024-07-02 14:26:56 +08:00
680 changed files with 16559 additions and 5589 deletions

View File

@@ -1,6 +1,7 @@
#!/bin/bash
cd web && npm install
pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc

View File

@@ -1,13 +1,21 @@
# Checklist:
> [!IMPORTANT]
> Please review the checklist below before submitting your pull request.
- [ ] Please open an issue before creating a PR or link to an existing issue
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
# Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. If it fixes a bug or resolves a feature request, be sure to link to that issue. Close issue syntax: `Fixes #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
Fixes # (issue)
Fixes
## Type of Change
Please delete options that are not relevant.
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
@@ -15,18 +23,12 @@ Please delete options that are not relevant.
- [ ] Improvement, including but not limited to code refactoring, performance optimization, and UI/UX improvement
- [ ] Dependency upgrade
# How Has This Been Tested?
# Testing Instructions
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
- [ ] TODO
- [ ] Test A
- [ ] Test B
# Suggested Checklist:
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] My changes generate no new warnings
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [ ] `optional` I have made corresponding changes to the documentation
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
- [ ] `optional` New and existing unit tests pass locally with my changes

View File

@@ -75,7 +75,7 @@ jobs:
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@@ -89,5 +89,6 @@ jobs:
pgvecto-rs
pgvector
chroma
myscale
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@@ -48,18 +48,18 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5

2
.gitignore vendored
View File

@@ -174,3 +174,5 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json
pyrightconfig.json
.idea/

1
.vscode/launch.json vendored
View File

@@ -13,7 +13,6 @@
"jinja": true,
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"args": [

View File

@@ -83,7 +83,7 @@ OCI_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
VECTOR_STORE=weaviate
# Weaviate configuration
@@ -106,6 +106,14 @@ MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration
MYSCALE_HOST=127.0.0.1
MYSCALE_PORT=8123
MYSCALE_USER=default
MYSCALE_PASSWORD=
MYSCALE_DATABASE=default
MYSCALE_FTS_PARAMS=
# Relyt configuration
RELYT_HOST=127.0.0.1
RELYT_PORT=5432
@@ -151,6 +159,16 @@ CHROMA_DATABASE=default_database
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
CHROMA_AUTH_CREDENTIALS=difyai123456
# AnalyticDB configuration
ANALYTICDB_KEY_ID=your-ak
ANALYTICDB_KEY_SECRET=your-sk
ANALYTICDB_REGION_ID=cn-hangzhou
ANALYTICDB_INSTANCE_ID=gp-ab123456
ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1
OPENSEARCH_PORT=9200
@@ -237,4 +255,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0

View File

@@ -5,8 +5,7 @@ WORKDIR /app/api
# Install Poetry
ENV POETRY_VERSION=1.8.3
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir --upgrade poetry==${POETRY_VERSION}
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
ENV POETRY_CACHE_DIR=/tmp/poetry_cache

View File

@@ -11,6 +11,7 @@
```bash
cd ../docker
cp middleware.env.example middleware.env
docker compose -f docker-compose.middleware.yaml -p dify up -d
cd ../api
```

View File

@@ -1,8 +1,8 @@
import os
from configs.app_config import DifyConfig
from configs import dify_config
if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
if os.environ.get("DEBUG", "false").lower() != 'true':
from gevent import monkey
monkey.patch_all()
@@ -43,6 +43,8 @@ from extensions import (
from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
# TODO: Find a way to avoid importing models here
from models import account, dataset, model, source, task, tool, tools, web
from services.account_service import AccountService
@@ -81,7 +83,7 @@ def create_flask_app_with_configs() -> Flask:
with configs loaded from .env file
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(DifyConfig().model_dump())
dify_app.config.from_mapping(dify_config.model_dump())
# populate configs into system environment variables
for key, value in dify_app.config.items():

View File

@@ -8,6 +8,7 @@ import click
from flask import current_app
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
@@ -112,7 +113,7 @@ def reset_encrypt_key_pair():
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if current_app.config['EDITION'] != 'SELF_HOSTED':
if dify_config.EDITION != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
return
@@ -336,6 +337,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@@ -0,0 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig()

View File

@@ -1,4 +1,5 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field, computed_field
from pydantic_settings import SettingsConfigDict
from configs.deploy import DeploymentConfig
from configs.enterprise import EnterpriseFeatureConfig
@@ -8,13 +9,7 @@ from configs.middleware import MiddlewareConfig
from configs.packaging import PackagingInfo
# TODO: Both `BaseModel` and `BaseSettings` has `model_config` attribute but they are in different types.
# This inheritance is depends on the order of the classes.
# It is better to use `BaseSettings` as the base class.
class DifyConfig(
# based on pydantic-settings
BaseSettings,
# Packaging info
PackagingInfo,
@@ -34,12 +29,39 @@ class DifyConfig(
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
):
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
model_config = SettingsConfigDict(
# read from dotenv format config file
env_file='.env',
env_file_encoding='utf-8',
frozen=True,
# ignore extra attributes
extra='ignore',
)
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class DeploymentConfig(BaseModel):
class DeploymentConfig(BaseSettings):
"""
Deployment configs
"""

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class EnterpriseFeatureConfig(BaseModel):
class EnterpriseFeatureConfig(BaseSettings):
"""
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**

View File

@@ -1,5 +1,3 @@
from pydantic import BaseModel
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class NotionConfig(BaseModel):
class NotionConfig(BaseSettings):
"""
Notion integration configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeFloat
from pydantic import Field, NonNegativeFloat
from pydantic_settings import BaseSettings
class SentryConfig(BaseModel):
class SentryConfig(BaseSettings):
"""
Sentry configs
"""

View File

@@ -1,11 +1,12 @@
from typing import Optional
from pydantic import AliasChoices, BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
class SecurityConfig(BaseModel):
class SecurityConfig(BaseSettings):
"""
Secret Key configs
"""
@@ -17,8 +18,12 @@ class SecurityConfig(BaseModel):
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description='Expiry time in hours for reset token',
default=24,
)
class AppExecutionConfig(BaseModel):
class AppExecutionConfig(BaseSettings):
"""
App Execution configs
"""
@@ -26,9 +31,13 @@ class AppExecutionConfig(BaseModel):
description='execution timeout in seconds for app execution',
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description='max active request per app, 0 means unlimited',
default=0,
)
class CodeExecutionSandboxConfig(BaseModel):
class CodeExecutionSandboxConfig(BaseSettings):
"""
Code Execution Sandbox configs
"""
@@ -43,36 +52,36 @@ class CodeExecutionSandboxConfig(BaseModel):
)
class EndpointConfig(BaseModel):
class EndpointConfig(BaseSettings):
"""
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description='The backend URL prefix of the console API.'
'used to concatenate the login authorization callback or notion integration callback.',
default='https://cloud.dify.ai',
default='',
)
CONSOLE_WEB_URL: str = Field(
description='The front-end URL prefix of the console web.'
'used to concatenate some front-end addresses and for CORS configuration use.',
default='https://cloud.dify.ai',
default='',
)
SERVICE_API_URL: str = Field(
description='Service API Url prefix.'
'used to display Service API Base Url to the front-end.',
default='https://api.dify.ai',
default='',
)
APP_WEB_URL: str = Field(
description='WebApp Url prefix.'
'used to display WebAPP API Base Url to the front-end.',
default='https://udify.app',
default='',
)
class FileAccessConfig(BaseModel):
class FileAccessConfig(BaseSettings):
"""
File Access configs
"""
@@ -82,7 +91,7 @@ class FileAccessConfig(BaseModel):
'Url is signed and has expiration time.',
validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
alias_priority=1,
default='https://cloud.dify.ai',
default='',
)
FILES_ACCESS_TIMEOUT: int = Field(
@@ -91,7 +100,7 @@ class FileAccessConfig(BaseModel):
)
class FileUploadConfig(BaseModel):
class FileUploadConfig(BaseSettings):
"""
File Uploading configs
"""
@@ -116,7 +125,7 @@ class FileUploadConfig(BaseModel):
)
class HttpConfig(BaseModel):
class HttpConfig(BaseSettings):
"""
HTTP configs
"""
@@ -148,7 +157,7 @@ class HttpConfig(BaseModel):
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
class InnerAPIConfig(BaseModel):
class InnerAPIConfig(BaseSettings):
"""
Inner API configs
"""
@@ -163,7 +172,7 @@ class InnerAPIConfig(BaseModel):
)
class LoggingConfig(BaseModel):
class LoggingConfig(BaseSettings):
"""
Logging configs
"""
@@ -195,7 +204,7 @@ class LoggingConfig(BaseModel):
)
class ModelLoadBalanceConfig(BaseModel):
class ModelLoadBalanceConfig(BaseSettings):
"""
Model load balance configs
"""
@@ -205,7 +214,7 @@ class ModelLoadBalanceConfig(BaseModel):
)
class BillingConfig(BaseModel):
class BillingConfig(BaseSettings):
"""
Platform Billing Configurations
"""
@@ -215,7 +224,7 @@ class BillingConfig(BaseModel):
)
class UpdateConfig(BaseModel):
class UpdateConfig(BaseSettings):
"""
Update configs
"""
@@ -225,7 +234,7 @@ class UpdateConfig(BaseModel):
)
class WorkflowConfig(BaseModel):
class WorkflowConfig(BaseSettings):
"""
Workflow feature configs
"""
@@ -246,7 +255,7 @@ class WorkflowConfig(BaseModel):
)
class OAuthConfig(BaseModel):
class OAuthConfig(BaseSettings):
"""
oauth configs
"""
@@ -276,7 +285,7 @@ class OAuthConfig(BaseModel):
)
class ModerationConfig(BaseModel):
class ModerationConfig(BaseSettings):
"""
Moderation in app configs.
"""
@@ -288,7 +297,7 @@ class ModerationConfig(BaseModel):
)
class ToolConfig(BaseModel):
class ToolConfig(BaseSettings):
"""
Tool configs
"""
@@ -299,7 +308,7 @@ class ToolConfig(BaseModel):
)
class MailConfig(BaseModel):
class MailConfig(BaseSettings):
"""
Mail Configurations
"""
@@ -355,7 +364,7 @@ class MailConfig(BaseModel):
)
class RagEtlConfig(BaseModel):
class RagEtlConfig(BaseSettings):
"""
RAG ETL Configurations.
"""
@@ -381,7 +390,7 @@ class RagEtlConfig(BaseModel):
)
class DataSetConfig(BaseModel):
class DataSetConfig(BaseSettings):
"""
Dataset configs
"""
@@ -391,8 +400,13 @@ class DataSetConfig(BaseModel):
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description='whether to enable dataset operator',
default=False,
)
class WorkspaceConfig(BaseModel):
class WorkspaceConfig(BaseSettings):
"""
Workspace configs
"""
@@ -403,7 +417,7 @@ class WorkspaceConfig(BaseModel):
)
class IndexingConfig(BaseModel):
class IndexingConfig(BaseSettings):
"""
Indexing configs.
"""
@@ -414,7 +428,7 @@ class IndexingConfig(BaseModel):
)
class ImageFormatConfig(BaseModel):
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description='multi model send image format, support base64, url, default is base64',
default='base64',

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeInt
from pydantic import Field, NonNegativeInt
from pydantic_settings import BaseSettings
class HostedOpenAiConfig(BaseModel):
class HostedOpenAiConfig(BaseSettings):
"""
Hosted OpenAI service config
"""
@@ -68,7 +69,7 @@ class HostedOpenAiConfig(BaseModel):
)
class HostedAzureOpenAiConfig(BaseModel):
class HostedAzureOpenAiConfig(BaseSettings):
"""
Hosted OpenAI service config
"""
@@ -94,7 +95,7 @@ class HostedAzureOpenAiConfig(BaseModel):
)
class HostedAnthropicConfig(BaseModel):
class HostedAnthropicConfig(BaseSettings):
"""
Hosted Azure OpenAI service config
"""
@@ -125,7 +126,7 @@ class HostedAnthropicConfig(BaseModel):
)
class HostedMinmaxConfig(BaseModel):
class HostedMinmaxConfig(BaseSettings):
"""
Hosted Minmax service config
"""
@@ -136,7 +137,7 @@ class HostedMinmaxConfig(BaseModel):
)
class HostedSparkConfig(BaseModel):
class HostedSparkConfig(BaseSettings):
"""
Hosted Spark service config
"""
@@ -147,7 +148,7 @@ class HostedSparkConfig(BaseModel):
)
class HostedZhipuAIConfig(BaseModel):
class HostedZhipuAIConfig(BaseSettings):
"""
Hosted Minmax service config
"""
@@ -158,7 +159,7 @@ class HostedZhipuAIConfig(BaseModel):
)
class HostedModerationConfig(BaseModel):
class HostedModerationConfig(BaseSettings):
"""
Hosted Moderation service config
"""
@@ -174,7 +175,7 @@ class HostedModerationConfig(BaseModel):
)
class HostedFetchAppTemplateConfig(BaseModel):
class HostedFetchAppTemplateConfig(BaseSettings):
"""
Hosted Moderation service config
"""

View File

@@ -1,6 +1,7 @@
from typing import Any, Optional
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
@@ -9,8 +10,10 @@ from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorag
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
@@ -22,7 +25,7 @@ from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseModel):
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description='storage type,'
' default to `local`,'
@@ -36,14 +39,14 @@ class StorageConfig(BaseModel):
)
class VectorStoreConfig(BaseModel):
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description='vector store type',
default=None,
)
class KeywordStoreConfig(BaseModel):
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description='keyword store type',
default='jieba',
@@ -81,6 +84,11 @@ class DatabaseConfig:
default='',
)
DB_EXTRAS: str = Field(
description='db extras options. Example: keepalives_idle=60&keepalives=1',
default='',
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description='db uri scheme',
default='postgresql',
@@ -89,7 +97,12 @@ class DatabaseConfig:
@computed_field
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = f"?client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else ""
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
if self.DB_CHARSET
else self.DB_EXTRAS
).strip("&")
db_extras = f"?{db_extras}" if db_extras else ""
return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{db_extras}")
@@ -114,7 +127,7 @@ class DatabaseConfig:
default=False,
)
SQLALCHEMY_ECHO: bool = Field(
SQLALCHEMY_ECHO: bool | str = Field(
description='whether to enable SqlAlchemy echo',
default=False,
)
@@ -172,8 +185,10 @@ class MiddlewareConfig(
# configs of vdb and vdb providers
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
MilvusConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,
PGVectorConfig,

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class RedisConfig(BaseModel):
class RedisConfig(BaseSettings):
"""
Redis configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class AliyunOSSStorageConfig(BaseModel):
class AliyunOSSStorageConfig(BaseSettings):
"""
Aliyun storage configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class S3StorageConfig(BaseModel):
class S3StorageConfig(BaseSettings):
"""
S3 storage configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class AzureBlobStorageConfig(BaseModel):
class AzureBlobStorageConfig(BaseSettings):
"""
Azure Blob storage configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class GoogleCloudStorageConfig(BaseModel):
class GoogleCloudStorageConfig(BaseSettings):
"""
Google Cloud storage configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class OCIStorageConfig(BaseModel):
class OCIStorageConfig(BaseSettings):
"""
OCI storage configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class TencentCloudCOSStorageConfig(BaseModel):
class TencentCloudCOSStorageConfig(BaseSettings):
"""
Tencent Cloud COS storage configs
"""

View File

@@ -0,0 +1,44 @@
from typing import Optional
from pydantic import BaseModel, Field
class AnalyticdbConfig(BaseModel):
"""
Configuration for connecting to AnalyticDB.
Refer to the following documentation for details on obtaining credentials:
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None,
description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
)
ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ChromaConfig(BaseModel):
class ChromaConfig(BaseSettings):
"""
Chroma configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class MilvusConfig(BaseModel):
class MilvusConfig(BaseSettings):
"""
Milvus configs
"""

View File

@@ -0,0 +1,39 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
class MyScaleConfig(BaseModel):
"""
MyScale configs
"""
MYSCALE_HOST: Optional[str] = Field(
description='MyScale host',
default=None,
)
MYSCALE_PORT: Optional[PositiveInt] = Field(
description='MyScale port',
default=8123,
)
MYSCALE_USER: Optional[str] = Field(
description='MyScale user',
default=None,
)
MYSCALE_PASSWORD: Optional[str] = Field(
description='MyScale password',
default=None,
)
MYSCALE_DATABASE: Optional[str] = Field(
description='MyScale database name',
default=None,
)
MYSCALE_FTS_PARAMS: Optional[str] = Field(
description='MyScale fts index parameters',
default=None,
)

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OpenSearchConfig(BaseModel):
class OpenSearchConfig(BaseSettings):
"""
OpenSearch configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class OracleConfig(BaseModel):
class OracleConfig(BaseSettings):
"""
ORACLE configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PGVectorConfig(BaseModel):
class PGVectorConfig(BaseSettings):
"""
PGVector configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PGVectoRSConfig(BaseModel):
class PGVectoRSConfig(BaseSettings):
"""
PGVectoRS configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class QdrantConfig(BaseModel):
class QdrantConfig(BaseSettings):
"""
Qdrant configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class RelytConfig(BaseModel):
class RelytConfig(BaseSettings):
"""
Relyt configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
class TencentVectorDBConfig(BaseModel):
class TencentVectorDBConfig(BaseSettings):
"""
Tencent Vector configs
"""
@@ -24,7 +25,7 @@ class TencentVectorDBConfig(BaseModel):
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description='Tencent Vector password',
description='Tencent Vector username',
default=None,
)
@@ -38,7 +39,12 @@ class TencentVectorDBConfig(BaseModel):
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: PositiveInt = Field(
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description='Tencent Vector replicas',
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description='Tencent Vector Database',
default=None,
)

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class TiDBVectorConfig(BaseModel):
class TiDBVectorConfig(BaseSettings):
"""
TiDB Vector configs
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class WeaviateConfig(BaseModel):
class WeaviateConfig(BaseSettings):
"""
Weaviate configs
"""

View File

@@ -1,14 +1,15 @@
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class PackagingInfo(BaseModel):
class PackagingInfo(BaseSettings):
"""
Packaging build information
"""
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.6.12-fix1',
default='0.6.13',
)
COMMIT_SHA: str = Field(

View File

@@ -14,7 +14,7 @@ language_timezone_mapping = {
'vi-VN': 'Asia/Ho_Chi_Minh',
'ro-RO': 'Europe/Bucharest',
'pl-PL': 'Europe/Warsaw',
'hi-IN': 'Asia/Kolkata'
'hi-IN': 'Asia/Kolkata',
}
languages = list(language_timezone_mapping.keys())

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,4 @@
TTS_AUTO_PLAY_TIMEOUT = 5
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02

View File

@@ -30,7 +30,7 @@ from .app import (
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
# Import billing controllers
from .billing import billing

View File

@@ -134,6 +134,7 @@ class AppApi(Resource):
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
args = parser.parse_args()
app_service = AppService()

View File

@@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
@get_app_model
def post(self, app_model):
from werkzeug.exceptions import InternalServerError
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
voice=request.form['voice'],
streaming=False
text=text,
message_id=message_id,
voice=voice
)
return {'data': response.data.decode('latin1')}
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()

View File

@@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
)
return helper.compact_generate_response(response)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -6,6 +6,7 @@ from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@@ -16,11 +17,11 @@ from ..wraps import account_initialization_required
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
client_secret=current_app.config.get(
'NOTION_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
if not dify_config.NOTION_CLIENT_ID or not dify_config.NOTION_CLIENT_SECRET:
return {}
notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
client_secret=dify_config.NOTION_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
OAUTH_PROVIDERS = {
'notion': notion_oauth
@@ -39,8 +40,10 @@ class OAuthDataSource(Resource):
print(vars(oauth_provider))
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
internal_secret = dify_config.NOTION_INTERNAL_SECRET
if not internal_secret:
return {'error': 'Internal secret is not set'},
oauth_provider.save_internal_access_token(internal_secret)
return { 'data': '' }
else:
@@ -60,13 +63,13 @@ class OAuthDataSourceCallback(Resource):
if 'code' in request.args:
code = request.args.get('code')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource):

View File

@@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
error_code = 'auth_failed'
description = "{message}"
code = 500
class InvalidEmailError(BaseHTTPException):
error_code = 'invalid_email'
description = "The email address is not valid."
code = 400
class PasswordMismatchError(BaseHTTPException):
error_code = 'password_mismatch'
description = "The passwords do not match."
code = 400
class InvalidTokenError(BaseHTTPException):
error_code = 'invalid_or_expired_token'
description = "The token is invalid or has expired."
code = 400
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = 'password_reset_rate_limit_exceeded'
description = "Password reset rate limit exceeded. Try again later."
code = 429

View File

@@ -0,0 +1,107 @@
import base64
import logging
import secrets
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.auth.error import (
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
PasswordResetRateLimitExceededError,
)
from controllers.console.setup import setup_required
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json')
args = parser.parse_args()
email = args['email']
if not email_validate(email):
raise InvalidEmailError()
account = Account.query.filter_by(email=email).first()
if account:
try:
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
else:
# Return success to avoid revealing email registration status
logging.warning(f"Attempt to reset password for unregistered email: {email}")
return {"result": "success"}
class ForgotPasswordCheckApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
token = args['token']
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
return {'is_valid': False, 'email': None}
return {'is_valid': True, 'email': reset_data.get('email')}
class ForgotPasswordResetApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
args = parser.parse_args()
new_password = args['new_password']
password_confirm = args['password_confirm']
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError()
token = args['token']
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
raise InvalidTokenError()
AccountService.revoke_reset_password_token(token)
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get('email')).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
return {'result': 'success'}
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')

View File

@@ -1,7 +1,7 @@
from typing import cast
import flask_login
from flask import current_app, request
from flask import request
from flask_restful import Resource, reqparse
import services
@@ -56,14 +56,14 @@ class LogoutApi(Resource):
class ResetPasswordApi(Resource):
@setup_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('email', type=email, required=True, location='json')
args = parser.parse_args()
# parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json')
# args = parser.parse_args()
# import mailchimp_transactional as MailchimpTransactional
# from mailchimp_transactional.api_client import ApiClientError
account = {'email': args['email']}
# account = {'email': args['email']}
# account = AccountService.get_by_email(args['email'])
# if account is None:
# raise ValueError('Email not found')
@@ -71,22 +71,22 @@ class ResetPasswordApi(Resource):
# AccountService.update_password(account, new_password)
# todo: Send email
MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
# MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY']
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
message = {
'from_email': 'noreply@example.com',
'to': [{'email': account.email}],
'subject': 'Reset your Dify password',
'html': """
<p>Dear User,</p>
<p>The Dify team has generated a new password for you, details as follows:</p>
<p><strong>{new_password}</strong></p>
<p>Please change your password to log in as soon as possible.</p>
<p>Regards,</p>
<p>The Dify Team</p>
"""
}
# message = {
# 'from_email': 'noreply@example.com',
# 'to': [{'email': account['email']}],
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# """
# }
# response = mailchimp.messages.send({
# 'message': message,

View File

@@ -6,6 +6,7 @@ import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from configs import dify_config
from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
@@ -18,22 +19,24 @@ from .. import api
def get_oauth_providers():
with current_app.app_context():
github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'),
client_secret=current_app.config.get(
'GITHUB_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET:
github_oauth = None
else:
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
else:
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
)
google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
client_secret=current_app.config.get(
'GOOGLE_CLIENT_SECRET'),
redirect_uri=current_app.config.get(
'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
OAUTH_PROVIDERS = {
'github': github_oauth,
'google': google_oauth
}
OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
return OAUTH_PROVIDERS
@@ -63,8 +66,7 @@ class OAuthCallback(Resource):
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
logging.exception(
f"An error occurred during the OAuth process with {provider}: {e.response.text}")
logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
return {'error': 'OAuth process failed'}, 400
account = _generate_account(provider, user_info)
@@ -81,7 +83,7 @@ class OAuthCallback(Resource):
token = AccountService.login(account, ip_address=get_remote_ip(request))
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@@ -101,11 +103,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Create account
account_name = user_info.name if user_info.name else 'Dify'
account = RegisterService.register(
email=user_info.email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
# Set interface language

View File

@@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
from libs.login import login_required
from models.dataset import Dataset, Document, DocumentSegment
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name):
@@ -85,6 +85,12 @@ class DatasetListApi(Resource):
else:
item['embedding_available'] = True
if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({'partial_member_list': part_users_list})
else:
item.update({'partial_member_list': []})
response = {
'data': data,
'has_more': len(datasets) == limit,
@@ -108,8 +114,8 @@ class DatasetListApi(Resource):
help='Invalid indexing technique.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
try:
@@ -140,6 +146,10 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
@@ -163,6 +173,11 @@ class DatasetApi(Resource):
data['embedding_available'] = False
else:
data['embedding_available'] = True
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
return data, 200
@setup_required
@@ -188,17 +203,21 @@ class DatasetApi(Resource):
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
)
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
args = parser.parse_args()
data = request.get_json()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get('permission'), data.get('partial_member_list')
)
dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
@@ -206,7 +225,20 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
return marshal(dataset, dataset_detail_fields), 200
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get('partial_member_list')
)
else:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({'partial_member_list': partial_member_list})
return result_data, 200
@setup_required
@login_required
@@ -215,11 +247,12 @@ class DatasetApi(Resource):
dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_editor or current_user.is_dataset_operator:
raise Forbidden()
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {'result': 'success'}, 204
else:
raise NotFound("Dataset not found.")
@@ -515,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
@@ -539,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH,
@@ -569,6 +602,27 @@ class DatasetErrorDocs(Resource):
}, 200
class DatasetPermissionUserListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
'data': partial_members_list,
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
@@ -582,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')

View File

@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.is_dataset_editor:
raise Forbidden()
try:
@@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
if args['indexing_technique'] == 'high_quality':
try:
model_manager = ModelManager()
@@ -344,7 +349,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError()
indexing_runner.calculate_tokens(document)
data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
@@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
# check user's permission
DatasetService.check_dataset_permission(dataset, current_user)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
document = self.get_document(dataset_id, document_id)
indexing_cache_key = 'document_{}_indexing'.format(document.id)
cache_result = redis_client.get(indexing_cache_key)
@@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()

View File

@@ -19,6 +19,7 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
from flask_restful import reqparse
app_model = installed_app.app
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
message_id=message_id,
voice=voice
)
return {'data': response.data.decode('latin1')}
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
@@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id')

View File

@@ -36,7 +36,7 @@ class TagListApi(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
def patch(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()

View File

@@ -245,6 +245,8 @@ class AccountIntegrateApi(Resource):
return {'data': integrate_data}
# Register API resources
api.add_resource(AccountInitApi, '/account/init')
api.add_resource(AccountProfileApi, '/account/profile')

View File

@@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
return {'result': 'success'}
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
api.add_resource(MemberListApi, '/workspaces/current/members')
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')

View File

@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, EndUser
from models.model import App, AppMode, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -72,19 +72,30 @@ class AudioApi(Resource):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=args['text'],
end_user=end_user,
voice=args.get('voice'),
streaming=args['streaming']
message_id=message_id,
end_user=end_user.external_user_id,
voice=voice
)
return response

View File

@@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@@ -69,7 +74,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@@ -132,7 +137,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import App, AppMode, EndUser
@@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App
from models.model import App, AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
from flask_restful import reqparse
try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get(
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
message_id=message_id,
end_user=end_user.external_user_id,
voice=request.form['voice'] if request.form.get('voice') else None,
streaming=False
voice=voice
)
return {'data': response.data.decode('latin1')}
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()

View File

@@ -114,6 +114,10 @@ class VariableEntity(BaseModel):
default: Optional[str] = None
hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel):
"""

View File

@@ -0,0 +1,135 @@
import base64
import concurrent.futures
import logging
import queue
import re
import threading
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
class AudioTrunk:
def __init__(self, status: str, audio):
self.audio = audio
self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
)
def _process_future(future_queue, audio_queue):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b''))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ''
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]')
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get('value')
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def _runtime(self):
future_queue = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
self.msg_text += message.event.chunk.delta.message.content
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ''
except Exception as e:
self.logger.warning(e)
break
future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None:
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self.last_message
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio
except queue.Empty:
return None
def _extract_sentence(self, org_text):
tx = self.match.finditer(org_text)
start = 0
result = []
for i in tx:
end = i.regs[0][1]
result.append(org_text[start:end])
start = end
return result, org_text[start:]

View File

@@ -4,6 +4,8 @@ import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
@@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._application_generate_entity.query
)
generator = self._process_stream_response(
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
@@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-> ChatbotAppBlockingResponse:
-> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
@@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[ChatbotAppStreamResponse, None, None]:
-> Generator[ChatbotAppStreamResponse, None, None]:
"""
To stream response.
:return:
@@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
if isinstance(event, QueueErrorEvent):
@@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
if not self._is_stream_out_support(
event=event
event=event
):
continue
@@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
@@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
@@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
@@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
if not self._task_state.current_stream_generate_state:
return None
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
@@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
@@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None

View File

@@ -1,52 +1,56 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity
class BaseAppGenerator:
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
for variable_config in variables:
variable = variable_config.variable
if (variable not in user_inputs
or user_inputs[variable] is None
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
if variable_config.required:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value is not None:
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
if '.' in value:
value = float(value)
else:
value = int(value)
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
if value and isinstance(value, str):
filtered_inputs[variable] = value.replace('\x00', '')
else:
filtered_inputs[variable] = value if value is not None else None
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
return float(user_input_value)
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
return user_input_value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):
return value.replace('\x00', '')
return value

View File

@@ -51,7 +51,6 @@ class AppQueueManager:
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
start_time = time.time()
last_ping_time = 0
while True:
try:
message = self._q.get(timeout=1)

View File

@@ -94,7 +94,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
call_depth=call_depth,
)
def _generate(
@@ -104,7 +103,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@@ -166,10 +164,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}

View File

@@ -1,7 +1,10 @@
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
@@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
@@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user)
db.session.close()
generator = self._process_stream_response(
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
@@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
@@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
if isinstance(event, QueueErrorEvent):
@@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
continue
if publisher:
publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.

View File

@@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
AdvancedChatTaskState entity
@@ -86,6 +87,8 @@ class StreamEvent(Enum):
ERROR = "error"
MESSAGE = "message"
MESSAGE_END = "message_end"
TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = "agent_thought"
@@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
answer: str
class MessageAudioStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE
audio: str
class MessageAudioEndStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
audio: str
class MessageEndStreamResponse(StreamResponse):
"""
MessageEndStreamResponse entity
@@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
"""
WorkflowStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
"""
WorkflowFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
}
}
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
class IterationNodeNextStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
class IterationNodeCompletedStreamResponse(StreamResponse):
"""
NodeCompletedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
class TextChunkStreamResponse(StreamResponse):
"""
TextChunkStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
"""
TextReplaceStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
ChatbotAppBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
"""
WorkflowAppBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
@@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity

View File

@@ -0,0 +1 @@
from .rate_limit import RateLimit

View File

@@ -0,0 +1,120 @@
import logging
import time
import uuid
from collections.abc import Generator
from datetime import timedelta
from typing import Optional, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class RateLimit:
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf')
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
with redis_client.pipeline() as pipe:
pipe.set(self.max_active_requests_key, self.max_active_requests)
pipe.expire(self.max_active_requests_key, timedelta(days=1))
pipe.execute()
else:
with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: Optional[str] = None) -> str:
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
self.flush_cache()
if self.max_active_requests <= 0:
return RateLimit._UNLIMITED_REQUEST_ID
if not request_id:
request_id = RateLimit.gen_request_key()
active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests))
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id
def exit(self, request_id: str):
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
return
redis_client.hdel(self.active_requests_key, request_id)
@staticmethod
def gen_request_key() -> str:
return str(uuid.uuid4())
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
if isinstance(generator, dict):
return generator
else:
return RateLimitGenerator(self, generator, request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
self.rate_limit = rate_limit
if callable(generator):
self.generator = generator()
else:
self.generator = generator
self.request_id = request_id
self.closed = False
def __iter__(self):
return self
def __next__(self):
if self.closed:
raise StopIteration
try:
return next(self.generator)
except StopIteration:
self.close()
raise
def close(self):
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'):
self.generator.close()

View File

@@ -4,6 +4,8 @@ import time
from collections.abc import Generator
from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
@@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
CompletionAppStreamResponse,
EasyUITaskState,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
@@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config
self._conversation = conversation
self._message = message
@@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation_name_generate_thread = None
def process(
self,
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
@@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._application_generate_entity.query
)
generator = self._process_stream_response(
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
@@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None:
break
audio = publisher.checkAndGetAudio()
if audio is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio,
task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message)
event = message.event
if isinstance(event, QueueErrorEvent):
@@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
"""
Save message.

View File

@@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
description = "Quota Exceeded"
class AppInvokeQuotaExceededError(Exception):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
description = "App Invoke Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support

View File

@@ -1,7 +1,6 @@
import os
import requests
from configs import dify_config
from models.api_based_extension import APIBasedExtensionPoint
@@ -31,10 +30,10 @@ class APIBasedExtensionRequestor:
try:
# proxy support for security
proxies = None
if os.environ.get("SSRF_PROXY_HTTP_URL") and os.environ.get("SSRF_PROXY_HTTPS_URL"):
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = {
'http': os.environ.get("SSRF_PROXY_HTTP_URL"),
'https': os.environ.get("SSRF_PROXY_HTTPS_URL"),
'http': dify_config.SSRF_PROXY_HTTP_URL,
'https': dify_config.SSRF_PROXY_HTTPS_URL,
}
response = requests.request(

View File

@@ -186,7 +186,7 @@ class MessageFileParser:
}
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code == 200:
if response.status_code in {200, 304}:
return True, ""
else:
return False, "URL does not exist."

View File

@@ -1,5 +1,4 @@
import logging
import os
import time
from enum import Enum
from threading import Lock
@@ -9,6 +8,7 @@ from httpx import get, post
from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
@@ -18,8 +18,8 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
# Code Executor
CODE_EXECUTION_ENDPOINT = os.environ.get('CODE_EXECUTION_ENDPOINT', 'http://sandbox:8194')
CODE_EXECUTION_API_KEY = os.environ.get('CODE_EXECUTION_API_KEY', 'dify-sandbox')
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
CODE_EXECUTION_TIMEOUT= (10, 60)

View File

@@ -214,6 +214,61 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
def calculate_tokens(self, tenant_id: str, tokens: int, dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
@@ -730,7 +785,7 @@ class IndexingRunner:
self._check_document_paused_status(dataset_document.id)
tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
if embedding_model_instance:
tokens += sum(
embedding_model_instance.get_text_embedding_num_tokens(
[document.page_content]

View File

@@ -12,7 +12,8 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from extensions.ext_database import db
from models.model import AppMode, Conversation, Message
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
class TokenBufferMemory:
@@ -30,33 +31,46 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
query = db.session.query(Message).filter(
query = db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id
).filter(
Message.conversation_id == self.conversation.id,
Message.answer != ''
).order_by(Message.created_at.desc())
if message_limit and message_limit > 0:
messages = query.limit(message_limit).all()
message_limit = message_limit if message_limit <= 500 else 500
else:
messages = query.all()
message_limit = 500
messages = query.limit(message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(
tenant_id=app_record.tenant_id,
app_id=app_record.id
)
prompt_messages = []
for message in messages:
files = message.message_files
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
file_extra_config = FileUploadConfigManager.convert(
message.workflow_run.workflow.features_dict,
is_vision=False
)
if message.workflow_run_id:
workflow_run = (db.session.query(WorkflowRun)
.filter(WorkflowRun.id == message.workflow_run_id).first())
if workflow_run:
file_extra_config = FileUploadConfigManager.convert(
workflow_run.workflow.features_dict,
is_vision=False
)
if file_extra_config:
file_objs = message_file_parser.transform_message_files(
@@ -136,4 +150,4 @@ class TokenBufferMemory:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)
return "\n".join(string_messages)

View File

@@ -264,7 +264,7 @@ class ModelInstance:
user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
-> str:
"""
Invoke large language tts model
@@ -287,8 +287,7 @@ class ModelInstance:
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
voice=voice
)
def _round_robin_invoke(self, function: Callable, *args, **kwargs):

View File

@@ -1,4 +1,6 @@
import hashlib
import logging
import re
import subprocess
import uuid
from abc import abstractmethod
@@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class TTSModel(AIModel):
"""
Model class for ttstext model.
@@ -20,7 +22,7 @@ class TTSModel(AIModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
Invoke large language model
@@ -35,14 +37,15 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
self._is_ffmpeg_installed()
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
return self._invoke(model=model, credentials=credentials, user=user,
content_text=content_text, voice=voice, tenant_id=tenant_id)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
Invoke large language model
@@ -123,26 +126,26 @@ class TTSModel(AIModel):
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
if delimiters is None:
delimiters = set('。!?;\n')
buf = []
word_count = 0
for char in text:
buf.append(char)
if char in delimiters:
if word_count >= limit:
yield ''.join(buf)
buf = []
word_count = 0
else:
word_count += 1
else:
word_count += 1
if buf:
yield ''.join(buf)
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
match = re.compile(pattern)
tx = match.finditer(org_text)
start = 0
result = []
one_sentence = ''
for i in tx:
end = i.regs[0][1]
tmp = org_text[start:end]
if len(one_sentence + tmp) > max_length:
result.append(one_sentence)
one_sentence = ''
one_sentence += tmp
start = end
last_sens = org_text[start:]
if last_sens:
one_sentence += last_sens
if one_sentence != '':
result.append(one_sentence)
return result
@staticmethod
def _is_ffmpeg_installed():

View File

@@ -33,3 +33,4 @@
- deepseek
- hunyuan
- siliconflow
- perfxcloud

View File

@@ -4,7 +4,7 @@ from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from flask import Response
from openai import AzureOpenAI
from pydub import AudioSegment
@@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
from extensions.ext_storage import storage
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
@@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
content_text: str, voice: str, user: Optional[str] = None) -> any:
"""
_invoke text2speech model
@@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
return self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
voice=voice)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:param user: unique user id
:return: text translated to audio file
"""
try:
@@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = []
# Create a thread pool and map the function to the list of sentences
@@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
except Exception as ex:
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
# response.stream_to_file(file_path)
storage.save(file_path, response.read())
# max font is 4096,there is 3500 limit for each request
max_length = 3500
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
response_format="mp3",
input=sentences[i], voice=voice) for i in range(len(sentences))]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
else:
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
response_format="mp3",
input=content_text.strip())
yield from response.__enter__().iter_bytes(1024)
except Exception as ex:
raise InvokeBadRequestError(str(ex))
@@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
for ai_model_entity in TTS_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@@ -5,6 +5,8 @@ model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@@ -68,7 +69,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
# TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via boto3 client
if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
# invoke Cohere models via boto3 client
if "cohere.command-r" in model:
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@@ -151,7 +152,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
"""
Invoke Anthropic large language model
@@ -171,23 +172,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = {
'modelId': model,
'messages': prompt_message_dicts,
'inferenceConfig': inference_config,
'additionalModelRequestFields': additional_model_fields,
}
if system and len(system) > 0:
parameters['system'] = system
if tools:
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
if stream:
response = bedrock_client.converse_stream(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
else:
response = bedrock_client.converse(
modelId=model,
messages=prompt_message_dicts,
system=system,
inferenceConfig=inference_config,
additionalModelRequestFields=additional_model_fields
)
response = bedrock_client.converse(**parameters)
return self._handle_converse_response(model, credentials, response, prompt_messages)
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
@@ -246,12 +248,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
output_tokens = 0
finish_reason = None
index = 0
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_use = {}
for chunk in response['stream']:
if 'messageStart' in chunk:
return_model = model
elif 'messageStop' in chunk:
finish_reason = chunk['messageStop']['stopReason']
elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']
elif 'metadata' in chunk:
input_tokens = chunk['metadata']['usage']['inputTokens']
output_tokens = chunk['metadata']['usage']['outputTokens']
@@ -260,29 +268,49 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
index=index,
message=AssistantPromptMessage(
content=''
content='',
tool_calls=tool_calls
),
finish_reason=finish_reason,
usage=usage
)
)
elif 'contentBlockDelta' in chunk:
chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
)
index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
delta = chunk['contentBlockDelta']['delta']
if 'text' in delta:
chunk_text = delta['text'] if delta['text'] else ''
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else '',
)
)
index = chunk['contentBlockDelta']['contentBlockIndex']
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index+1,
message=assistant_prompt_message,
)
)
elif 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=tool_use['input']
)
)
tool_calls.append(tool_call)
tool_use = {}
except Exception as ex:
raise InvokeError(str(ex))
@@ -312,16 +340,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
system = []
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content=message.content.strip()
if first_loop:
system=message.content
first_loop=False
else:
system+="\n"
system+=message.content
system.append({"text": message.content})
prompt_message_dicts = []
for message in prompt_messages:
@@ -330,6 +352,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return system, prompt_message_dicts
def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
tool_config = {}
configs = []
if tools:
for tool in tools:
configs.append(
{
"toolSpec": {
"name": tool.name,
"description": tool.description,
"inputSchema": {
"json": tool.parameters
}
}
}
)
tool_config["tools"] = configs
return tool_config
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
@@ -379,10 +420,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
if message.tool_calls:
message_dict = {
"role": "assistant", "content":[{
"toolUse": {
"toolUseId": message.tool_calls[0].id,
"name": message.tool_calls[0].function.name,
"input": json.loads(message.tool_calls[0].function.arguments)
}
}]
}
else:
message_dict = {"role": "assistant", "content": [{'text': message.content}]}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = [{'text': message.content}]
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": {"text": message.content}}]
}
}]
}
else:
raise ValueError(f"Got unknown type {message}")
@@ -401,11 +464,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
prefix = model.split('.')[0]
model_name = model.split('.')[1]
if isinstance(prompt_messages, str):
prompt = prompt_messages
else:
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
@@ -489,11 +554,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt_prefix} {content} {human_prompt_postfix}"
body = content
if (isinstance(content, list)):
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt_prefix} {message.content}"
else:
raise ValueError(f"Got unknown type {message}")

View File

@@ -21,7 +21,7 @@ model_properties:
- mode: 'shimmer'
name: 'Shimmer'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 120
word_limit: 3500
audio_type: 'mp3'
max_workers: 5
pricing:

View File

@@ -21,7 +21,7 @@ model_properties:
- mode: 'shimmer'
name: 'Shimmer'
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
word_limit: 120
word_limit: 3500
audio_type: 'mp3'
max_workers: 5
pricing:

View File

@@ -3,7 +3,7 @@ from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from flask import Response
from openai import OpenAI
from pydub import AudioSegment
@@ -11,7 +11,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
from extensions.ext_storage import storage
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
@@ -20,7 +19,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
content_text: str, voice: str, user: Optional[str] = None) -> any:
"""
_invoke text2speech model
@@ -29,22 +28,17 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
# if streaming:
return self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
"""
@@ -79,7 +73,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = []
# Create a thread pool and map the function to the list of sentences
@@ -104,34 +98,40 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
except Exception as ex:
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = OpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
# response.stream_to_file(file_path)
storage.save(file_path, response.read())
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
response_format="mp3",
input=sentences[i], voice=voice) for i in range(len(sentences))]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
else:
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
response_format="mp3",
input=content_text.strip())
yield from response.__enter__().iter_bytes(1024)
except Exception as ex:
raise InvokeBadRequestError(str(ex))

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 21 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -0,0 +1,61 @@
model: Qwen-14B-Chat-Int4
label:
en_US: Qwen-14B-Chat-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@@ -0,0 +1,61 @@
model: Qwen1.5-110B-Chat-GPTQ-Int4
label:
en_US: Qwen1.5-110B-Chat-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 128
min: 1
max: 256
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@@ -0,0 +1,61 @@
model: Qwen1.5-72B-Chat-GPTQ-Int4
label:
en_US: Qwen1.5-72B-Chat-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

Some files were not shown because too many files have changed in this diff Show More