mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 14:58:32 +00:00
Compare commits
164 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f101956d0f | ||
|
|
5c5221f2cc | ||
|
|
ef7e47d162 | ||
|
|
4211b9abbd | ||
|
|
0c0120ef27 | ||
|
|
dacd457478 | ||
|
|
7b03a0316d | ||
|
|
52201d95b1 | ||
|
|
e2cde628bb | ||
|
|
3335fa78fc | ||
|
|
7abc7fa573 | ||
|
|
f6247fe67c | ||
|
|
996a9135f6 | ||
|
|
3599751f93 | ||
|
|
2d186e1e76 | ||
|
|
bb2f46d7cc | ||
|
|
463fbe2680 | ||
|
|
95a7e50137 | ||
|
|
9d93ad1f16 | ||
|
|
44104797d6 | ||
|
|
1548501050 | ||
|
|
de3911e930 | ||
|
|
5a8a901560 | ||
|
|
12d45e9114 | ||
|
|
d057067543 | ||
|
|
560d375e0f | ||
|
|
3388d6636c | ||
|
|
2624a6dcd0 | ||
|
|
b5c2785e10 | ||
|
|
493834d45d | ||
|
|
926546b153 | ||
|
|
56434db4f5 | ||
|
|
688292e6ff | ||
|
|
f7415e1ca4 | ||
|
|
2961fa0e08 | ||
|
|
ad17ff9a92 | ||
|
|
558ab25f51 | ||
|
|
a5db7c9acb | ||
|
|
580297e290 | ||
|
|
79d11ea709 | ||
|
|
99f40a9682 | ||
|
|
e86756cb39 | ||
|
|
1325246da8 | ||
|
|
dfa9a91906 | ||
|
|
5e2926a207 | ||
|
|
9048832a9a | ||
|
|
7d5a385811 | ||
|
|
900e93f758 | ||
|
|
99430a5931 | ||
|
|
c9b4029ce7 | ||
|
|
78c3051585 | ||
|
|
cd4310df25 | ||
|
|
259cff9f22 | ||
|
|
7b7eb00385 | ||
|
|
62b9e5a6f9 | ||
|
|
a399502ecd | ||
|
|
92a840f1b2 | ||
|
|
74fdc16bd1 | ||
|
|
56cfdce453 | ||
|
|
efa8eb379f | ||
|
|
7f095bdc42 | ||
|
|
e20161b3de | ||
|
|
fc8fdbacb4 | ||
|
|
7fde638556 | ||
|
|
be93c19b7e | ||
|
|
967eb81112 | ||
|
|
9f602f73eb | ||
|
|
41de7e76ec | ||
|
|
607a22ad12 | ||
|
|
4b402c4041 | ||
|
|
daccb10d8c | ||
|
|
63f1dd7877 | ||
|
|
79801f5c30 | ||
|
|
9c7a1bc067 | ||
|
|
cf0ff88120 | ||
|
|
e0b67536e0 | ||
|
|
94c7dcc7f1 | ||
|
|
38e155d819 | ||
|
|
efd5575683 | ||
|
|
1a7c213405 | ||
|
|
8e3d60c359 | ||
|
|
924b4fe742 | ||
|
|
7e154a467b | ||
|
|
b90f1581be | ||
|
|
821992e21f | ||
|
|
f0c0ce9db1 | ||
|
|
8ecb9aaa91 | ||
|
|
22258fb0bf | ||
|
|
a725b8bb6e | ||
|
|
bdfdccd511 | ||
|
|
194bc60429 | ||
|
|
430ca3322b | ||
|
|
3d803c2e80 | ||
|
|
fa3dcbb3bc | ||
|
|
ee342063d8 | ||
|
|
bb3bc60f83 | ||
|
|
e7a4cfac4d | ||
|
|
6478aa1c9d | ||
|
|
7b5839335a | ||
|
|
a360af8687 | ||
|
|
36cb25b341 | ||
|
|
e565ecdaef | ||
|
|
f96fdc2970 | ||
|
|
0d04cdc323 | ||
|
|
926f604f09 | ||
|
|
180743612c | ||
|
|
d05f189049 | ||
|
|
ceaa9f1101 | ||
|
|
6f4cbe0bde | ||
|
|
8d4bb9b40d | ||
|
|
1765fe2a29 | ||
|
|
79a710ce98 | ||
|
|
bec5451f12 | ||
|
|
86dfdcb8ec | ||
|
|
42d986b96d | ||
|
|
fbc4ca980c | ||
|
|
80c52e0ea4 | ||
|
|
50b76dd5a2 | ||
|
|
225fcd5e41 | ||
|
|
afffd345bc | ||
|
|
716576043d | ||
|
|
28231d39a4 | ||
|
|
9e23c3d625 | ||
|
|
bdd5869244 | ||
|
|
fc1415d705 | ||
|
|
8218f62478 | ||
|
|
fd354d999d | ||
|
|
ec00b25793 | ||
|
|
967b7d89e3 | ||
|
|
32f8439143 | ||
|
|
0ff8bd2aa9 | ||
|
|
2866383228 | ||
|
|
00ac7edeb3 | ||
|
|
537068cfde | ||
|
|
c3c6a48059 | ||
|
|
5c166b3f40 | ||
|
|
230fa3286b | ||
|
|
061c0b10fd | ||
|
|
32f8a98cf8 | ||
|
|
6c60ecb237 | ||
|
|
c3fae5e801 | ||
|
|
a594e256ae | ||
|
|
41d90c2408 | ||
|
|
7ff42b1b7a | ||
|
|
4d7cfd0de5 | ||
|
|
266d32bd77 | ||
|
|
7e1184c071 | ||
|
|
1ce51e57ab | ||
|
|
142b4fd699 | ||
|
|
cc8feaa483 | ||
|
|
d9d5d35a77 | ||
|
|
9277156b6c | ||
|
|
1490a19fa1 | ||
|
|
9b7adcd4d9 | ||
|
|
a8d32f9964 | ||
|
|
5093337de1 | ||
|
|
f54225568c | ||
|
|
255ff446ba | ||
|
|
9a0dc4bfdc | ||
|
|
9d975750bc | ||
|
|
7c979e6490 | ||
|
|
d60ca1661c | ||
|
|
bb62391a4c | ||
|
|
0b25c0b677 |
@@ -7,5 +7,6 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
13
.github/pull_request_template.md
vendored
13
.github/pull_request_template.md
vendored
@@ -8,16 +8,9 @@ Please include a summary of the change and which issue is fixed. Please also inc
|
||||
|
||||
# Screenshots
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>Before: </td>
|
||||
<td>After: </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>...</td>
|
||||
<td>...</td>
|
||||
</tr>
|
||||
</table>
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
|
||||
# Checklist
|
||||
|
||||
|
||||
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@@ -50,6 +50,9 @@ jobs:
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@@ -9,5 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
|
||||
1
.github/workflows/style.yml
vendored
1
.github/workflows/style.yml
vendored
@@ -37,6 +37,7 @@ jobs:
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
|
||||
|
||||
3
.github/workflows/vdb-tests.yml
vendored
3
.github/workflows/vdb-tests.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -67,6 +67,7 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
@@ -56,20 +56,27 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=opendal
|
||||
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
OPENDAL_SCHEME=fs
|
||||
OPENDAL_FS_ROOT=storage
|
||||
|
||||
# S3 Storage configuration
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
|
||||
S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
|
||||
# Aliyun oss Storage configuration
|
||||
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
@@ -79,6 +86,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
@@ -125,8 +133,8 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -277,6 +285,7 @@ VIKINGDB_SOCKET_TIMEOUT=30
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
USING_UGC_INDEX=False
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
@@ -295,8 +304,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
@@ -381,6 +389,8 @@ LOG_FILE_BACKUP_COUNT=5
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
@@ -389,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
@@ -413,3 +424,7 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
@@ -70,7 +70,6 @@ ignore = [
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
27
api/app.py
27
api/app.py
@@ -1,13 +1,30 @@
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils, version_utils
|
||||
from libs import version_utils
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
|
||||
def is_db_command():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils
|
||||
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
@@ -17,15 +16,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
if isinstance(value, str):
|
||||
os.environ[key] = value
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
@@ -98,3 +88,14 @@ def initialize_extensions(app: DifyApp):
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
|
||||
|
||||
def create_migrations_app():
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,11 +1,51 @@
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from configs.deploy import DeploymentConfig
|
||||
from configs.enterprise import EnterpriseFeatureConfig
|
||||
from configs.extra import ExtraServiceConfig
|
||||
from configs.feature import FeatureConfig
|
||||
from configs.middleware import MiddlewareConfig
|
||||
from configs.packaging import PackagingInfo
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
from .packaging import PackagingInfo
|
||||
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
|
||||
from .remote_settings_sources.apollo import ApolloSettingsSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
super().__init__(settings_cls)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
if not remote_source_name:
|
||||
return {}
|
||||
|
||||
remote_source: RemoteSettingsSource | None = None
|
||||
match remote_source_name:
|
||||
case RemoteSettingsSourceName.APOLLO:
|
||||
remote_source = ApolloSettingsSource(current_state)
|
||||
case _:
|
||||
logger.warning(f"Unsupported remote source: {remote_source_name}")
|
||||
return {}
|
||||
|
||||
d: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
|
||||
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
if field_value is not None:
|
||||
d[field_key] = field_value
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class DifyConfig(
|
||||
@@ -19,6 +59,8 @@ class DifyConfig(
|
||||
MiddlewareConfig,
|
||||
# Extra service configs
|
||||
ExtraServiceConfig,
|
||||
# Remote source configs
|
||||
RemoteSettingsSourceConfig,
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
@@ -35,3 +77,20 @@ class DifyConfig(
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
# Thanks for your concentration and consideration.
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
RemoteSettingsSourceFactory(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
@@ -433,12 +433,28 @@ class WorkflowConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for workflow node execution
|
||||
"""
|
||||
|
||||
MAX_SUBMIT_COUNT: PositiveInt = Field(
|
||||
description="Maximum number of submitted thread count in a ThreadPool for parallel node execution",
|
||||
default=100,
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
@@ -474,6 +490,11 @@ class AuthConfig(BaseSettings):
|
||||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@@ -649,14 +670,9 @@ class IndexingConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class VisionFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
@@ -762,19 +778,20 @@ class FeatureConfig(
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
VisionFormatConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
LoggingConfig,
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
# hosted services config
|
||||
|
||||
@@ -1,54 +1,69 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.lindorm_config import LindormConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
|
||||
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
from .cache.redis_config import RedisConfig
|
||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
from .storage.opendal_storage_config import OpenDALStorageConfig
|
||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
from .vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from .vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from .vdb.upstash_config import UpstashConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
STORAGE_TYPE: Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
] = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
|
||||
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
|
||||
default="local",
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
|
||||
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
|
||||
default="storage",
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,7 +88,7 @@ class KeywordStoreConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
class DatabaseConfig(BaseSettings):
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
@@ -235,6 +250,7 @@ class MiddlewareConfig(
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
OpenDALStorageConfig,
|
||||
S3StorageConfig,
|
||||
SupabaseStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class BaiduOBSStorageConfig(BaseModel):
|
||||
class BaiduOBSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Baidu Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Huawei Cloud Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
9
api/configs/middleware/storage/opendal_storage_config.py
Normal file
9
api/configs/middleware/storage/opendal_storage_config.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenDALStorageConfig(BaseSettings):
|
||||
OPENDAL_SCHEME: str = Field(
|
||||
default="fs",
|
||||
description="OpenDAL scheme.",
|
||||
)
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class SupabaseStorageConfig(BaseModel):
|
||||
class SupabaseStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Supabase Object Storage Service
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
class VolcengineTOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Volcengine Tinder Object Storage (TOS)
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
class AnalyticdbConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class CouchbaseConfig(BaseModel):
|
||||
class CouchbaseConfig(BaseSettings):
|
||||
"""
|
||||
Couchbase configs
|
||||
"""
|
||||
|
||||
@@ -21,3 +21,14 @@ class LindormConfig(BaseSettings):
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
DEFAULT_INDEX_TYPE: Optional[str] = Field(
|
||||
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||
default="hnsw",
|
||||
)
|
||||
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
|
||||
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||
)
|
||||
USING_UGC_INDEX: Optional[bool] = Field(
|
||||
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
class MyScaleConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for MyScale vector database
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class VikingDBConfig(BaseModel):
|
||||
class VikingDBConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.13.1",
|
||||
default="0.14.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
17
api/configs/remote_settings_sources/__init__.py
Normal file
17
api/configs/remote_settings_sources/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .apollo import ApolloSettingsSourceInfo
|
||||
from .base import RemoteSettingsSource
|
||||
from .enums import RemoteSettingsSourceName
|
||||
|
||||
|
||||
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
|
||||
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
|
||||
description="name of remote config source",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]
|
||||
55
api/configs/remote_settings_sources/apollo/__init__.py
Normal file
55
api/configs/remote_settings_sources/apollo/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||
|
||||
from .client import ApolloClient
|
||||
|
||||
|
||||
class ApolloSettingsSourceInfo(BaseSettings):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
APOLLO_APP_ID: Optional[str] = Field(
|
||||
description="apollo app_id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_CLUSTER: Optional[str] = Field(
|
||||
description="apollo cluster",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_CONFIG_URL: Optional[str] = Field(
|
||||
description="apollo config url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_NAMESPACE: Optional[str] = Field(
|
||||
description="apollo namespace",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ApolloSettingsSource(RemoteSettingsSource):
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
self.client = ApolloClient(
|
||||
app_id=configs["APOLLO_APP_ID"],
|
||||
cluster=configs["APOLLO_CLUSTER"],
|
||||
config_url=configs["APOLLO_CONFIG_URL"],
|
||||
start_hot_update=False,
|
||||
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
|
||||
)
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
return field_value, field_name, False
|
||||
303
api/configs/remote_settings_sources/apollo/client.py
Normal file
303
api/configs/remote_settings_sources/apollo/client.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from .python_3x import http_request, makedirs_wrapper
|
||||
from .utils import (
|
||||
CONFIGURATIONS,
|
||||
NAMESPACE_NAME,
|
||||
NOTIFICATION_ID,
|
||||
get_value_from_dict,
|
||||
init_ip,
|
||||
no_key_cache_key,
|
||||
signature,
|
||||
url_encode_wrapper,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApolloClient:
|
||||
def __init__(
|
||||
self,
|
||||
config_url,
|
||||
app_id,
|
||||
cluster="default",
|
||||
secret="",
|
||||
start_hot_update=True,
|
||||
change_listener=None,
|
||||
_notification_map=None,
|
||||
):
|
||||
# Core routing parameters
|
||||
self.config_url = config_url
|
||||
self.cluster = cluster
|
||||
self.app_id = app_id
|
||||
|
||||
# Non-core parameters
|
||||
self.ip = init_ip()
|
||||
self.secret = secret
|
||||
|
||||
# Check the parameter variables
|
||||
|
||||
# Private control variables
|
||||
self._cycle_time = 5
|
||||
self._stopping = False
|
||||
self._cache = {}
|
||||
self._no_key = {}
|
||||
self._hash = {}
|
||||
self._pull_timeout = 75
|
||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||
self._long_poll_thread = None
|
||||
self._change_listener = change_listener # "add" "delete" "update"
|
||||
if _notification_map is None:
|
||||
_notification_map = {"application": -1}
|
||||
self._notification_map = _notification_map
|
||||
self.last_release_key = None
|
||||
# Private startup method
|
||||
self._path_checker()
|
||||
if start_hot_update:
|
||||
self._start_hot_update()
|
||||
|
||||
# start the heartbeat thread
|
||||
heartbeat = threading.Thread(target=self._heart_beat)
|
||||
heartbeat.daemon = True
|
||||
heartbeat.start()
|
||||
|
||||
def get_json_from_net(self, namespace="application"):
|
||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||
)
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
if code == 200:
|
||||
if not body:
|
||||
logger.error(f"get_json_from_net load configs failed, body is {body}")
|
||||
return None
|
||||
data = json.loads(body)
|
||||
data = data["configurations"]
|
||||
return_data = {CONFIGURATIONS: data}
|
||||
return return_data
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("an error occurred in get_json_from_net")
|
||||
return None
|
||||
|
||||
def get_value(self, key, default_val=None, namespace="application"):
|
||||
try:
|
||||
# read memory configuration
|
||||
namespace_cache = self._cache.get(namespace)
|
||||
val = get_value_from_dict(namespace_cache, key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
if no_key in self._no_key:
|
||||
return default_val
|
||||
|
||||
# read the network configuration
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
val = get_value_from_dict(namespace_data, key)
|
||||
if val is not None:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return val
|
||||
|
||||
# read the file configuration
|
||||
namespace_cache = self._get_local_cache(namespace)
|
||||
val = get_value_from_dict(namespace_cache, key)
|
||||
if val is not None:
|
||||
self._update_cache_and_file(namespace_cache, namespace)
|
||||
return val
|
||||
|
||||
# If all of them are not obtained, the default value is returned
|
||||
# and the local cache is set to None
|
||||
self._set_local_cache_none(namespace, key)
|
||||
return default_val
|
||||
except Exception:
|
||||
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
|
||||
return default_val
|
||||
|
||||
# Set the key of a namespace to none, and do not set default val
|
||||
# to ensure the real-time correctness of the function call.
|
||||
# If the user does not have the same default val twice
|
||||
# and the default val is used here, there may be a problem.
|
||||
def _set_local_cache_none(self, namespace, key):
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
self._no_key[no_key] = key
|
||||
|
||||
def _start_hot_update(self):
|
||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||
# when the main thread is launched.
|
||||
self._long_poll_thread.daemon = True
|
||||
self._long_poll_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
logger.info("Stopping listener...")
|
||||
|
||||
# Call the set callback function, and if it is abnormal, try it out
|
||||
def _call_listener(self, namespace, old_kv, new_kv):
|
||||
if self._change_listener is None:
|
||||
return
|
||||
if old_kv is None:
|
||||
old_kv = {}
|
||||
if new_kv is None:
|
||||
new_kv = {}
|
||||
try:
|
||||
for key in old_kv:
|
||||
new_value = new_kv.get(key)
|
||||
old_value = old_kv.get(key)
|
||||
if new_value is None:
|
||||
# If newValue is empty, it means key, and the value is deleted.
|
||||
self._change_listener("delete", namespace, key, old_value)
|
||||
continue
|
||||
if new_value != old_value:
|
||||
self._change_listener("update", namespace, key, new_value)
|
||||
continue
|
||||
for key in new_kv:
|
||||
new_value = new_kv.get(key)
|
||||
old_value = old_kv.get(key)
|
||||
if old_value is None:
|
||||
self._change_listener("add", namespace, key, new_value)
|
||||
except BaseException as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _path_checker(self):
|
||||
if not os.path.isdir(self._cache_file_path):
|
||||
makedirs_wrapper(self._cache_file_path)
|
||||
|
||||
# update the local cache and file cache
|
||||
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||
# update the local cache
|
||||
self._cache[namespace] = namespace_data
|
||||
# update the file cache
|
||||
new_string = json.dumps(namespace_data)
|
||||
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
|
||||
if self._hash.get(namespace) == new_hash:
|
||||
pass
|
||||
else:
|
||||
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
|
||||
file_path.write_text(new_string)
|
||||
self._hash[namespace] = new_hash
|
||||
|
||||
# get the configuration from the local file
|
||||
def _get_local_cache(self, namespace="application"):
|
||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||
if os.path.isfile(cache_file_path):
|
||||
with open(cache_file_path) as f:
|
||||
result = json.loads(f.readline())
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _long_poll(self):
|
||||
notifications = []
|
||||
for key in self._cache:
|
||||
namespace_data = self._cache[key]
|
||||
notification_id = -1
|
||||
if NOTIFICATION_ID in namespace_data:
|
||||
notification_id = self._cache[key][NOTIFICATION_ID]
|
||||
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
|
||||
try:
|
||||
# if the length is 0 it is returned directly
|
||||
if len(notifications) == 0:
|
||||
return
|
||||
url = "{}/notifications/v2".format(self.config_url)
|
||||
params = {
|
||||
"appId": self.app_id,
|
||||
"cluster": self.cluster,
|
||||
"notifications": json.dumps(notifications, ensure_ascii=False),
|
||||
}
|
||||
param_str = url_encode_wrapper(params)
|
||||
url = url + "?" + param_str
|
||||
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
|
||||
http_code = code
|
||||
if http_code == 304:
|
||||
logger.debug("No change, loop...")
|
||||
return
|
||||
if http_code == 200:
|
||||
if not body:
|
||||
logger.error(f"_long_poll load configs failed,body is {body}")
|
||||
return
|
||||
data = json.loads(body)
|
||||
for entry in data:
|
||||
namespace = entry[NAMESPACE_NAME]
|
||||
n_id = entry[NOTIFICATION_ID]
|
||||
logger.info("%s has changes: notificationId=%d", namespace, n_id)
|
||||
self._get_net_and_set_local(namespace, n_id, call_change=True)
|
||||
return
|
||||
else:
|
||||
logger.warning("Sleep...")
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
if not namespace_data:
|
||||
return
|
||||
namespace_data[NOTIFICATION_ID] = n_id
|
||||
old_namespace = self._cache.get(namespace)
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
if self._change_listener is not None and call_change and old_namespace:
|
||||
old_kv = old_namespace.get(CONFIGURATIONS)
|
||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||
self._call_listener(namespace, old_kv, new_kv)
|
||||
|
||||
def _listener(self):
|
||||
logger.info("start long_poll")
|
||||
while not self._stopping:
|
||||
self._long_poll()
|
||||
time.sleep(self._cycle_time)
|
||||
logger.info("stopped, long_poll")
|
||||
|
||||
# add the need for endorsement to the header
|
||||
def _sign_headers(self, url):
|
||||
headers = {}
|
||||
if self.secret == "":
|
||||
return headers
|
||||
uri = url[len(self.config_url) : len(url)]
|
||||
time_unix_now = str(int(round(time.time() * 1000)))
|
||||
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
|
||||
headers["Timestamp"] = time_unix_now
|
||||
return headers
|
||||
|
||||
def _heart_beat(self):
|
||||
while not self._stopping:
|
||||
for namespace in self._notification_map:
|
||||
self._do_heart_beat(namespace)
|
||||
time.sleep(60 * 10) # 10分钟
|
||||
|
||||
def _do_heart_beat(self, namespace):
|
||||
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
if code == 200:
|
||||
if not body:
|
||||
logger.error(f"_do_heart_beat load configs failed,body is {body}")
|
||||
return None
|
||||
data = json.loads(body)
|
||||
if self.last_release_key == data["releaseKey"]:
|
||||
return None
|
||||
self.last_release_key = data["releaseKey"]
|
||||
data = data["configurations"]
|
||||
self._update_cache_and_file(data, namespace)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("an error occurred in _do_heart_beat")
|
||||
return None
|
||||
|
||||
def get_all_dicts(self, namespace):
|
||||
namespace_data = self._cache.get(namespace)
|
||||
if namespace_data is None:
|
||||
net_namespace_data = self.get_json_from_net(namespace)
|
||||
if not net_namespace_data:
|
||||
return namespace_data
|
||||
namespace_data = net_namespace_data.get(CONFIGURATIONS)
|
||||
if namespace_data:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return namespace_data
|
||||
41
api/configs/remote_settings_sources/apollo/python_3x.py
Normal file
41
api/configs/remote_settings_sources/apollo/python_3x.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
from urllib import parse
|
||||
from urllib.error import HTTPError
|
||||
|
||||
# Create an SSL context that allows for a lower level of security
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Create an opener object and pass in a custom SSL context
|
||||
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
|
||||
|
||||
urllib.request.install_opener(opener)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def http_request(url, timeout, headers={}):
|
||||
try:
|
||||
request = urllib.request.Request(url, headers=headers)
|
||||
res = urllib.request.urlopen(request, timeout=timeout)
|
||||
body = res.read().decode("utf-8")
|
||||
return res.code, body
|
||||
except HTTPError as e:
|
||||
if e.code == 304:
|
||||
logger.warning("http_request error,code is 304, maybe you should check secret")
|
||||
return 304, None
|
||||
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
|
||||
raise e
|
||||
|
||||
|
||||
def url_encode(params):
|
||||
return parse.urlencode(params)
|
||||
|
||||
|
||||
def makedirs_wrapper(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
51
api/configs/remote_settings_sources/apollo/utils.py
Normal file
51
api/configs/remote_settings_sources/apollo/utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import hashlib
|
||||
import socket
|
||||
|
||||
from .python_3x import url_encode
|
||||
|
||||
# define constants
|
||||
CONFIGURATIONS = "configurations"
|
||||
NOTIFICATION_ID = "notificationId"
|
||||
NAMESPACE_NAME = "namespaceName"
|
||||
|
||||
|
||||
# add timestamps uris and keys
|
||||
def signature(timestamp, uri, secret):
|
||||
import base64
|
||||
import hmac
|
||||
|
||||
string_to_sign = "" + timestamp + "\n" + uri
|
||||
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
|
||||
return base64.b64encode(hmac_code).decode()
|
||||
|
||||
|
||||
def url_encode_wrapper(params):
|
||||
return url_encode(params)
|
||||
|
||||
|
||||
def no_key_cache_key(namespace, key):
|
||||
return "{}{}{}".format(namespace, len(namespace), key)
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache, key):
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
return None
|
||||
if key in kv_data:
|
||||
return kv_data[key]
|
||||
return None
|
||||
|
||||
|
||||
def init_ip():
|
||||
ip = ""
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 53))
|
||||
ip = s.getsockname()[0]
|
||||
finally:
|
||||
if s:
|
||||
s.close()
|
||||
return ip
|
||||
15
api/configs/remote_settings_sources/base.py
Normal file
15
api/configs/remote_settings_sources/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
class RemoteSettingsSource:
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
return value
|
||||
5
api/configs/remote_settings_sources/enums.py
Normal file
5
api/configs/remote_settings_sources/enums.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RemoteSettingsSourceName(StrEnum):
|
||||
APOLLO = "apollo"
|
||||
@@ -14,11 +14,11 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -65,7 +65,7 @@ class ModelConfigResource(Resource):
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
@@ -97,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
@@ -26,7 +27,7 @@ class TraceAppConfigApi(Resource):
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -48,7 +49,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigCheckError()
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -68,7 +69,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -85,7 +86,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -426,7 +427,21 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
|
||||
@@ -5,8 +5,7 @@ from typing import Optional, Union
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
|
||||
@@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
if document.indexing_status == "completed":
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception(f"Failed to retry document, document id: {document_id}")
|
||||
continue
|
||||
# retry document
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
@@ -20,8 +21,17 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
if app_id:
|
||||
installed_apps = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
|
||||
@@ -4,6 +4,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
@@ -12,6 +13,8 @@ app_fields = {
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -58,6 +59,9 @@ class FileApi(Resource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source == "datasets" and not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
||||
@@ -368,6 +368,7 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
labels=args["labels"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -36,6 +36,29 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
@@ -44,7 +67,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -19,8 +19,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -31,6 +33,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -127,7 +130,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
@@ -318,7 +320,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
@@ -327,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
@@ -361,8 +379,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@@ -376,7 +392,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation.id,
|
||||
@@ -387,6 +403,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
@@ -404,6 +443,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -28,6 +28,39 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
@@ -36,7 +69,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"segment_id": resource["segment_id"],
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -44,19 +44,29 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -44,14 +44,29 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -30,6 +30,35 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
@@ -41,7 +70,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
|
||||
@@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
@@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent,
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
|
||||
@@ -15,8 +15,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -26,6 +28,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -106,7 +109,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -258,29 +260,44 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@@ -288,6 +305,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
@@ -320,8 +338,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@@ -335,7 +351,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
@@ -345,6 +361,30 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
@@ -354,7 +394,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -366,6 +405,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
|
||||
@@ -8,8 +8,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -18,6 +20,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -25,6 +28,7 @@ from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
@@ -32,8 +36,10 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -176,8 +182,12 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
@@ -253,6 +263,36 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self._publish_event(
|
||||
QueueNodeExceptionEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
@@ -382,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.error,
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
start_index=event.start_index,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
@@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
|
||||
ITERATION_START = "iteration_start"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_SUCCEEDED = "node_succeeded"
|
||||
NODE_FAILED = "node_failed"
|
||||
NODE_EXCEPTION = "node_exception"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
@@ -41,6 +43,7 @@ class QueueEvent(StrEnum):
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@@ -113,18 +116,6 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
@@ -249,6 +240,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||
error: str
|
||||
exceptions_count: int
|
||||
|
||||
|
||||
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||
exceptions_count: int
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeStartedEvent(AppQueueEvent):
|
||||
@@ -312,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(AppQueueEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
start_index: int # start index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
@@ -343,6 +376,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_EXCEPTION
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
|
||||
@@ -52,6 +52,7 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
@@ -213,6 +214,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
created_by: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
exceptions_count: Optional[int] = 0
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
@@ -341,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class NodeRetryStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
|
||||
@@ -110,7 +110,7 @@ class RateLimitGenerator:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
except Exception:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
|
||||
@@ -12,8 +12,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -25,6 +27,7 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
@@ -164,6 +167,55 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_partial_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
@@ -174,6 +226,7 @@ class WorkflowCycleManage:
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
@@ -193,7 +246,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = (
|
||||
@@ -220,9 +273,9 @@ class WorkflowCycleManage:
|
||||
|
||||
db.session.close()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
session.add(workflow_run)
|
||||
session.refresh(workflow_run)
|
||||
# with Session(db.engine, expire_on_commit=False) as session:
|
||||
# session.add(workflow_run)
|
||||
# session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -318,7 +371,7 @@ class WorkflowCycleManage:
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
@@ -337,7 +390,11 @@ class WorkflowCycleManage:
|
||||
)
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.status: (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
),
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
|
||||
@@ -351,8 +408,11 @@ class WorkflowCycleManage:
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
)
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
@@ -365,6 +425,52 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.index = event.start_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
@@ -433,6 +539,7 @@ class WorkflowCycleManage:
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -483,7 +590,10 @@ class WorkflowCycleManage:
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
@@ -525,6 +635,51 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
@@ -41,53 +40,42 @@ def to_prompt_message_content(
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
) -> MultiModalPromptMessageContent:
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type](**params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
return _download_file_content(tool_file.file_key)
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
# remote file
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
@@ -118,21 +106,14 @@ def _get_encoded_string(f: File, /):
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
@@ -141,7 +122,7 @@ def _to_url(f: File, /):
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if f.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
||||
@@ -47,6 +47,38 @@ class File(BaseModel):
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
# Those properties are private, should not be exposed to the outside.
|
||||
_storage_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
tenant_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
remote_url: Optional[str] = None,
|
||||
related_id: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
extension: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
size: int = -1,
|
||||
storage_key: str,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
tenant_id=tenant_id,
|
||||
type=type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
self._storage_key = storage_key
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
|
||||
|
||||
@@ -14,6 +13,7 @@ def obfuscated_token(token: str):
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
||||
@@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
@@ -39,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
@@ -64,7 +69,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
@@ -4,6 +4,7 @@ from .message_entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
class VideoPromptMessageContent(PromptMessageContent):
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
data: str = Field(..., description="Base64 encoded video data")
|
||||
format: str = Field(..., description="Video format")
|
||||
|
||||
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
format: str = Field(..., description="Audio format")
|
||||
|
||||
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
@@ -103,11 +116,8 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
@@ -18,7 +17,6 @@ from anthropic.types import (
|
||||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities import (
|
||||
@@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.data.startswith("data:"):
|
||||
if not message_content.base64_data:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
image_content = requests.get(message_content.url).content
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||
)
|
||||
else:
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
base64_data = message_content.base64_data
|
||||
|
||||
mime_type = message_content.mime_type
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
@@ -534,7 +529,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
},
|
||||
|
||||
@@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4o-2024-11-20",
|
||||
entity=AIModelEntity(
|
||||
model="fake-deployment-name",
|
||||
label=I18nObject(
|
||||
en_US="fake-deployment-name-label",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="presence_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
type="string",
|
||||
help=I18nObject(
|
||||
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||
),
|
||||
required=False,
|
||||
options=["text", "json_object", "json_schema"],
|
||||
),
|
||||
ParameterRule(
|
||||
name="json_schema",
|
||||
label=I18nObject(en_US="JSON Schema"),
|
||||
type="text",
|
||||
help=I18nObject(
|
||||
zh_Hans="设置返回的json schema,llm将按照它返回",
|
||||
en_US="Set a response json schema will ensure LLM to adhere it.",
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency="USD",
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4-turbo",
|
||||
entity=AIModelEntity(
|
||||
|
||||
@@ -86,6 +86,9 @@ model_credential_schema:
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
- label:
|
||||
en_US: '2024-10-21'
|
||||
value: '2024-10-21'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
@@ -168,6 +171,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-11-20
|
||||
value: gpt-4o-2024-11-20
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
||||
@@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.model_runtime.entities.llm_entities import (
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
@@ -105,7 +106,11 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError("User message content must be str")
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_dict = {"role": "user", "content": message_content.data}
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
raise ValueError("Content object type not support image_url")
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
|
||||
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
region_name = credentials.get("aws_region")
|
||||
if not region_name:
|
||||
raise InvokeBadRequestError("aws_region is required")
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
service_name=service_name,
|
||||
config=client_config,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
else:
|
||||
# use iam without aksk to call
|
||||
client = boto3.client(service_name=service_name, config=client_config)
|
||||
|
||||
return client
|
||||
@@ -6,6 +6,7 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -6,6 +6,7 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
@@ -173,13 +174,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
bedrock_client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
region_name=credentials["aws_region"],
|
||||
)
|
||||
|
||||
bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
|
||||
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -6,6 +6,7 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
|
||||
class BedrockRerankModel(RerankModel):
|
||||
@@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel):
|
||||
return RerankResult(model=model, docs=docs)
|
||||
|
||||
# initialize client
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-agent-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", ""),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
|
||||
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
|
||||
text_sources = []
|
||||
for text in docs:
|
||||
@@ -70,7 +62,10 @@ class BedrockRerankModel(RerankModel):
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
region = credentials.get("aws_region")
|
||||
# region is a required field
|
||||
if not region:
|
||||
raise InvokeBadRequestError("aws_region is required in credentials")
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
|
||||
@@ -3,8 +3,6 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
EndpointConnectionError,
|
||||
@@ -25,6 +23,7 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,14 +47,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
@@ -2,3 +2,4 @@
|
||||
- rerank-english-v3.0
|
||||
- rerank-multilingual-v2.0
|
||||
- rerank-multilingual-v3.0
|
||||
- rerank-v3.5
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
model: rerank-v3.5
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -24,6 +24,9 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
model: InternVL2-8B
|
||||
label:
|
||||
en_US: InternVL2-8B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -0,0 +1,93 @@
|
||||
model: InternVL2.5-26B
|
||||
label:
|
||||
en_US: InternVL2.5-26B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -6,3 +6,5 @@
|
||||
- deepseek-coder-33B-instruct-chat
|
||||
- deepseek-coder-33B-instruct-completions
|
||||
- codegeex4-all-9b
|
||||
- InternVL2.5-26B
|
||||
- InternVL2-8B
|
||||
|
||||
@@ -29,18 +29,26 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model, model_parameters)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return super()._invoke(
|
||||
GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model),
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
stop,
|
||||
stream,
|
||||
user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None:
|
||||
if model is None:
|
||||
model = "Qwen2-72B-Instruct"
|
||||
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
|
||||
if model.endswith("completions"):
|
||||
credentials["mode"] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
@@ -11,6 +13,8 @@
|
||||
- gemini-1.5-flash-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0924
|
||||
- gemini-exp-1206
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1114
|
||||
- gemini-pro
|
||||
- gemini-pro-vision
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-exp
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Exp
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user